Skip to content

Commit

Permalink
grammars: 1.5x faster inference w/ complex grammars (vector reserves …
Browse files Browse the repository at this point in the history
…/ reuses) (ggerganov#6609)

* grammars: reserve rejects & next candidates

* grammars: reuse new_stacks

* grammars: fix missing sig change in llama.h

* grammars: fix test (api changed)

* grammars: update gbnf-validator.cpp

* grammars: simpler syntax (no swap)
  • Loading branch information
ochafik authored and tybalex committed Apr 17, 2024
1 parent 18a84a7 commit 1e8310d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
size_t pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
Expand Down
16 changes: 10 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr) {
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {

std::vector<std::vector<const llama_grammar_element *>> new_stacks;
new_stacks.clear();

for (const auto & stack : stacks) {
if (stack.empty()) {
Expand All @@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
}

return new_stacks;
}

static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
Expand All @@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const std::vector<llama_grammar_candidate> & candidates) {

std::vector<llama_grammar_candidate> rejects;
rejects.reserve(candidates.size());

if (stack.empty()) {
for (const auto & tok : candidates) {
Expand All @@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const llama_grammar_element * stack_pos = stack.back();

std::vector<llama_grammar_candidate> next_candidates;
next_candidates.reserve(candidates.size());

for (const auto & tok : candidates) {
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
Expand Down Expand Up @@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
grammar->stacks = tmp_new_stacks;
}
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
Expand Down
5 changes: 3 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1097,10 +1097,11 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
struct llama_context * ctx
);

std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr);
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);

std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
Expand Down
6 changes: 3 additions & 3 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ number ::= [0-9]+)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
}

Expand Down Expand Up @@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);

// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
Expand Down Expand Up @@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
Expand Down

0 comments on commit 1e8310d

Please sign in to comment.