Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement outlines converter model for structured output (#1211) #1318

Merged
merged 13 commits into from
Dec 17, 2024
3 changes: 2 additions & 1 deletion camel/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from .openai_converter import OpenAISchemaConverter
from .outlines_converter import OutlinesConverter

__all__ = ["OpenAISchemaConverter"]
__all__ = ["OpenAISchemaConverter", "OutlinesConverter"]
6 changes: 2 additions & 4 deletions camel/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict

from pydantic import BaseModel


class BaseConverter(ABC):
r"""A base class for schema outputs that includes functionality
Expand All @@ -30,7 +28,7 @@ class BaseConverter(ABC):
@abstractmethod
def convert(
self, content: str, *args: Any, **kwargs: Dict[str, Any]
) -> BaseModel:
) -> Any:
r"""Structures the input text into the expected response format.

Args:
Expand All @@ -40,6 +38,6 @@ def convert(
prompt (Optional[str], optional): The prompt to be used.

Returns:
Optional[BaseModel]: The structured response.
Any: The converted response.
"""
pass
249 changes: 249 additions & 0 deletions camel/schemas/outlines_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from typing import Any, Callable, List, Literal, Type, Union

from pydantic import BaseModel

from .base import BaseConverter


class OutlinesConverter(BaseConverter):
r"""OutlinesConverter is a class that converts a string or a function
into a BaseModel schema.

Args:
model_type (str, optional): The model type to be used.
platform (str, optional): The platform to be used.
1. transformers
2. mamba
3. vllm
4. llamacpp
5. mlx
(default: "transformers")
**kwargs: The keyword arguments to be used. See the outlines
documentation for more details. See
https://dottxt-ai.github.io/outlines/latest/reference/models/models/
"""

def __init__(
self,
model_type: str,
platform: Literal[
"vllm", "transformers", "mamba", "llamacpp", "mlx"
] = "transformers",
**kwargs: Any,
):
self.model_type = model_type
from outlines import models

match platform:
case "vllm":
self._outlines_model = models.vllm(model_type, **kwargs)
case "transformers":
self._outlines_model = models.transformers(
model_type, **kwargs
)
case "mamba":
self._outlines_model = models.mamba(model_type, **kwargs)
case "llamacpp":
self._outlines_model = models.llamacpp(model_type, **kwargs)
case "mlx":
self._outlines_model = models.mlxlm(model_type, **kwargs)
case _:
raise ValueError(f"Unsupported platform: {platform}")

def convert_regex(self, content: str, regex_pattern: str) -> str:
r"""Convert the content to the specified regex pattern.

Args:
content (str): The content to be converted.
regex_pattern (str): The regex pattern to be used.

Returns:
str: The converted content.
"""
import outlines

regex_generator = outlines.generate.regex(
self._outlines_model, regex_pattern
)
return regex_generator(content)

def convert_json(
self,
content: str,
output_schema: Union[str, Callable],
) -> dict:
r"""Convert the content to the specified JSON schema given by
output_schema.

Args:
content (str): The content to be converted.
output_schema (Union[str, Callable]): The expected format of the
response.

Returns:
dict: The converted content in JSON format.
"""
import outlines

json_generator = outlines.generate.json(
self._outlines_model, output_schema
)
return json_generator(content)

def convert_pydantic(
self,
content: str,
output_schema: Type[BaseModel],
) -> BaseModel:
r"""Convert the content to the specified Pydantic schema.

Args:
content (str): The content to be converted.
output_schema (Type[BaseModel]): The expected format of the
response.

Returns:
BaseModel: The converted content in pydantic model format.
"""
import outlines

json_generator = outlines.generate.json(
self._outlines_model, output_schema
)
return json_generator(content)

def convert_type(self, content: str, type_name: type) -> str:
r"""Convert the content to the specified type.

The following types are currently available:
1. int
2. float
3. bool
4. datetime.date
5. datetime.time
6. datetime.datetime
7. custom types (https://dottxt-ai.github.io/outlines/latest/reference/generation/types/)

Args:
content (str): The content to be converted.
type_name (type): The type to be used.

Returns:
str: The converted content.
"""
import outlines

type_generator = outlines.generate.format(
self._outlines_model, type_name
)
return type_generator(content)

def convert_choice(self, content: str, choices: List[str]) -> str:
r"""Convert the content to the specified choice.

Args:
content (str): The content to be converted.
choices (List[str]): The choices to be used.

Returns:
str: The converted content.
"""
import outlines

choices_generator = outlines.generate.choice(
self._outlines_model, choices
)
return choices_generator(content)

def convert_grammar(self, content: str, grammar: str) -> str:
r"""Convert the content to the specified grammar.

Args:
content (str): The content to be converted.
grammar (str): The grammar to be used.

Returns:
str: The converted content.
"""
import outlines

grammar_generator = outlines.generate.cfg(
self._outlines_model, grammar
)
return grammar_generator(content)

def convert( # type: ignore[override]
self,
content: str,
type: Literal["regex", "json", "type", "choice", "grammar"],
**kwargs,
) -> Any:
r"""Formats the input content into the expected BaseModel.

Args:
type (Literal["regex", "json", "type", "choice", "grammar"]):
The type of conversion to perform. Options are:
- "regex": Match the content against a regex pattern.
- "pydantic": Convert the content into a pydantic model.
- "json": Convert the content into a JSON based on a
schema.
- "type": Convert the content into a specified type.
- "choice": Match the content against a list of valid
choices.
- "grammar": Convert the content using a specified grammar.
content (str): The content to be formatted.
**kwargs: Additional keyword arguments specific to the conversion
type.

- For "regex":
regex_pattern (str): The regex pattern to use for matching.

- For "pydantic":
output_schema (Type[BaseModel]): The schema to validate and
format the pydantic model.

- For "json":
output_schema (Union[str, Callable]): The schema to validate
and format the JSON object.

- For "type":
type_name (str): The target type name for the conversion.

- For "choice":
choices (List[str]): A list of valid choices to match against.

- For "grammar":
grammar (str): The grammar definition to use for content
conversion.
"""
match type:
case "regex":
return self.convert_regex(content, kwargs.get("regex_pattern")) # type: ignore[arg-type]
case "pydantic":
return self.convert_pydantic(
content, kwargs.get("output_schema")
) # type: ignore[arg-type]
case "json":
return self.convert_json(content, kwargs.get("output_schema")) # type: ignore[arg-type]
case "type":
return self.convert_type(content, kwargs.get("type_name")) # type: ignore[arg-type]
case "choice":
return self.convert_choice(content, kwargs.get("choices")) # type: ignore[arg-type]
case "grammar":
return self.convert_grammar(content, kwargs.get("grammar")) # type: ignore[arg-type]
case _:
raise ValueError("Unsupported output schema type")
Loading
Loading