diff --git a/src/lfe.py b/src/lfe.py index feebad2..dd7d267 100644 --- a/src/lfe.py +++ b/src/lfe.py @@ -1,5 +1,5 @@ """Benchmark the lm-format-enforcer library.""" -from lmformatenforcer import RegexParser, TokenEnforcer +from lmformatenforcer import JsonSchemaParser, RegexParser, TokenEnforcer from lmformatenforcer.integrations.transformers import ( build_token_enforcer_tokenizer_data, ) @@ -12,7 +12,7 @@ "google/gemma-2-2b-it", # 256,128 tokens vocabulary ] -case = [ +regex_case = [ (r"\d{3}-\d{2}-\d{4}", "203-22-1234"), ( r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?", @@ -25,8 +25,8 @@ ] -class LMFormatEnforcer: - params = [models, case] +class LMFormatEnforcerRegex: + params = [models, regex_case] param_names = ["model", "regex"] timeout = 600 @@ -51,3 +51,56 @@ def time_lfe(self, _, regex): for i in range(len(regex_example_tokens)): _ = token_enforcer.get_allowed_tokens(regex_example_tokens[: i + 1]) + + +json_case = [ + ( + { + "$defs": { + "Armor": { + "enum": ["leather", "chainmail", "plate"], + "title": "Armor", + "type": "string", + } + }, + "properties": { + "name": {"maxLength": 10, "title": "Name", "type": "string"}, + "age": {"title": "Age", "type": "integer"}, + "armor": {"$ref": "#/$defs/Armor"}, + "strength": {"title": "Strength", "type": "integer"}, + }, + "required": ["name", "age", "armor", "strength"], + "title": "Character", + "type": "object", + }, + """{'name': 'Super Warrior', 'age': 26, 'armor': 'leather', 'armor': 10}""", + ) +] + + +class LMFormatEnforcerJsonSchema: + params = [models, json_case] + param_names = ["model", "regex"] + timeout = 600 + + def setup(self, model, _): + """Set up the benchmark. + + We convert the tokenizer during set up as this only + needs to be done once for a given model. + + """ + self.tokenizer = AutoTokenizer.from_pretrained( + model, clean_up_tokenization_spaces=True + ) + self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) + + def time_lfe(self, _, regex): + regex_string, regex_example = regex + regex_example_tokens = self.tokenizer.encode(regex_example) + + parser = JsonSchemaParser(regex_string) + token_enforcer = TokenEnforcer(self.tokenizer_data, parser) + + for i in range(len(regex_example_tokens)): + _ = token_enforcer.get_allowed_tokens(regex_example_tokens[: i + 1])