Skip to content

Commit

Permalink
[Feature] Expose Grammar.union (#227)
Browse files Browse the repository at this point in the history
This PR provides Grammar.union to combine several grammars into one. The combined grammar can accept strings that follow any of the prior grammars.
  • Loading branch information
Ubospica authored Mar 5, 2025
1 parent fbc3767 commit 0670b14
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
18 changes: 18 additions & 0 deletions python/xgrammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,21 @@ def concat(*grammars: "Grammar") -> "Grammar":
"""
grammar_handles = [grammar._handle for grammar in grammars]
return Grammar._create_from_handle(_core.Grammar.concat(grammar_handles))

@staticmethod
def union(*grammars: "Grammar") -> "Grammar":
"""Create a grammar that matches any of the grammars in the list. That is equivalent to
using the `|` operator to concatenate the grammars in the list.
Parameters
----------
grammars : List[Grammar]
The grammars to create the union of.
Returns
-------
grammar : Grammar
The union of the grammars.
"""
grammar_handles = [grammar._handle for grammar in grammars]
return Grammar._create_from_handle(_core.Grammar.union(grammar_handles))
18 changes: 0 additions & 18 deletions python/xgrammar/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,6 @@ def _get_matcher_from_grammar_and_tokenizer_info(
return GrammarMatcher(compiled_grammar, **kwargs)


def _get_grammar_union(*grammars: "Grammar") -> "Grammar":
"""Create a grammar that matches any of the grammars in the list. That is equivalent to
using the `|` operator to concatenate the grammars in the list.
Parameters
----------
grammars : List[Grammar]
The grammars to create the union of.
Returns
-------
grammar : Grammar
The union of the grammars.
"""
grammar_handles = [grammar._handle for grammar in grammars]
return Grammar._create_from_handle(_core.Grammar.union(grammar_handles))


def _get_allow_empty_rule_ids(compiled_grammar: CompiledGrammar) -> List[int]:
return _core.testing._get_allow_empty_rule_ids(compiled_grammar._handle)

Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_grammar_union_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest

import xgrammar as xgr
from xgrammar.testing import _get_grammar_union


def test_grammar_union():
Expand Down Expand Up @@ -42,7 +41,7 @@ def test_grammar_union():
r3 ::= (("abc"))
"""

union_grammar = _get_grammar_union(grammar1, grammar2, grammar3)
union_grammar = xgr.Grammar.union(grammar1, grammar2, grammar3)
assert str(union_grammar) == expected


Expand Down

0 comments on commit 0670b14

Please sign in to comment.