diff --git a/main.cpp b/main.cpp index 3f49ad9..143585b 100644 --- a/main.cpp +++ b/main.cpp @@ -300,6 +300,10 @@ int main(int argc, char ** argv) { fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str()); } } + + if (!params.input_prefix.empty()) { + fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); + } } fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(stderr, "\n\n"); @@ -472,6 +476,11 @@ int main(int argc, char ** argv) { } std::string buffer; + if (!params.input_prefix.empty()) { + buffer += params.input_prefix; + printf(buffer.c_str()); + } + std::string line; bool another_line = true; do { diff --git a/utils.cpp b/utils.cpp index 2f995c1..ef3b67a 100644 --- a/utils.cpp +++ b/utils.cpp @@ -155,6 +155,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { exit(0); } else if (arg == "--random-prompt") { params.random_prompt = true; + } else if (arg == "--in-prefix") { + params.input_prefix = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, params); @@ -187,6 +189,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stderr, " prompt to start generation with (default: empty)\n"); fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); + fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); diff --git a/utils.h b/utils.h index d469bc6..0690ef7 100644 --- a/utils.h +++ b/utils.h @@ -30,6 +30,7 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; + std::string input_prefix = ""; // string to prefix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted