diff --git a/utils.cpp b/utils.cpp index 1d5309c..45c9cab 100644 --- a/utils.cpp +++ b/utils.cpp @@ -26,41 +26,95 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency()); } + bool invalid_param = false; + std::string arg; for (int i = 1; i < argc; i++) { - std::string arg = argv[i]; + arg = argv[i]; if (arg == "-s" || arg == "--seed") { - params.seed = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.seed = std::stoi(argv[i]); } else if (arg == "-t" || arg == "--threads") { - params.n_threads = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads = std::stoi(argv[i]); } else if (arg == "-p" || arg == "--prompt") { - params.prompt = argv[++i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.prompt = argv[i]; } else if (arg == "-f" || arg == "--file") { - std::ifstream file(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); if (params.prompt.back() == '\n') { params.prompt.pop_back(); } } else if (arg == "-n" || arg == "--n_predict") { - params.n_predict = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_predict = std::stoi(argv[i]); } else if (arg == "--top_k") { - params.top_k = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.top_k = std::stoi(argv[i]); } else if (arg == "-c" || arg == "--ctx_size") { - params.n_ctx = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ctx = std::stoi(argv[i]); } else if (arg == "--memory_f16") { params.memory_f16 = true; } else if (arg == "--top_p") { - params.top_p = std::stof(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.top_p = std::stof(argv[i]); } else if (arg == "--temp") { - params.temp = std::stof(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.temp = std::stof(argv[i]); } else if (arg == "--repeat_last_n") { - params.repeat_last_n = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.repeat_last_n = std::stoi(argv[i]); } else if (arg == "--repeat_penalty") { - params.repeat_penalty = std::stof(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.repeat_penalty = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch_size") { - params.n_batch = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_batch = std::stoi(argv[i]); } else if (arg == "-m" || arg == "--model") { - params.model = argv[++i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.model = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--interactive-first") { @@ -70,13 +124,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--color") { params.use_color = true; } else if (arg == "-r" || arg == "--reverse-prompt") { - params.antiprompt.push_back(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.antiprompt.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { params.ignore_eos = true; } else if (arg == "--n_parts") { - params.n_parts = std::stoi(argv[++i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_parts = std::stoi(argv[i]); } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); @@ -85,9 +147,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, params); - exit(0); + exit(1); } } + if (invalid_param) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + gpt_print_usage(argc, argv, params); + exit(1); + } return true; }