Skip to content

Commit

Permalink
feat: add test_cohere_model.py and solved comment problems
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilJohnson0930 committed Sep 23, 2024
1 parent 0686fde commit 3ac127b
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 58 deletions.
38 changes: 15 additions & 23 deletions camel/configs/cohere_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,29 @@

from typing import Optional

from pydantic import field_validator

from camel.configs.base_config import BaseConfig


class CohereConfig(BaseConfig):
"""Defines the parameters for generating chat completions using the
r"""Defines the parameters for generating chat completions using the
Cohere API.
Args:
temperature (Optional[float]): Controls randomness in the model.
temperature (Optional[float], optional):
Controls randomness in the model.
Values closer to 0 make the model more deterministic, while
values closer to 1 make it more random. Defaults to None.
p (Optional[float]): Sets a p% nucleus sampling threshold.
Defaults to None.
k (Optional[int]): Limits the number of tokens to sample from on
each step. Defaults to None.
max_tokens (Optional[int]): The maximum number of tokens to generate.
Defaults to None.
prompt_truncation (Optional[str]): How to truncate the prompt if it
values closer to 1 make it more random. (default: :obj:`None`)
p (Optional[float], optional): Sets a p% nucleus sampling threshold.
(default: :obj:`None`)
k (Optional[int], optional):
Limits the number of tokens to sample from on
each step. (default: :obj:`None`)
max_tokens (Optional[int], optional):
The maximum number of tokens to generate.
(default: :obj:`None`)
prompt_truncation (Optional[str], optional):
How to truncate the prompt if it
exceeds the model's context length. Can be 'START', 'END', or
'AUTO'. Defaults to None.
'AUTO'. (default: :obj:`None`)
"""

temperature: Optional[float] = None
Expand All @@ -45,15 +46,6 @@ class CohereConfig(BaseConfig):
max_tokens: Optional[int] = None
prompt_truncation: Optional[str] = None

@field_validator("prompt_truncation")
@classmethod
def validate_prompt_truncation(cls, v):
if v is not None and v not in ['START', 'END', 'AUTO']:
raise ValueError(
"prompt_truncation must be 'START', 'END', or 'AUTO'"
)
return v

def as_dict(self):
return {k: v for k, v in super().as_dict().items() if v is not None}

Expand Down
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .anthropic_model import AnthropicModel
from .azure_openai_model import AzureOpenAIModel
from .base_model import BaseModelBackend
from .cohere_model import CohereModel
from .gemini_model import GeminiModel
from .groq_model import GroqModel
from .litellm_model import LiteLLMModel
Expand Down Expand Up @@ -41,6 +42,7 @@
'GroqModel',
'StubModel',
'ZhipuAIModel',
'CohereModel',
'OpenSourceModel',
'ModelFactory',
'LiteLLMModel',
Expand Down
46 changes: 11 additions & 35 deletions camel/models/cohere_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
if TYPE_CHECKING:
from cohere.types import NonStreamedChatResponse

from camel.configs import COHERE_API_PARAMS
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ModelPlatformType, ModelType
from camel.types import ChatCompletion, ModelType
from camel.utils import (
BaseTokenCounter,
OpenAITokenCounter,
Expand All @@ -40,23 +41,23 @@


class CohereModel(BaseModelBackend):
"""Cohere API in a unified BaseModelBackend interface."""
r"""Cohere API in a unified BaseModelBackend interface."""

def __init__(
self,
model_type: ModelType,
model_config_dict: Dict[str, Any],
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
model_platform: Optional[ModelPlatformType] = None,
):
import cohere

super().__init__(
model_type, model_config_dict, api_key, token_counter=token_counter
)
self._api_key = api_key or os.environ.get("COHERE_API_KEY")
self.model_platform = model_platform
self._url = url or os.environ.get("COHERE_SERVER_URL")

self._client = cohere.Client(api_key=self._api_key)
self._token_counter: Optional[BaseTokenCounter] = None
Expand Down Expand Up @@ -143,53 +144,36 @@ def _to_cohere_chatmessage(

@property
def token_counter(self) -> BaseTokenCounter:
"""Initialize the token counter for the model backend.
r"""Initialize the token counter for the model backend.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = OpenAITokenCounter(
model=ModelType.GPT_3_5_TURBO
model=ModelType.GPT_4O_MINI
)
return self._token_counter

@api_keys_required("COHERE_API_KEY")
def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:
"""Runs inference of Cohere chat completion.
r"""Runs inference of Cohere chat completion.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
ChatCompletion.
"""
from cohere.core.api_error import ApiError

