Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make reverse prompt option act as a stop token in non-interactive sce… #1032

Merged
merged 7 commits into from
May 19, 2023
6 changes: 3 additions & 3 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}
if (params.prompt_cache_all &&
(params.interactive || params.interactive_first ||
params.instruct || params.antiprompt.size())) {
params.instruct)) {
fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
gpt_print_usage(argc, argv, default_params);
exit(1);
Expand All @@ -368,8 +368,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n");
fprintf(stderr, " (can be specified more than once for multiple prompts).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
Expand Down
26 changes: 18 additions & 8 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ int main(int argc, char ** argv) {
params.antiprompt.push_back("### Instruction:\n\n");
}

// enable interactive mode if reverse prompt or interactive start is specified
if (params.antiprompt.size() != 0 || params.interactive_first) {
// enable interactive mode if interactive start is specified
if (params.interactive_first) {
params.interactive = true;
}

Expand Down Expand Up @@ -306,7 +306,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;

while (n_remain != 0 || params.interactive) {
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (embd.size() > 0) {
// infinite text generation via context swapping
Expand Down Expand Up @@ -504,9 +504,8 @@ int main(int argc, char ** argv) {
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}

// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {

// check for reverse prompt
if (params.antiprompt.size()) {
Expand All @@ -517,10 +516,21 @@ int main(int argc, char ** argv) {

is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
// so we'll compensate for that by widening the search window a bit.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
size_t extra_padding = params.interactive ? 0 : 2;
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
: 0;

if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) {
if (params.interactive) {
is_interacting = true;
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
is_antiprompt = true;
fflush(stdout);
break;
}
}
Expand Down