@ -102,6 +102,9 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std : : vector < float > logits ;
bool logits_all = false ;
// input embedding (1-dimensional array: [n_embd])
std : : vector < float > embedding ;
} ;
struct llama_context_params llama_context_default_params ( ) {
@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false ,
/*.logits_all =*/ false ,
/*.vocab_only =*/ false ,
/*.embedding =*/ false ,
} ;
return result ;
@ -592,8 +596,6 @@ static bool llama_model_load(
fin . close ( ) ;
}
lctx . logits . reserve ( lctx . model . hparams . n_ctx ) ;
lctx . t_load_us = ggml_time_us ( ) - t_start_us ;
return true ;
@ -791,6 +793,9 @@ static bool llama_eval_internal(
inpL = cur ;
}
// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL ;
// norm
{
inpL = ggml_rms_norm ( ctx0 , inpL ) ;
@ -799,6 +804,8 @@ static bool llama_eval_internal(
inpL = ggml_mul ( ctx0 ,
ggml_repeat ( ctx0 , model . norm , inpL ) ,
inpL ) ;
embeddings = inpL ;
}
// lm_head
@ -821,15 +828,26 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
auto & logits_out = lctx . logits ;
// extract logits
{
auto & logits_out = lctx . logits ;
if ( lctx . logits_all ) {
logits_out . resize ( n_vocab * N ) ;
memcpy ( logits_out . data ( ) , ( float * ) ggml_get_data ( inpL ) , sizeof ( float ) * n_vocab * N ) ;
} else {
// return result for just the last token
logits_out . resize ( n_vocab ) ;
memcpy ( logits_out . data ( ) , ( float * ) ggml_get_data ( inpL ) + ( n_vocab * ( N - 1 ) ) , sizeof ( float ) * n_vocab ) ;
}
}
// extract embeddings
if ( lctx . embedding . size ( ) ) {
auto & embedding_out = lctx . embedding ;
if ( lctx . logits_all ) {
logits_out . resize ( n_vocab * N ) ;
memcpy ( logits_out . data ( ) , ( float * ) ggml_get_data ( inpL ) , sizeof ( float ) * n_vocab * N ) ;
} else {
// return result for just the last token
logits_out . resize ( n_vocab ) ;
memcpy ( logits_out . data ( ) , ( float * ) ggml_get_data ( inpL ) + ( n_vocab * ( N - 1 ) ) , sizeof ( float ) * n_vocab ) ;
embedding_out . resize ( n_embd ) ;
memcpy ( embedding_out . data ( ) , ( float * ) ggml_get_data ( embeddings ) + ( n_embd * ( N - 1 ) ) , sizeof ( float ) * n_embd ) ;
}
if ( mem_per_token = = 0 ) {
@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
return nullptr ;
}
// reserve memory for context buffers
{
const auto & hparams = ctx - > model . hparams ;
if ( params . logits_all ) {
ctx - > logits . reserve ( hparams . n_ctx * hparams . n_vocab ) ;
} else {
ctx - > logits . reserve ( hparams . n_ctx ) ;
}
if ( params . embedding ) {
ctx - > embedding . reserve ( hparams . n_embd ) ;
}
}
return ctx ;
}
@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx - > logits . data ( ) ;
}
float * llama_get_embeddings ( struct llama_context * ctx ) {
return ctx - > embedding . data ( ) ;
}
const char * llama_token_to_str ( struct llama_context * ctx , llama_token token ) {
if ( token > = llama_n_vocab ( ctx ) ) {
return nullptr ;