Skip to content

Commit

Permalink
small enhancement based on review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan committed Dec 17, 2024
1 parent 9f09fd7 commit 87ccb62
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 189 deletions.
6 changes: 2 additions & 4 deletions camel/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict

from pydantic import BaseModel


class BaseConverter(ABC):
r"""A base class for schema outputs that includes functionality
Expand All @@ -30,7 +28,7 @@ class BaseConverter(ABC):
@abstractmethod
def convert(
self, content: str, *args: Any, **kwargs: Dict[str, Any]
) -> BaseModel:
) -> Any:
r"""Structures the input text into the expected response format.
Args:
Expand All @@ -40,6 +38,6 @@ def convert(
prompt (Optional[str], optional): The prompt to be used.
Returns:
Optional[BaseModel]: The structured response.
Any: The converted response.
"""
pass
25 changes: 18 additions & 7 deletions camel/schemas/outlines_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from typing import Any, Callable, List, Literal, Type, Union

import outlines
from pydantic import BaseModel

from .base import BaseConverter
Expand Down Expand Up @@ -65,7 +64,7 @@ def __init__(
case _:
raise ValueError(f"Unsupported platform: {platform}")

def convert_regex(self, content: str, regex_pattern: str):
def convert_regex(self, content: str, regex_pattern: str) -> str:
r"""Convert the content to the specified regex pattern.
Args:
Expand All @@ -75,6 +74,8 @@ def convert_regex(self, content: str, regex_pattern: str):
Returns:
str: The converted content.
"""
import outlines

regex_generator = outlines.generate.regex(
self._outlines_model, regex_pattern
)
Expand All @@ -96,6 +97,8 @@ def convert_json(
Returns:
dict: The converted content in JSON format.
"""
import outlines

json_generator = outlines.generate.json(
self._outlines_model, output_schema
)
Expand All @@ -116,12 +119,14 @@ def convert_pydantic(
Returns:
BaseModel: The converted content in pydantic model format.
"""
import outlines

json_generator = outlines.generate.json(
self._outlines_model, output_schema
)
return json_generator(content)

def convert_type(self, content: str, type_name: type):
def convert_type(self, content: str, type_name: type) -> str:
r"""Convert the content to the specified type.
The following types are currently available:
Expand All @@ -140,12 +145,14 @@ def convert_type(self, content: str, type_name: type):
Returns:
str: The converted content.
"""
import outlines

type_generator = outlines.generate.format(
self._outlines_model, type_name
)
return type_generator(content)

def convert_choice(self, content: str, choices: List[str]):
def convert_choice(self, content: str, choices: List[str]) -> str:
r"""Convert the content to the specified choice.
Args:
Expand All @@ -155,12 +162,14 @@ def convert_choice(self, content: str, choices: List[str]):
Returns:
str: The converted content.
"""
import outlines

choices_generator = outlines.generate.choice(
self._outlines_model, choices
)
return choices_generator(content)

def convert_grammar(self, content: str, grammar: str):
def convert_grammar(self, content: str, grammar: str) -> str:
r"""Convert the content to the specified grammar.
Args:
Expand All @@ -170,17 +179,19 @@ def convert_grammar(self, content: str, grammar: str):
Returns:
str: The converted content.
"""
import outlines

grammar_generator = outlines.generate.cfg(
self._outlines_model, grammar
)
return grammar_generator(content)

def convert( # type: ignore[override]
self,
type: Literal["regex", "json", "type", "choice", "grammar"],
content: str,
type: Literal["regex", "json", "type", "choice", "grammar"],
**kwargs,
):
) -> Any:
r"""Formats the input content into the expected BaseModel.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
)

print(output)
'''
"""
===============================================================================
6:00 pm
===============================================================================
'''
"""


######## Pydantic conversion #########
Expand All @@ -52,17 +52,17 @@ class Temperature(BaseModel):
)

print(type(output))
'''
"""
===============================================================================
<class '__main__.Temperature'>
===============================================================================
'''
"""
print(output)
'''
"""
===============================================================================
location='Beijing' date='2023-09-01' temperature=30.0
===============================================================================
'''
"""


######## JSON conversion #########
Expand All @@ -87,17 +87,17 @@ class Temperature(BaseModel):
output_schema=schema,
)
print(type(output))
'''
"""
===============================================================================
<class 'dict'>
===============================================================================
'''
"""
print(output)
'''
"""
===============================================================================
{'name': 'John', 'last_name': 'Doe', 'id': 123456}
===============================================================================
'''
"""

# 2. Using a function (Callable)

Expand All @@ -112,17 +112,17 @@ def get_temperature(location: str, date: str, temperature: float):
)

print(type(output))
'''
"""
===============================================================================
<class 'dict'>
===============================================================================
'''
"""
print(output)
'''
"""
===============================================================================
{'location': 'Beijing', 'date': '2023-09-01', 'temperature': 30}
===============================================================================
'''
"""


######## Type constraints #########
Expand All @@ -133,11 +133,11 @@ def get_temperature(location: str, date: str, temperature: float):
)

print(output)
'''
"""
===============================================================================
35
===============================================================================
'''
"""


######## Mutliple choices #########
Expand All @@ -148,11 +148,11 @@ def get_temperature(location: str, date: str, temperature: float):
)

print(output)
'''
"""
===============================================================================
Madrid
===============================================================================
'''
"""


######## Grammer #########
Expand All @@ -178,8 +178,8 @@ def get_temperature(location: str, date: str, temperature: float):
)

print(output)
'''
"""
===============================================================================
(8-2)
===============================================================================
'''
"""
Loading

0 comments on commit 87ccb62

Please sign in to comment.