@ -2,13 +2,10 @@
# include <cassert>
# include <cstring>
# include <iostream>
# include <fstream>
# include <sstream>
# include <string>
# include <iterator>
# include <algorithm>
# include <regex>
# if defined (_WIN32)
# include <fcntl.h>
@ -26,43 +23,6 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
# define CP_UTF8 65001
# endif
void split_args ( const std : : string & args_string , std : : vector < std : : string > & output_args )
{
std : : string current_arg = " " ;
bool in_quotes = false ;
char quote_type ;
for ( char c : args_string ) {
if ( c = = ' " ' | | c = = ' \' ' ) {
if ( ! in_quotes ) {
in_quotes = true ;
quote_type = c ;
} else if ( quote_type = = c ) {
in_quotes = false ;
} else {
current_arg + = c ;
}
} else if ( in_quotes ) {
current_arg + = c ;
} else if ( std : : isspace ( c ) ) {
if ( current_arg ! = " " ) {
output_args . push_back ( current_arg ) ;
current_arg = " " ;
}
} else {
current_arg + = c ;
}
}
if ( current_arg ! = " " ) {
output_args . push_back ( current_arg ) ;
}
}
std : : string unescape ( const std : : string & str ) {
return std : : regex_replace ( str , std : : regex ( " \\ \\ n " ) , " \n " ) ;
}
bool gpt_params_parse ( int argc , char * * argv , gpt_params & params ) {
// determine sensible default number of threads.
// std::thread::hardware_concurrency may not be equal to the number of cores, or may return 0.
@ -80,66 +40,35 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std : : string arg ;
gpt_params default_params ;
// get additional arguments from config files
std : : vector < std : : string > args ;
for ( int i = 1 ; i < argc ; i + + ) {
arg = argv [ i ] ;
if ( arg = = " --config " ) {
if ( + + i > = argc ) {
invalid_param = true ;
break ;
}
std : : ifstream file ( argv [ i ] ) ;
if ( ! file ) {
fprintf ( stderr , " error: failed to open file '%s' \n " , argv [ i ] ) ;
invalid_param = true ;
break ;
}
std : : string args_string ;
std : : copy ( std : : istreambuf_iterator < char > ( file ) , std : : istreambuf_iterator < char > ( ) , back_inserter ( args_string ) ) ;
if ( args_string . back ( ) = = ' \n ' ) {
args_string . pop_back ( ) ;
}
split_args ( args_string , args ) ;
for ( int j = 0 ; j < args . size ( ) ; j + + ) {
args [ j ] = unescape ( args [ j ] ) ;
}
} else {
args . emplace_back ( argv [ i ] ) ;
}
}
// parse args
int args_c = static_cast < int > ( args . size ( ) ) ;
for ( int i = 0 ; i < args_c & & ! invalid_param ; i + + ) {
arg = args [ i ] ;
if ( arg = = " -s " | | arg = = " --seed " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = argc ) {
invalid_param = true ;
break ;
}
params . seed = std : : stoi ( arg s [ i ] ) ;
params . seed = std : : stoi ( argv [ i ] ) ;
} else if ( arg = = " -t " | | arg = = " --threads " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = argc ) {
invalid_param = true ;
break ;
}
params . n_threads = std : : stoi ( arg s [ i ] ) ;
params . n_threads = std : : stoi ( argv [ i ] ) ;
} else if ( arg = = " -p " | | arg = = " --prompt " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = argc ) {
invalid_param = true ;
break ;
}
params . prompt = arg s [ i ] ;
params . prompt = argv [ i ] ;
} else if ( arg = = " -f " | | arg = = " --file " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = argc ) {
invalid_param = true ;
break ;
}
std : : ifstream file ( arg s [ i ] ) ;
std : : ifstream file ( argv [ i ] ) ;
if ( ! file ) {
fprintf ( stderr , " error: failed to open file '%s' \n " , arg s[ i ] . c_str ( ) ) ;
fprintf ( stderr , " error: failed to open file '%s' \n " , argv [ i ] ) ;
invalid_param = true ;
break ;
}
@ -148,100 +77,80 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params . prompt . pop_back ( ) ;
}
} else if ( arg = = " -n " | | arg = = " --n_predict " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . n_predict = std : : stoi ( arg s [ i ] ) ;
params . n_predict = std : : stoi ( arg v [ i ] ) ;
} else if ( arg = = " --top_k " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . top_k = std : : stoi ( arg s [ i ] ) ;
params . top_k = std : : stoi ( arg v [ i ] ) ;
} else if ( arg = = " -c " | | arg = = " --ctx_size " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . n_ctx = std : : stoi ( arg s [ i ] ) ;
params . n_ctx = std : : stoi ( arg v [ i ] ) ;
} else if ( arg = = " --memory_f32 " ) {
params . memory_f16 = false ;
} else if ( arg = = " --top_p " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . top_p = std : : stof ( arg s [ i ] ) ;
params . top_p = std : : stof ( arg v [ i ] ) ;
} else if ( arg = = " --temp " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . temp = std : : stof ( arg s [ i ] ) ;
params . temp = std : : stof ( arg v [ i ] ) ;
} else if ( arg = = " --repeat_last_n " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . repeat_last_n = std : : stoi ( arg s [ i ] ) ;
params . repeat_last_n = std : : stoi ( arg v [ i ] ) ;
} else if ( arg = = " --repeat_penalty " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . repeat_penalty = std : : stof ( arg s [ i ] ) ;
params . repeat_penalty = std : : stof ( arg v [ i ] ) ;
} else if ( arg = = " -b " | | arg = = " --batch_size " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . n_batch = std : : stoi ( arg s [ i ] ) ;
params . n_batch = std : : stoi ( arg v [ i ] ) ;
params . n_batch = std : : min ( 512 , params . n_batch ) ;
} else if ( arg = = " --keep " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . n_keep = std : : stoi ( arg s [ i ] ) ;
params . n_keep = std : : stoi ( arg v [ i ] ) ;
} else if ( arg = = " -m " | | arg = = " --model " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . model = arg s [ i ] ;
params . model = arg v [ i ] ;
} else if ( arg = = " -i " | | arg = = " --interactive " ) {
params . interactive = true ;
} else if ( arg = = " --embedding " ) {
params . embedding = true ;
} else if ( arg = = " --clean-interface " ) {
params . clean_interface = true ;
} else if ( arg = = " --interactive-start " ) {
params . interactive = true ;
} else if ( arg = = " --interactive-first " ) {
params . interactive_start = true ;
} else if ( arg = = " -ins " | | arg = = " --instruct " ) {
fprintf ( stderr , " \n \n Warning: instruct mode is deprecated! Use: \n "
" --clean-interface "
" --interactive-first "
" --keep -1 "
" --ins-prefix-bos "
" --ins-prefix \" \\ n \\ n### Instruction: \\ n \\ n \" "
" --ins-suffix \" \\ n \\ n### Response: \\ n \\ n \" "
" -r \" ### Instruction: \\ n \\ n \" "
" \n \n " ) ;
// params.instruct = true;
params . clean_interface = true ;
params . interactive_start = true ;
params . n_keep = - 1 ;
params . instruct_prefix_bos = true ;
params . instruct_prefix = " \n \n ### Instruction: \n \n " ;
params . instruct_suffix = " \n \n ### Response: \n \n " ;
params . antiprompt . push_back ( " ### Instruction: \n \n " ) ;
params . instruct = true ;
} else if ( arg = = " --color " ) {
params . use_color = true ;
} else if ( arg = = " --disable-multiline " ) {
params . multiline_mode = false ;
} else if ( arg = = " --mlock " ) {
params . use_mlock = true ;
} else if ( arg = = " --no-mmap " ) {
@ -251,94 +160,65 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if ( arg = = " --verbose-prompt " ) {
params . verbose_prompt = true ;
} else if ( arg = = " -r " | | arg = = " --reverse-prompt " ) {
if ( + + i > = args_c ) {
invalid_param = true ;
break ;
}
params . antiprompt . push_back ( args [ i ] ) ;
} else if ( arg = = " --stop-prompt " ) {
if ( + + i > = args_c ) {
if ( + + i > = argc ) {
invalid_param = true ;
break ;
}
params . stopprompt . push_back ( args [ i ] ) ;
} else if ( arg = = " --rm-trailing-space-workaround " ) {
params . rm_trailing_space_workaround = true ;
params . antiprompt . push_back ( argv [ i ] ) ;
} else if ( arg = = " --perplexity " ) {
params . perplexity = true ;
} else if ( arg = = " --ignore-eos " ) {
params . ignore_eos = true ;
} else if ( arg = = " --n_parts " ) {
if ( + + i > = arg s_ c) {
if ( + + i > = arg c) {
invalid_param = true ;
break ;
}
params . n_parts = std : : stoi ( arg s [ i ] ) ;
params . n_parts = std : : stoi ( arg v [ i ] ) ;
} else if ( arg = = " -h " | | arg = = " --help " ) {
gpt_print_usage ( arg v[ 0 ] , default_params ) ;
gpt_print_usage ( arg c, argv , default_params ) ;
exit ( 0 ) ;
} else if ( arg = = " --random-prompt " ) {
params . random_prompt = true ;
} else if ( arg = = " --in-prefix " ) {
if ( + + i > = args_c ) {
invalid_param = true ;
break ;
}
params . input_prefix = args [ i ] ;
} else if ( arg = = " --ins-prefix-bos " ) {
params . instruct_prefix_bos = true ;
} else if ( arg = = " --ins-prefix " ) {
if ( + + i > = args_c ) {
invalid_param = true ;
break ;
}
params . instruct_prefix = args [ i ] ;
} else if ( arg = = " --ins-suffix-bos " ) {
params . instruct_suffix_bos = true ;
} else if ( arg = = " --ins-suffix " ) {
if ( + + i > = args_c ) {
if ( + + i > = argc ) {
invalid_param = true ;
break ;
}
params . in struct_suffix = args [ i ] ;
params . input_prefix = argv [ i ] ;
} else {
fprintf ( stderr , " error: unknown argument: %s \n " , arg . c_str ( ) ) ;
gpt_print_usage ( arg v[ 0 ] , default_params ) ;
gpt_print_usage ( argc , argv , default_params ) ;
exit ( 1 ) ;
}
}
if ( invalid_param ) {
fprintf ( stderr , " error: invalid parameter for argument: %s \n " , arg . c_str ( ) ) ;
gpt_print_usage ( arg v[ 0 ] , default_params ) ;
gpt_print_usage ( argc , argv , default_params ) ;
exit ( 1 ) ;
}
return true ;
}
void gpt_print_usage ( char * argv _0 , const gpt_params & params ) {
fprintf ( stderr , " usage: %s [options] \n " , argv _0 ) ;
void gpt_print_usage ( int /*argc*/ , char * * argv , const gpt_params & params ) {
fprintf ( stderr , " usage: %s [options] \n " , argv [ 0 ] ) ;
fprintf ( stderr , " \n " ) ;
fprintf ( stderr , " options: \n " ) ;
fprintf ( stderr , " -h, --help show this help message and exit \n " ) ;
fprintf ( stderr , " -i, --interactive run in interactive mode \n " ) ;
fprintf ( stderr , " --interactive-first run in interactive mode and wait for input right away \n " ) ;
fprintf ( stderr , " - -clean-interface hides input prefix & suffix and displays '>' instead \n " ) ;
fprintf ( stderr , " - ins, --instruct run in instruction mode (use with Alpaca models) \n " ) ;
fprintf ( stderr , " -r PROMPT, --reverse-prompt PROMPT \n " ) ;
fprintf ( stderr , " run in interactive mode and poll user input upon seeing PROMPT (can be \n " ) ;
fprintf ( stderr , " specified more than once for multiple prompts). \n " ) ;
fprintf ( stderr , " --color colorise output to distinguish prompt and user input from generations \n " ) ;
fprintf ( stderr , " --disable-multiline disable multiline mode (use Ctrl+D on Linux/Mac and Ctrl+Z then Return on Windows to toggle multiline) \n " ) ;
fprintf ( stderr , " -s SEED, --seed SEED RNG seed (default: -1, use random seed for <= 0) \n " ) ;
fprintf ( stderr , " -t N, --threads N number of threads to use during computation (default: %d) \n " , params . n_threads ) ;
fprintf ( stderr , " -p PROMPT, --prompt PROMPT \n " ) ;
fprintf ( stderr , " prompt to start generation with (default: empty) \n " ) ;
fprintf ( stderr , " --random-prompt start with a randomized prompt. \n " ) ;
fprintf ( stderr , " --in-prefix STRING string to prefix user inputs with (default: empty) \n " ) ;
fprintf ( stderr , " --ins-prefix STRING (instruct) prefix user inputs with tokenized string (default: empty) \n " ) ;
fprintf ( stderr , " --ins-prefix-bos (instruct) prepend bos token to instruct prefix. \n " ) ;
fprintf ( stderr , " --ins-suffix STRING (instruct) suffix user inputs with tokenized string (default: empty) \n " ) ;
fprintf ( stderr , " --ins-suffix-bos (instruct) prepend bos token to instruct suffix. \n " ) ;
fprintf ( stderr , " -f FNAME, --file FNAME \n " ) ;
fprintf ( stderr , " prompt file to start generation. \n " ) ;
fprintf ( stderr , " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity) \n " , params . n_predict ) ;
@ -448,61 +328,3 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
str = strTo ;
}
# endif
bool get_input_text ( std : : string & input_text , bool eof_toggled_multiline_mode ) {
bool another_line = true ;
bool is_eof_multiline_toggled = false ;
do {
std : : string line ;
# if defined(_WIN32)
auto & stdcin = std : : wcin ;
std : : wstring wline ;
if ( ! std : : getline ( stdcin , wline ) ) {
// input stream is bad or EOF received
if ( stdcin . bad ( ) ) {
fprintf ( stderr , " %s: error: input stream bad \n " , __func__ ) ;
return 1 ;
}
}
win32_utf8_encode ( wline , line ) ;
# else
auto & stdcin = std : : cin ;
if ( ! std : : getline ( stdcin , line ) ) {
// input stream is bad or EOF received
if ( stdcin . bad ( ) ) {
fprintf ( stderr , " %s: error: input stream bad \n " , __func__ ) ;
return 1 ;
}
}
# endif
if ( stdcin . eof ( ) ) {
stdcin . clear ( ) ;
stdcin . seekg ( 0 , std : : ios : : beg ) ;
if ( ! eof_toggled_multiline_mode ) {
another_line = false ;
} else {
is_eof_multiline_toggled = ! is_eof_multiline_toggled ;
if ( is_eof_multiline_toggled ) {
input_text + = line ;
continue ;
}
}
}
if ( ! eof_toggled_multiline_mode ) {
if ( line . empty ( ) | | line . back ( ) ! = ' \\ ' ) {
another_line = false ;
} else {
line . pop_back ( ) ; // Remove the continue character
}
} else {
if ( ! is_eof_multiline_toggled ) {
another_line = false ;
}
}
input_text + = line ;
if ( another_line ) {
input_text + = ' \n ' ; // Append the line to the result
}
} while ( another_line ) ;
return true ;
}