Skip to content

Commit

Permalink
Add Json Schema benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Aug 15, 2024
1 parent 32c2bef commit aa898b4
Showing 1 changed file with 57 additions and 4 deletions.
61 changes: 57 additions & 4 deletions src/lfe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Benchmark the lm-format-enforcer library."""
from lmformatenforcer import RegexParser, TokenEnforcer
from lmformatenforcer import JsonSchemaParser, RegexParser, TokenEnforcer
from lmformatenforcer.integrations.transformers import (
build_token_enforcer_tokenizer_data,
)
Expand All @@ -12,7 +12,7 @@
"google/gemma-2-2b-it", # 256,128 tokens vocabulary
]

case = [
regex_case = [
(r"\d{3}-\d{2}-\d{4}", "203-22-1234"),
(
r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?",
Expand All @@ -25,8 +25,8 @@
]


class LMFormatEnforcer:
params = [models, case]
class LMFormatEnforcerRegex:
params = [models, regex_case]
param_names = ["model", "regex"]
timeout = 600

Expand All @@ -51,3 +51,56 @@ def time_lfe(self, _, regex):

for i in range(len(regex_example_tokens)):
_ = token_enforcer.get_allowed_tokens(regex_example_tokens[: i + 1])


json_case = [
(
{
"$defs": {
"Armor": {
"enum": ["leather", "chainmail", "plate"],
"title": "Armor",
"type": "string",
}
},
"properties": {
"name": {"maxLength": 10, "title": "Name", "type": "string"},
"age": {"title": "Age", "type": "integer"},
"armor": {"$ref": "#/$defs/Armor"},
"strength": {"title": "Strength", "type": "integer"},
},
"required": ["name", "age", "armor", "strength"],
"title": "Character",
"type": "object",
},
"""{'name': 'Super Warrior', 'age': 26, 'armor': 'leather', 'armor': 10}""",
)
]


class LMFormatEnforcerJsonSchema:
params = [models, json_case]
param_names = ["model", "regex"]
timeout = 600

def setup(self, model, _):
"""Set up the benchmark.
We convert the tokenizer during set up as this only
needs to be done once for a given model.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)

def time_lfe(self, _, regex):
regex_string, regex_example = regex
regex_example_tokens = self.tokenizer.encode(regex_example)

parser = JsonSchemaParser(regex_string)
token_enforcer = TokenEnforcer(self.tokenizer_data, parser)

for i in range(len(regex_example_tokens)):
_ = token_enforcer.get_allowed_tokens(regex_example_tokens[: i + 1])

0 comments on commit aa898b4

Please sign in to comment.