From 7166c8cf90af4b26e99219a57d663ba545128312 Mon Sep 17 00:00:00 2001 From: keenanpepper Date: Sat, 25 Jan 2025 16:07:52 -0800 Subject: [PATCH 1/6] community: Add Goodfire chat model --- libs/community/extended_testing_deps.txt | 1 + .../chat_models/__init__.py | 9 +- .../chat_models/goodfire.py | 174 ++++++++++++++++++ .../unit_tests/chat_models/test_goodfire.py | 88 +++++++++ .../unit_tests/chat_models/test_imports.py | 3 +- 5 files changed, 272 insertions(+), 3 deletions(-) create mode 100644 libs/community/langchain_community/chat_models/goodfire.py create mode 100644 libs/community/tests/unit_tests/chat_models/test_goodfire.py diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index e04a857ae17fd..4192e034e4034 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -27,6 +27,7 @@ friendli-client>=1.2.4,<2 geopandas>=0.13.1 gitpython>=3.1.32,<4 gliner>=0.2.7 +goodfire>=0.3.4 google-cloud-documentai>=2.20.1,<3 gql>=3.4.1,<4 gradientai>=1.4.0,<2 diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 9c83bdecbfc88..c77919c36b307 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -70,6 +70,9 @@ from langchain_community.chat_models.gigachat import ( GigaChat, ) + from langchain_community.chat_models.goodfire import ( + Goodfire, + ) from langchain_community.chat_models.google_palm import ( ChatGooglePalm, ) @@ -245,8 +248,9 @@ "ChatLlamaCpp", "ErnieBotChat", "FakeListChatModel", - "GPTRouter", "GigaChat", + "Goodfire", + "GPTRouter", "HumanInputChatModel", "JinaChat", "LlamaEdgeChatService", @@ -310,8 +314,9 @@ "ChatZhipuAI": "langchain_community.chat_models.zhipuai", "ErnieBotChat": "langchain_community.chat_models.ernie", "FakeListChatModel": "langchain_community.chat_models.fake", - "GPTRouter": "langchain_community.chat_models.gpt_router", "GigaChat": "langchain_community.chat_models.gigachat", + "Goodfire": "langchain_community.chat_models.goodfire", + "GPTRouter": "langchain_community.chat_models.gpt_router", "HumanInputChatModel": "langchain_community.chat_models.human", "JinaChat": "langchain_community.chat_models.jinachat", "LlamaEdgeChatService": "langchain_community.chat_models.llama_edge", diff --git a/libs/community/langchain_community/chat_models/goodfire.py b/libs/community/langchain_community/chat_models/goodfire.py new file mode 100644 index 0000000000000..94d79f4ec9297 --- /dev/null +++ b/libs/community/langchain_community/chat_models/goodfire.py @@ -0,0 +1,174 @@ +from typing import Any, Dict, List, Optional + +import goodfire +from goodfire.variants.variants import SUPPORTED_MODELS +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from pydantic import Field, SecretStr, model_validator + + +def format_for_goodfire(messages: List[BaseMessage]) -> List[dict]: + """ + Format messages for Goodfire by setting "role" based on the message type. + """ + output = [] + for message in messages: + if isinstance(message, HumanMessage): + output.append({"role": "user", "content": message.content}) + elif isinstance(message, AIMessage): + output.append({"role": "assistant", "content": message.content}) + elif isinstance(message, SystemMessage): + output.append({"role": "system", "content": message.content}) + else: + raise ValueError(f"Unknown message type: {type(message)}") + return output + + +def format_for_langchain(message: dict) -> BaseMessage: + """ + Format a Goodfire message for Langchain. This assumes that the message is an + assistant message (AIMessage). + """ + assert message["role"] == "assistant", ( + f"Expected role 'assistant', got {message['role']}" + ) + return AIMessage(content=message["content"]) + + +class Goodfire(BaseChatModel): + """Goodfire chat model.""" + + goodfire_api_key: SecretStr = Field(default=SecretStr("")) + sync_client: goodfire.Client = Field( + default_factory=lambda: goodfire.Client(api_key="") + ) + async_client: goodfire.AsyncClient = Field( + default_factory=lambda: goodfire.AsyncClient(api_key="") + ) + variant: goodfire.Variant # Removed default - this must be set + + @property + def _llm_type(self) -> str: + return "goodfire" + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"goodfire_api_key": "GOODFIRE_API_KEY"} + + def __init__( + self, + model: SUPPORTED_MODELS, + goodfire_api_key: Optional[str] = None, + variant: Optional[goodfire.Variant] = None, + **kwargs: Any, + ): + """Initialize the Goodfire chat model. + + Args: + model: The model to use, must be one of the supported models. + goodfire_api_key: The API key to use. If None, will look for + GOODFIRE_API_KEY env var. + variant: Optional variant to use. If not provided, will be created + from the model parameter. + """ + # Create variant first + variant_instance = variant or goodfire.Variant(model) + + # Include variant in kwargs for parent initialization + kwargs["variant"] = variant_instance + + # Initialize parent class + super().__init__(**kwargs) + + # Initialize API key and clients if provided + if goodfire_api_key: + self.goodfire_api_key = SecretStr(goodfire_api_key) + self.sync_client = goodfire.Client(api_key=goodfire_api_key) + self.async_client = goodfire.AsyncClient(api_key=goodfire_api_key) + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key exists in environment.""" + values["goodfire_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "goodfire_api_key", + "GOODFIRE_API_KEY", + ) + ) + + # Initialize clients with the validated API key + api_key = values["goodfire_api_key"].get_secret_value() + values["sync_client"] = goodfire.Client(api_key=api_key) + values["async_client"] = goodfire.AsyncClient(api_key=api_key) + + return values + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """ + Generate a response from Goodfire. + """ + + # If a model is provided, use it instead of the default variant + if "model" in kwargs: + model = kwargs.pop("model") + else: + model = self.variant + + goodfire_response = self.sync_client.chat.completions.create( + messages=format_for_goodfire(messages), + model=model, + **kwargs, + ) + + return ChatResult( + generations=[ + ChatGeneration( + message=format_for_langchain(goodfire_response.choices[0].message) + ) + ] + ) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """ + Generate a response from Goodfire. + """ + + # If a model is provided, use it instead of the default variant + if "model" in kwargs: + model = kwargs.pop("model") + else: + model = self.variant + + goodfire_response = await self.async_client.chat.completions.create( + messages=format_for_goodfire(messages), + model=model, + **kwargs, + ) + + return ChatResult( + generations=[ + ChatGeneration( + message=format_for_langchain(goodfire_response.choices[0].message) + ) + ] + ) diff --git a/libs/community/tests/unit_tests/chat_models/test_goodfire.py b/libs/community/tests/unit_tests/chat_models/test_goodfire.py new file mode 100644 index 0000000000000..4e3cfa57d92f3 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_goodfire.py @@ -0,0 +1,88 @@ +"""Test Goodfire Chat API wrapper.""" + +import os +from typing import List + +import goodfire +import pytest +from goodfire.variants.variants import SUPPORTED_MODELS +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage + +from langchain_community.chat_models import Goodfire +from langchain_community.chat_models.goodfire import ( + format_for_goodfire, + format_for_langchain, +) + +os.environ["GOODFIRE_API_KEY"] = "test_key" + +VALID_MODEL: SUPPORTED_MODELS = "meta-llama/Llama-3.3-70B-Instruct" + + +@pytest.mark.requires("goodfire") +def test_goodfire_model_param() -> None: + llm = Goodfire(model=VALID_MODEL) + assert isinstance(llm.variant, goodfire.Variant) + assert llm.variant.base_model == VALID_MODEL + + +@pytest.mark.requires("goodfire") +def test_goodfire_initialization() -> None: + """Test goodfire initialization with API key.""" + llm = Goodfire(model=VALID_MODEL, goodfire_api_key="test_key") + assert llm.goodfire_api_key.get_secret_value() == "test_key" + assert isinstance(llm.sync_client, goodfire.Client) + assert isinstance(llm.async_client, goodfire.AsyncClient) + + +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ([HumanMessage(content="Hello")], [{"role": "user", "content": "Hello"}]), + ( + [HumanMessage(content="Hello"), AIMessage(content="Hi there!")], + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ), + ( + [ + SystemMessage(content="You're an assistant"), + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + ], + [ + {"role": "system", "content": "You're an assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ), + ], +) +def test_message_formatting(messages: List[BaseMessage], expected: List[dict]) -> None: + result = format_for_goodfire(messages) + assert result == expected + + +def test_format_for_langchain() -> None: + message = {"role": "assistant", "content": "Hello there!"} + result = format_for_langchain(message) + assert isinstance(result, AIMessage) + assert result.content == "Hello there!" + + +def test_format_for_langchain_invalid_role() -> None: + message = {"role": "user", "content": "Hello"} + with pytest.raises(AssertionError, match="Expected role 'assistant'"): + format_for_langchain(message) + + +@pytest.mark.requires("goodfire") +def test_invalid_message_type() -> None: + class CustomMessage(BaseMessage): + content: str + type: str = "custom" + + with pytest.raises(ValueError, match="Unknown message type"): + format_for_goodfire([CustomMessage(content="test")]) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 1c7dff198f2b6..c930a88483754 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -50,8 +50,9 @@ "ChatZhipuAI", "ErnieBotChat", "FakeListChatModel", - "GPTRouter", "GigaChat", + "Goodfire", + "GPTRouter", "HumanInputChatModel", "JinaChat", "LlamaEdgeChatService", From edbc911c9ae31dc573866e9d8e81f35d5d4256ed Mon Sep 17 00:00:00 2001 From: keenanpepper Date: Sat, 25 Jan 2025 16:51:04 -0800 Subject: [PATCH 2/6] Make sure to import optional dependency inside function --- .../chat_models/goodfire.py | 32 ++++++++++++------- .../unit_tests/chat_models/test_goodfire.py | 18 +++++++++-- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/libs/community/langchain_community/chat_models/goodfire.py b/libs/community/langchain_community/chat_models/goodfire.py index 94d79f4ec9297..039e52ba2b50b 100644 --- a/libs/community/langchain_community/chat_models/goodfire.py +++ b/libs/community/langchain_community/chat_models/goodfire.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional -import goodfire -from goodfire.variants.variants import SUPPORTED_MODELS from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -45,13 +43,9 @@ class Goodfire(BaseChatModel): """Goodfire chat model.""" goodfire_api_key: SecretStr = Field(default=SecretStr("")) - sync_client: goodfire.Client = Field( - default_factory=lambda: goodfire.Client(api_key="") - ) - async_client: goodfire.AsyncClient = Field( - default_factory=lambda: goodfire.AsyncClient(api_key="") - ) - variant: goodfire.Variant # Removed default - this must be set + sync_client: Any = Field(default=None) + async_client: Any = Field(default=None) + variant: Any # Changed type hint since we can't import goodfire at module level @property def _llm_type(self) -> str: @@ -63,9 +57,9 @@ def lc_secrets(self) -> Dict[str, str]: def __init__( self, - model: SUPPORTED_MODELS, + model: str, # Changed from SUPPORTED_MODELS since we can't import it goodfire_api_key: Optional[str] = None, - variant: Optional[goodfire.Variant] = None, + variant: Optional[Any] = None, **kwargs: Any, ): """Initialize the Goodfire chat model. @@ -77,6 +71,14 @@ def __init__( variant: Optional variant to use. If not provided, will be created from the model parameter. """ + try: + import goodfire + except ImportError as e: + raise ImportError( + "Could not import goodfire python package. " + "Please install it with `pip install goodfire`." + ) from e + # Create variant first variant_instance = variant or goodfire.Variant(model) @@ -96,6 +98,14 @@ def __init__( @classmethod def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" + try: + import goodfire + except ImportError as e: + raise ImportError( + "Could not import goodfire python package. " + "Please install it with `pip install goodfire`." + ) from e + values["goodfire_api_key"] = convert_to_secret_str( get_from_dict_or_env( values, diff --git a/libs/community/tests/unit_tests/chat_models/test_goodfire.py b/libs/community/tests/unit_tests/chat_models/test_goodfire.py index 4e3cfa57d92f3..8997693224388 100644 --- a/libs/community/tests/unit_tests/chat_models/test_goodfire.py +++ b/libs/community/tests/unit_tests/chat_models/test_goodfire.py @@ -3,9 +3,7 @@ import os from typing import List -import goodfire import pytest -from goodfire.variants.variants import SUPPORTED_MODELS from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_community.chat_models import Goodfire @@ -16,11 +14,18 @@ os.environ["GOODFIRE_API_KEY"] = "test_key" -VALID_MODEL: SUPPORTED_MODELS = "meta-llama/Llama-3.3-70B-Instruct" +VALID_MODEL: str = "meta-llama/Llama-3.3-70B-Instruct" @pytest.mark.requires("goodfire") def test_goodfire_model_param() -> None: + try: + import goodfire + except ImportError as e: + raise ImportError( + "Could not import goodfire python package. " + "Please install it with `pip install goodfire`." + ) from e llm = Goodfire(model=VALID_MODEL) assert isinstance(llm.variant, goodfire.Variant) assert llm.variant.base_model == VALID_MODEL @@ -29,6 +34,13 @@ def test_goodfire_model_param() -> None: @pytest.mark.requires("goodfire") def test_goodfire_initialization() -> None: """Test goodfire initialization with API key.""" + try: + import goodfire + except ImportError as e: + raise ImportError( + "Could not import goodfire python package. " + "Please install it with `pip install goodfire`." + ) from e llm = Goodfire(model=VALID_MODEL, goodfire_api_key="test_key") assert llm.goodfire_api_key.get_secret_value() == "test_key" assert isinstance(llm.sync_client, goodfire.Client) From faeee1c15f0dc23829e9e269a8b04586c50ac04a Mon Sep 17 00:00:00 2001 From: keenanpepper Date: Sat, 25 Jan 2025 17:09:23 -0800 Subject: [PATCH 3/6] Rename to ChatGoodfire --- .../langchain_community/chat_models/__init__.py | 10 +++++----- .../langchain_community/chat_models/goodfire.py | 2 +- .../tests/unit_tests/chat_models/test_goodfire.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index c77919c36b307..ab1232ec294e6 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -71,7 +71,7 @@ GigaChat, ) from langchain_community.chat_models.goodfire import ( - Goodfire, + ChatGoodfire, ) from langchain_community.chat_models.google_palm import ( ChatGooglePalm, @@ -213,6 +213,7 @@ "ChatEverlyAI", "ChatFireworks", "ChatFriendli", + "ChatGoodfire", "ChatGooglePalm", "ChatHuggingFace", "ChatHunyuan", @@ -248,9 +249,8 @@ "ChatLlamaCpp", "ErnieBotChat", "FakeListChatModel", - "GigaChat", - "Goodfire", "GPTRouter", + "GigaChat", "HumanInputChatModel", "JinaChat", "LlamaEdgeChatService", @@ -280,6 +280,7 @@ "ChatEdenAI": "langchain_community.chat_models.edenai", "ChatFireworks": "langchain_community.chat_models.fireworks", "ChatFriendli": "langchain_community.chat_models.friendli", + "ChatGoodfire": "langchain_community.chat_models.goodfire", "ChatGooglePalm": "langchain_community.chat_models.google_palm", "ChatHuggingFace": "langchain_community.chat_models.huggingface", "ChatHunyuan": "langchain_community.chat_models.hunyuan", @@ -314,9 +315,8 @@ "ChatZhipuAI": "langchain_community.chat_models.zhipuai", "ErnieBotChat": "langchain_community.chat_models.ernie", "FakeListChatModel": "langchain_community.chat_models.fake", - "GigaChat": "langchain_community.chat_models.gigachat", - "Goodfire": "langchain_community.chat_models.goodfire", "GPTRouter": "langchain_community.chat_models.gpt_router", + "GigaChat": "langchain_community.chat_models.gigachat", "HumanInputChatModel": "langchain_community.chat_models.human", "JinaChat": "langchain_community.chat_models.jinachat", "LlamaEdgeChatService": "langchain_community.chat_models.llama_edge", diff --git a/libs/community/langchain_community/chat_models/goodfire.py b/libs/community/langchain_community/chat_models/goodfire.py index 039e52ba2b50b..916bdc56562c6 100644 --- a/libs/community/langchain_community/chat_models/goodfire.py +++ b/libs/community/langchain_community/chat_models/goodfire.py @@ -39,7 +39,7 @@ def format_for_langchain(message: dict) -> BaseMessage: return AIMessage(content=message["content"]) -class Goodfire(BaseChatModel): +class ChatGoodfire(BaseChatModel): """Goodfire chat model.""" goodfire_api_key: SecretStr = Field(default=SecretStr("")) diff --git a/libs/community/tests/unit_tests/chat_models/test_goodfire.py b/libs/community/tests/unit_tests/chat_models/test_goodfire.py index 8997693224388..8de71b54284a1 100644 --- a/libs/community/tests/unit_tests/chat_models/test_goodfire.py +++ b/libs/community/tests/unit_tests/chat_models/test_goodfire.py @@ -6,7 +6,7 @@ import pytest from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langchain_community.chat_models import Goodfire +from langchain_community.chat_models import ChatGoodfire from langchain_community.chat_models.goodfire import ( format_for_goodfire, format_for_langchain, @@ -26,7 +26,7 @@ def test_goodfire_model_param() -> None: "Could not import goodfire python package. " "Please install it with `pip install goodfire`." ) from e - llm = Goodfire(model=VALID_MODEL) + llm = ChatGoodfire(model=VALID_MODEL) assert isinstance(llm.variant, goodfire.Variant) assert llm.variant.base_model == VALID_MODEL @@ -41,7 +41,7 @@ def test_goodfire_initialization() -> None: "Could not import goodfire python package. " "Please install it with `pip install goodfire`." ) from e - llm = Goodfire(model=VALID_MODEL, goodfire_api_key="test_key") + llm = ChatGoodfire(model=VALID_MODEL, goodfire_api_key="test_key") assert llm.goodfire_api_key.get_secret_value() == "test_key" assert isinstance(llm.sync_client, goodfire.Client) assert isinstance(llm.async_client, goodfire.AsyncClient) From 2a54d6512df48600a5271af41964d26094145920 Mon Sep 17 00:00:00 2001 From: keenanpepper Date: Sat, 25 Jan 2025 17:10:34 -0800 Subject: [PATCH 4/6] Update test with ChatGoodfire --- libs/community/tests/unit_tests/chat_models/test_imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index c930a88483754..4ee99db67f59d 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -15,6 +15,7 @@ "ChatEdenAI", "ChatFireworks", "ChatFriendli", + "ChatGoodfire", "ChatGooglePalm", "ChatHuggingFace", "ChatHunyuan", @@ -51,7 +52,6 @@ "ErnieBotChat", "FakeListChatModel", "GigaChat", - "Goodfire", "GPTRouter", "HumanInputChatModel", "JinaChat", From beefd6773e95f63d1116458eec2cf5385047fdae Mon Sep 17 00:00:00 2001 From: keenanpepper Date: Sat, 25 Jan 2025 17:11:24 -0800 Subject: [PATCH 5/6] Restore order of unrelated stuff --- libs/community/tests/unit_tests/chat_models/test_imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 4ee99db67f59d..cef97deb4b53b 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -51,8 +51,8 @@ "ChatZhipuAI", "ErnieBotChat", "FakeListChatModel", - "GigaChat", "GPTRouter", + "GigaChat", "HumanInputChatModel", "JinaChat", "LlamaEdgeChatService", From 0fee32a3dd98dcdfec1b4a4b5e56f73d242e764c Mon Sep 17 00:00:00 2001 From: keenanpepper Date: Sat, 25 Jan 2025 17:54:48 -0800 Subject: [PATCH 6/6] Simplify constructor and add docs --- docs/docs/integrations/chat/goodfire.ipynb | 266 ++++++++++++++++++ .../chat_models/goodfire.py | 21 +- .../unit_tests/chat_models/test_goodfire.py | 22 +- 3 files changed, 291 insertions(+), 18 deletions(-) create mode 100644 docs/docs/integrations/chat/goodfire.ipynb diff --git a/docs/docs/integrations/chat/goodfire.ipynb b/docs/docs/integrations/chat/goodfire.ipynb new file mode 100644 index 0000000000000..514982252316f --- /dev/null +++ b/docs/docs/integrations/chat/goodfire.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "afaf8039", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "---\n", + "sidebar_label: Goodfire\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "e49f1e0d", + "metadata": {}, + "source": [ + "# Goodfire\n", + "\n", + "Goodfire is an AI inference platform to run certain Llama models with SAE feature steering. See the [Goodfire docs](https://docs.goodfire.ai/) for more information." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if \"GOODFIRE_API_KEY\" not in os.environ:\n", + " os.environ[\"GOODFIRE_API_KEY\"] = getpass.getpass(\"Enter your Goodfire API key: \")" + ] + }, + { + "cell_type": "markdown", + "id": "a38cde65-254d-4219-a441-068766c0d4b5", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our model object and generate chat completions:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "model must be a Goodfire variant, got ", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m enthusiasm_variant \u001b[38;5;241m=\u001b[39m goodfire\u001b[38;5;241m.\u001b[39mVariant(MODEL_NAME)\n\u001b[1;32m 13\u001b[0m enthusiasm_variant\u001b[38;5;241m.\u001b[39mset(enthusiasm_feature, \u001b[38;5;241m0.3\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m llm \u001b[38;5;241m=\u001b[39m \u001b[43mChatGoodfire\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mMODEL_NAME\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43mvariant\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menthusiasm_variant\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.6\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# other params...\u001b[39;49;00m\n\u001b[1;32m 21\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/langchain/libs/community/langchain_community/chat_models/goodfire.py:80\u001b[0m, in \u001b[0;36mChatGoodfire.__init__\u001b[0;34m(self, model, goodfire_api_key, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not import goodfire python package. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease install it with `pip install goodfire`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 77\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, goodfire\u001b[38;5;241m.\u001b[39mVariant):\n\u001b[0;32m---> 80\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel must be a Goodfire variant, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(model)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 82\u001b[0m \u001b[38;5;66;03m# Include model in kwargs for parent initialization\u001b[39;00m\n\u001b[1;32m 83\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m model\n", + "\u001b[0;31mValueError\u001b[0m: model must be a Goodfire variant, got " + ] + } + ], + "source": [ + "from langchain_community.chat_models import ChatGoodfire\n", + "import goodfire\n", + "\n", + "MODEL_NAME = \"meta-llama/Llama-3.3-70B-Instruct\"\n", + "\n", + "goodfire_client = goodfire.Client(api_key=os.environ[\"GOODFIRE_API_KEY\"])\n", + "\n", + "base_variant = goodfire.Variant(MODEL_NAME)\n", + "\n", + "enthusiasm_feature = goodfire_client.features.lookup([55543], base_variant)[55543]\n", + "\n", + "enthusiasm_variant = goodfire.Variant(MODEL_NAME)\n", + "enthusiasm_variant.set(enthusiasm_feature, 0.3)\n", + "\n", + "llm = ChatGoodfire(\n", + " model=enthusiasm_variant,\n", + " temperature=0.6,\n", + " seed=42,\n", + " # other params...\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2b4f3e15", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "62e0dbc3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='J\\'ADORE LA PROGRAMMATION! \\n\\n(or in a more casual tone: J\\'ADORE LE CODAGE!)\\n\\nNote: \"J\\'adore\" is a stronger way to say \"I love\" in French, it\\'s more like \"I\\'m crazy about\" or \"I\\'m absolutely passionate about\". If you want to use a more literal translation, you can say: \"J\\'aime la programmation\" which means \"I like programming\".', additional_kwargs={}, response_metadata={}, id='run-d91dd50b-1d6a-4c04-a78c-b1b922c1fc92-0')" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", + " ),\n", + " (\"human\", \"I love programming.\"),\n", + "]\n", + "ai_msg = llm.invoke(messages)\n", + "ai_msg" + ] + }, + { + "cell_type": "markdown", + "id": "39f7d928", + "metadata": {}, + "source": [ + "Note: The variant can be overridden after instantiation by providing a new variant to the `model` parameter." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "ceac2cb6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"J'adore la programmation.\", additional_kwargs={}, response_metadata={}, id='run-b646d8ed-74c3-40a2-8530-7f094060bf23-0')" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ai_msg = llm.invoke(messages, model=base_variant)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "d86145b3-bfef-46e8-b227-4dda5c9c2705", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "J'ADORE LA PROGRAMMATION! \n", + "\n", + "(or in a more casual tone: J'ADORE LE CODAGE!)\n", + "\n", + "Note: \"J'adore\" is a stronger way to say \"I love\" in French, it's more like \"I'm crazy about\" or \"I'm absolutely passionate about\". If you want to use a more literal translation, you can say: \"J'aime la programmation\" which means \"I like programming\".\n" + ] + } + ], + "source": [ + "print(ai_msg.content)" + ] + }, + { + "cell_type": "markdown", + "id": "18e2bfc0-7e78-4528-a73f-499ac150dca8", + "metadata": {}, + "source": [ + "## Chaining\n", + "\n", + "We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Ich liebe das Programmieren.', additional_kwargs={}, response_metadata={}, id='run-f77167ac-e9a8-4fc0-9e43-5a4800290324-0')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", + " ),\n", + " (\"human\", \"{input}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | llm\n", + "chain.invoke(\n", + " {\n", + " \"input_language\": \"English\",\n", + " \"output_language\": \"German\",\n", + " \"input\": \"I love programming.\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all ChatGoodfire features and configurations head to the API reference: https://python.langchain.com/api_reference/goodfire/chat_models/langchain_goodfire.chat_models.ChatGoodfire.html\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/chat_models/goodfire.py b/libs/community/langchain_community/chat_models/goodfire.py index 916bdc56562c6..ca131332de105 100644 --- a/libs/community/langchain_community/chat_models/goodfire.py +++ b/libs/community/langchain_community/chat_models/goodfire.py @@ -45,7 +45,7 @@ class ChatGoodfire(BaseChatModel): goodfire_api_key: SecretStr = Field(default=SecretStr("")) sync_client: Any = Field(default=None) async_client: Any = Field(default=None) - variant: Any # Changed type hint since we can't import goodfire at module level + model: Any # Changed type hint since we can't import goodfire at module level @property def _llm_type(self) -> str: @@ -57,19 +57,16 @@ def lc_secrets(self) -> Dict[str, str]: def __init__( self, - model: str, # Changed from SUPPORTED_MODELS since we can't import it + model: Any, goodfire_api_key: Optional[str] = None, - variant: Optional[Any] = None, **kwargs: Any, ): """Initialize the Goodfire chat model. Args: - model: The model to use, must be one of the supported models. + model: The Goodfire variant to use. goodfire_api_key: The API key to use. If None, will look for GOODFIRE_API_KEY env var. - variant: Optional variant to use. If not provided, will be created - from the model parameter. """ try: import goodfire @@ -79,11 +76,11 @@ def __init__( "Please install it with `pip install goodfire`." ) from e - # Create variant first - variant_instance = variant or goodfire.Variant(model) + if not isinstance(model, goodfire.Variant): + raise ValueError(f"model must be a Goodfire variant, got {type(model)}") - # Include variant in kwargs for parent initialization - kwargs["variant"] = variant_instance + # Include model in kwargs for parent initialization + kwargs["model"] = model # Initialize parent class super().__init__(**kwargs) @@ -136,7 +133,7 @@ def _generate( if "model" in kwargs: model = kwargs.pop("model") else: - model = self.variant + model = self.model goodfire_response = self.sync_client.chat.completions.create( messages=format_for_goodfire(messages), @@ -167,7 +164,7 @@ async def _agenerate( if "model" in kwargs: model = kwargs.pop("model") else: - model = self.variant + model = self.model goodfire_response = await self.async_client.chat.completions.create( messages=format_for_goodfire(messages), diff --git a/libs/community/tests/unit_tests/chat_models/test_goodfire.py b/libs/community/tests/unit_tests/chat_models/test_goodfire.py index 8de71b54284a1..44b8906394a10 100644 --- a/libs/community/tests/unit_tests/chat_models/test_goodfire.py +++ b/libs/community/tests/unit_tests/chat_models/test_goodfire.py @@ -1,7 +1,7 @@ """Test Goodfire Chat API wrapper.""" import os -from typing import List +from typing import Any, List import pytest from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage @@ -14,7 +14,16 @@ os.environ["GOODFIRE_API_KEY"] = "test_key" -VALID_MODEL: str = "meta-llama/Llama-3.3-70B-Instruct" + +def get_valid_variant() -> Any: + try: + import goodfire + except ImportError as e: + raise ImportError( + "Could not import goodfire python package. " + "Please install it with `pip install goodfire`." + ) from e + return goodfire.Variant("meta-llama/Llama-3.3-70B-Instruct") @pytest.mark.requires("goodfire") @@ -26,9 +35,10 @@ def test_goodfire_model_param() -> None: "Could not import goodfire python package. " "Please install it with `pip install goodfire`." ) from e - llm = ChatGoodfire(model=VALID_MODEL) - assert isinstance(llm.variant, goodfire.Variant) - assert llm.variant.base_model == VALID_MODEL + base_variant = get_valid_variant() + llm = ChatGoodfire(model=base_variant) + assert isinstance(llm.model, goodfire.Variant) + assert llm.model.base_model == base_variant.base_model @pytest.mark.requires("goodfire") @@ -41,7 +51,7 @@ def test_goodfire_initialization() -> None: "Could not import goodfire python package. " "Please install it with `pip install goodfire`." ) from e - llm = ChatGoodfire(model=VALID_MODEL, goodfire_api_key="test_key") + llm = ChatGoodfire(model=get_valid_variant(), goodfire_api_key="test_key") assert llm.goodfire_api_key.get_secret_value() == "test_key" assert isinstance(llm.sync_client, goodfire.Client) assert isinstance(llm.async_client, goodfire.AsyncClient)