cohere_messages = self._to_cohere_chatmessage(messages)

# Filter out unsupported parameters
supported_params = {
'temperature',
'p',
'k',
'max_tokens',
'prompt_truncation',
}
filtered_config = {
k: v
for k, v in self.model_config_dict.items()
if k in supported_params
}

try:
response = self._client.chat(
message=cohere_messages[-1]["message"],
chat_history=cohere_messages[:-1], # type: ignore[arg-type]
model=self.model_type.value,
**filtered_config,
**self.model_config_dict,
)
except ApiError as e:
logging.error(f"Cohere API Error: {e.status_code}")
Expand Down Expand Up @@ -218,22 +202,14 @@ def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:
return openai_response

def check_model_config(self):
"""Check whether the model configuration contains any
r"""Check whether the model configuration contains any
unexpected arguments to Cohere API.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to Cohere API.
"""
supported_params = {
'temperature',
'p',
'k',
'max_tokens',
'prompt_truncation',
}
for param in self.model_config_dict:
if param not in supported_params:
if param not in COHERE_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into Cohere model backend."
Expand Down
8 changes: 8 additions & 0 deletions examples/models/cohere_model_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@
# Get response information
response = camel_agent.step(user_msg)
print(response.msgs[0].content)
'''
===============================================================================
Hello CAMEL AI! It's great to connect with an open-source community focused on
autonomous and communicative agents. Your work is fascinating and has a wide
range of applications. I look forward to learning more about your research and
contributions to the field of AI.
===============================================================================
'''
1 change: 1 addition & 0 deletions examples/models/role_playing_with_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def main(
model_platform=model_platform,
model_type=model_type,
model_config_dict=model_config.as_dict(),
api_key=api_key,
)

# Set up role playing session
Expand Down
62 changes: 62 additions & 0 deletions test/models/test_cohere_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import re

import pytest

from camel.configs import CohereConfig, OpenSourceConfig
from camel.models import CohereModel
from camel.types import ModelType
from camel.utils import OpenAITokenCounter


@pytest.mark.model_backend
@pytest.mark.parametrize(
"model_type",
[
ModelType.COHERE_COMMAND_R,
ModelType.COHERE_COMMAND_LIGHT,
ModelType.COHERE_COMMAND,
ModelType.COHERE_COMMAND_NIGHTLY,
],
)
def test_cohere_model(model_type):
model_config_dict = CohereConfig().as_dict()
model = CohereModel(model_type, model_config_dict)
assert model.model_type == model_type
assert model.model_config_dict == model_config_dict
assert isinstance(model.token_counter, OpenAITokenCounter)
assert isinstance(model.model_type.value_for_tiktoken, str)
assert isinstance(model.model_type.token_limit, int)


@pytest.mark.model_backend
def test_cohere_model_unexpected_argument():
model_type = ModelType.COHERE_COMMAND_R
model_config = OpenSourceConfig(
model_path="vicuna-7b-v1.5",
server_url="http://localhost:8000/v1",
)
model_config_dict = model_config.as_dict()

with pytest.raises(
ValueError,
match=re.escape(
(
"Unexpected argument `model_path` is "
"input into Cohere model backend."
)
),
):
_ = CohereModel(model_type, model_config_dict)

0 comments on commit 3ac127b

Please sign in to comment.