@ -27,20 +27,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
int count = 0 ;
int seq_count = tokens . size ( ) / params . n_ctx ;
int n_vocab = llama_n_vocab ( ctx ) ;
double nll = 0.0 ;
fprintf ( stderr , " %s : calculating perplexity over %d chunks \n " , __func__ , seq_count ) ;
fprintf ( stderr , " %s : calculating perplexity over %d chunks, batch_size=%d \n " , __func__ , seq_count , params . n_batch ) ;
for ( int i = 0 ; i < seq_count ; + + i ) {
int start = i * params . n_ctx ;
int end = start + params . n_ctx - 1 ; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
// it is better to always be power of 2 for better performance
std : : vector < llama_token > embd ( tokens . begin ( ) + start , tokens . begin ( ) + end ) ;
int end = start + params . n_ctx ;
std : : vector < float > logits ;
int num_batches = ( params . n_ctx + params . n_batch - 1 ) / params . n_batch ;
auto start_t = std : : chrono : : high_resolution_clock : : now ( ) ;
if ( llama_eval ( ctx , embd . data ( ) , embd . size ( ) , 0 , params . n_threads ) ) {
fprintf ( stderr , " %s : failed to eval \n " , __func__ ) ;
return ;
for ( int j = 0 ; j < num_batches ; + + j ) {
int batch_start = start + j * params . n_batch ;
int batch_size = std : : min ( end - batch_start , params . n_batch ) ;
if ( llama_eval ( ctx , tokens . data ( ) + batch_start , batch_size , j * params . n_batch , params . n_threads ) ) {
fprintf ( stderr , " %s : failed to eval \n " , __func__ ) ;
return ;
}
auto batch_logits = llama_get_logits ( ctx ) ;
logits . insert ( logits . end ( ) , batch_logits , batch_logits + batch_size * n_vocab ) ;
}
auto end_t = std : : chrono : : high_resolution_clock : : now ( ) ;
if ( i = = 0 ) {
@ -59,15 +66,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
auto logits = llama_get_logits ( ctx ) ;
for ( int j = params . n_ctx / 2 ; j < params . n_ctx - 1 ; + + j ) {
for ( int j = std : : min ( 512 , params . n_ctx / 2 ) ; j < params . n_ctx - 1 ; + + j ) {
// Calculate probability of next token, given the previous ones.
int n_vocab = llama_n_vocab ( ctx ) ;
std : : vector < float > tok_logits (
logits + j * n_vocab ,
logits + ( j + 1 ) * n_vocab ) ;
const float prob = softmax ( tok_logits ) [ tokens [ start + j + 1 ] ] ;
logits . begin ( ) + j * n_vocab ,
logits . begin ( ) + ( j + 1 ) * n_vocab ) ;
float prob = softmax ( tok_logits ) [ tokens [ start + j + 1 ] ] ;
nll + = - std : : log ( prob ) ;
+ + count ;
}
@ -82,11 +86,13 @@ int main(int argc, char ** argv) {
gpt_params params ;
params . model = " models/llama-7B/ggml-model.bin " ;
params . n_batch = 512 ;
if ( gpt_params_parse ( argc , argv , params ) = = false ) {
return 1 ;
}
params . perplexity = true ;
params . n_batch = std : : min ( params . n_batch , params . n_ctx ) ;
if ( params . n_ctx > 2048 ) {
fprintf ( stderr , " %s: warning: model does not support context sizes greater than 2048 tokens (%d specified); "