Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing grammar to completion endpoint #2532

Merged
merged 3 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ embedding: examples/embedding/embedding.cpp build-info.h ggml.
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS)
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)

$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
Expand Down
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ node .

`mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1).

`grammar`: Set grammar for grammar-based sampling (default: no grammar)

`seed`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).

`ignore_eos`: Ignore end of stream token and continue generating (default: false).
Expand Down
60 changes: 58 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
Expand Down Expand Up @@ -195,6 +196,8 @@ struct llama_server_context
llama_context *ctx = nullptr;
gpt_params params;

llama_grammar *grammar = nullptr;

bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
Expand Down Expand Up @@ -226,6 +229,7 @@ struct llama_server_context
void rewind()
{
params.antiprompt.clear();
params.grammar.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
Expand All @@ -237,6 +241,7 @@ struct llama_server_context
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
grammar = nullptr;

n_remain = 0;
n_past = 0;
Expand All @@ -257,6 +262,33 @@ struct llama_server_context
return true;
}

bool loadGrammar()
{
if (!params.grammar.empty()) {
grammar_parser::parse_state parsed_grammar;

parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return false;
}
grammar_parser::print_grammar(stderr, parsed_grammar);

{
auto it = params.logit_bias.find(llama_token_eos());
if (it != params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}

std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
return true;
}

void loadPrompt()
{
params.prompt.insert(0, 1, ' '); // always add a first space
Expand Down Expand Up @@ -420,6 +452,10 @@ struct llama_server_context
logits[llama_token_nl()] = nl_logit;
}

if (grammar != nullptr) {
llama_sample_grammar(ctx, &candidates_p, grammar);
}

if (temp <= 0)
{
// Greedy sampling
Expand Down Expand Up @@ -457,10 +493,15 @@ struct llama_server_context
}
}

if (grammar != nullptr) {
llama_grammar_accept_token(ctx, grammar, result.tok);
}

for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
{
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
num_tokens_predicted++;
Expand Down Expand Up @@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama)
{"stream", llama.stream},
{"logit_bias", llama.params.logit_bias},
{"n_probs", llama.params.n_probs},
{"grammar", llama.params.grammar},
};
}

Expand Down Expand Up @@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.grammar = body.value("grammar", default_params.grammar);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);

llama.params.logit_bias.clear();
Expand Down Expand Up @@ -1179,6 +1222,12 @@ int main(int argc, char **argv)

parse_options_completion(json::parse(req.body), llama);

if (!llama.loadGrammar())
{
res.status = 400;
return;
}

llama.loadPrompt();
llama.beginCompletion();

Expand Down Expand Up @@ -1330,8 +1379,12 @@ int main(int argc, char **argv)

svr.set_error_handler([](const Request &, Response &res)
{
res.set_content("File Not Found", "text/plain");
res.status = 404; });
if (res.status == 400) {
res.set_content("Invalid request", "text/plain");
} else {
res.set_content("File Not Found", "text/plain");
res.status = 404;
} });

// set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout);
Expand Down Expand Up @@ -1359,6 +1412,9 @@ int main(int argc, char **argv)
return 1;
}

if (llama.grammar != nullptr) {
llama_grammar_free(llama.grammar);
}
llama_backend_free();

return 0;
Expand Down
Loading