-
Notifications
You must be signed in to change notification settings - Fork 750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Implement outlines converter model for structured output (#1211) #1318
Changes from 3 commits
5d4862c
55c314f
73ce555
21e253e
8368a12
0919d0d
f7e3448
ca34a23
185d574
0fee0ae
9f09fd7
87ccb62
535a67a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
|
||
from typing import Any, Callable, List, Literal, Type, Union | ||
|
||
import outlines | ||
from pydantic import BaseModel | ||
|
||
from .base import BaseConverter | ||
|
||
|
||
class OutlinesConverter(BaseConverter): | ||
r"""OutlinesConverter is a class that converts a string or a function | ||
into a BaseModel schema. | ||
|
||
Args: | ||
model_type (str, optional): The model type to be used. | ||
platform (str, optional): The platform to be used. | ||
1. transformers | ||
2. mamba | ||
3. vllm | ||
4. llamacpp | ||
5. mlx | ||
(default: "transformers") | ||
**kwargs: The keyword arguments to be used. See the outlines | ||
documentation for more details. See | ||
https://dottxt-ai.github.io/outlines/latest/reference/models/models/ | ||
""" | ||
|
||
def __init__( | ||
self, model_type: str, platform: str = "transformers", **kwargs: Any | ||
): | ||
MuggleJinx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.model_type = model_type | ||
from outlines import models | ||
|
||
match platform: | ||
case "vllm": | ||
self._outlines_model = models.vllm(model_type, **kwargs) | ||
case "transformers": | ||
self._outlines_model = models.transformers( | ||
model_type, **kwargs | ||
) | ||
case "mamba": | ||
self._outlines_model = models.mamba(model_type, **kwargs) | ||
case "llamacpp": | ||
self._outlines_model = models.llamacpp(model_type, **kwargs) | ||
case "mlx": | ||
self._outlines_model = models.mlxlm(model_type, **kwargs) | ||
case _: | ||
raise ValueError(f"Unsupported platform: {platform}") | ||
|
||
def convert_regex(self, content: str, regex_pattern: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing return type hint, even it could be |
||
r"""Convert the content to the specified regex pattern. | ||
|
||
Args: | ||
content (str): The content to be converted. | ||
regex_pattern (str): The regex pattern to be used. | ||
|
||
Returns: | ||
str: The converted content. | ||
""" | ||
regex_generator = outlines.generate.regex( | ||
self._outlines_model, regex_pattern | ||
) | ||
return regex_generator(content) | ||
MuggleJinx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def convert_json( | ||
self, | ||
content: str, | ||
output_schema: Union[Type[BaseModel], str, Callable], | ||
): | ||
r"""Convert the content to the specified JSON schema. | ||
|
||
Args: | ||
content (str): The content to be converted. | ||
output_schema (Union[Type[BaseModel], str, Callable]): The expected | ||
format of the response. | ||
|
||
Returns: | ||
str: The converted content. | ||
""" | ||
json_generator = outlines.generate.json( | ||
self._outlines_model, output_schema | ||
) | ||
return json_generator(content) | ||
MuggleJinx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def convert_type(self, content: str, type_name: type): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing return type hint, same for other methods |
||
r"""Convert the content to the specified type. | ||
|
||
The following types are currently available: | ||
1. int | ||
2. float | ||
3. bool | ||
4. datetime.date | ||
5. datetime.time | ||
6. datetime.datetime | ||
7. custom types (https://dottxt-ai.github.io/outlines/latest/reference/generation/types/) | ||
|
||
Args: | ||
content (str): The content to be converted. | ||
type_name (type): The type to be used. | ||
|
||
Returns: | ||
str: The converted content. | ||
""" | ||
type_generator = outlines.generate.format( | ||
self._outlines_model, type_name | ||
) | ||
return type_generator(content) | ||
|
||
def convert_choice(self, content: str, choices: List[str]): | ||
r"""Convert the content to the specified choice. | ||
|
||
Args: | ||
content (str): The content to be converted. | ||
choices (List[str]): The choices to be used. | ||
|
||
Returns: | ||
str: The converted content. | ||
""" | ||
choices_generator = outlines.generate.choice( | ||
self._outlines_model, choices | ||
) | ||
return choices_generator(content) | ||
|
||
def convert_grammar(self, content: str, grammar: str): | ||
r"""Convert the content to the specified grammar. | ||
|
||
Args: | ||
content (str): The content to be converted. | ||
grammar (str): The grammar to be used. | ||
|
||
Returns: | ||
str: The converted content. | ||
""" | ||
grammar_generator = outlines.generate.cfg( | ||
self._outlines_model, grammar | ||
) | ||
return grammar_generator(content) | ||
|
||
def convert( # type: ignore[override] | ||
self, | ||
type: Literal["regex", "json", "type", "choice", "grammar"], | ||
content: str, | ||
**kwargs, | ||
) -> BaseModel: | ||
r"""Formats the input content into the expected BaseModel | ||
|
||
Args: | ||
content (str): The content to be formatted. | ||
**kwargs: The keyword arguments to be used. | ||
|
||
Returns: | ||
BaseModel: The formatted response. | ||
""" | ||
MuggleJinx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
match type: | ||
case "regex": | ||
return self.convert_regex(content, kwargs.get("regex_pattern")) # type: ignore[arg-type] | ||
case "json": | ||
return self.convert_json(content, kwargs.get("output_schema")) # type: ignore[arg-type] | ||
case "type": | ||
return self.convert_type(content, kwargs.get("type_name")) # type: ignore[arg-type] | ||
case "choice": | ||
return self.convert_choice(content, kwargs.get("choices")) # type: ignore[arg-type] | ||
case "grammar": | ||
return self.convert_grammar(content, kwargs.get("grammar")) # type: ignore[arg-type] | ||
case _: | ||
raise ValueError("Unsupported output schema type") |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. file naming, use |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
|
||
from pydantic import BaseModel | ||
|
||
from camel.schemas import OutlinesConverter | ||
|
||
# Define the model using OutlinesConverter | ||
model = OutlinesConverter( | ||
model_type="microsoft/Phi-3-mini-4k-instruct", platform="transformers" | ||
) | ||
|
||
######## Regex conversion ######### | ||
|
||
time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?" | ||
output = model.convert_regex( | ||
"The the best time to visit a dentist is at ", time_regex_pattern | ||
) | ||
|
||
print(output) | ||
''' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unify by using |
||
6:00 pm | ||
''' | ||
MuggleJinx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
######## JSON conversion ######### | ||
|
||
|
||
# 1. Using a Pydantic model | ||
class Temperature(BaseModel): | ||
location: str | ||
date: str | ||
temperature: float | ||
|
||
|
||
output = model.convert_json( | ||
"Today is 2023-09-01, the temperature in Beijing is 30 degrees.", | ||
output_schema=Temperature, | ||
) | ||
|
||
print(output) | ||
''' | ||
location='Beijing' date='2023-09-01' temperature=30.0 | ||
''' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the output format use '=' @Wendong-Fan , not {'location':'Beijing'} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I split the it into 2 functions, one returns pydantic object, one returns the internal dict object. |
||
|
||
# 2. Using a JSON schema | ||
schema = """ | ||
{ | ||
"title": "User", | ||
"type": "object", | ||
"properties": { | ||
"name": {"type": "string"}, | ||
"last_name": {"type": "string"}, | ||
"id": {"type": "integer"} | ||
}, | ||
"required": ["name", "last_name", "id"] | ||
} | ||
""" | ||
|
||
output = model.convert_json( | ||
"Create a user profile with the fields name, last_name and id", | ||
output_schema=schema, | ||
) | ||
|
||
print(output) | ||
''' | ||
{'name': 'John', 'last_name': 'Doe', 'id': 123456} | ||
''' | ||
|
||
|
||
####### Type constraints ####### | ||
output = model.convert_type( | ||
"When I was 6 my sister was half my age. Now I'm 70 how old is my sister?", | ||
int, | ||
) | ||
|
||
print(output) | ||
''' | ||
35 | ||
''' | ||
|
||
|
||
####### Mutliple choices ####### | ||
|
||
output = model.convert_choice( | ||
"What is the capital of Spain?", | ||
["Paris", "London", "Berlin", "Madrid"], | ||
) | ||
|
||
print(output) | ||
''' | ||
Madrid | ||
''' | ||
|
||
|
||
####### Grammer ####### | ||
|
||
arithmetic_grammar = """ | ||
?start: expression | ||
|
||
?expression: term (("+" | "-") term)* | ||
|
||
?term: factor (("*" | "/") factor)* | ||
|
||
?factor: NUMBER | ||
| "-" factor | ||
| "(" expression ")" | ||
|
||
%import common.NUMBER | ||
""" | ||
|
||
output = model.convert_grammar( | ||
"Alice had 4 apples and Bob ate 2. " | ||
+ "Write an expression for Alice's apples:", | ||
arithmetic_grammar, | ||
) | ||
|
||
print(output) | ||
''' | ||
(8-2) | ||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outlines is not necessary packages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sry, it is needed in
outlines.generate.regex
for example.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can move
import outlines
within the class to make it optional