diff --git a/pyproject.toml b/pyproject.toml index 1e61129..59ff5b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ retrying = "^1.3.4" anthropic_support = ["anthropic"] hf_support = ["accelerate", "datasets", "torch", "transformers"] vllm_support = ["accelerate", "torch", "vllm"] +ollama_support = ["litellm"] all = [ # anthropic @@ -37,6 +38,7 @@ all = [ "datasets", "torch", "transformers", + "litellm" ] all_with_extras = [ # all @@ -49,6 +51,8 @@ all_with_extras = [ "transformers", # More Extras "vllm", + # ollama + "litellm" ] # To export dependencies to pip, use: # poetry export -f requirements.txt --with dev --without-hashes --output requirements-dev.txt diff --git a/synthesizer/core/base.py b/synthesizer/core/base.py index 0bb2e55..b4f79db 100644 --- a/synthesizer/core/base.py +++ b/synthesizer/core/base.py @@ -13,6 +13,7 @@ class LLMProviderName(Enum): LLAMACPP = "llamacpp" LITE_LLM = "lite-llm" SCIPHI = "sciphi" + OLLAMA = "ollama" class RAGProviderName(Enum): diff --git a/synthesizer/interface/__init__.py b/synthesizer/interface/__init__.py index 06ef46a..fc01331 100644 --- a/synthesizer/interface/__init__.py +++ b/synthesizer/interface/__init__.py @@ -12,6 +12,7 @@ from synthesizer.interface.llm.sciphi_interface import SciPhiLLMInterface from synthesizer.interface.llm.vllm_interface import vLLMInterface from synthesizer.interface.llm_interface_manager import LLMInterfaceManager +from synthesizer.interface.llm.ollama_interface import OllamaLLMInterface from synthesizer.interface.rag.agent_search import ( AgentSearchRAGConfig, AgentSearchRAGInterface, @@ -38,6 +39,7 @@ "OpenAILLMInterface", "SciPhiLLMInterface", "vLLMInterface", + "OllamaLLMInterface", # RAG "RAGInterfaceManager", "RAGProviderConfig", diff --git a/synthesizer/interface/llm/ollama_interface.py b/synthesizer/interface/llm/ollama_interface.py new file mode 100644 index 0000000..4554542 --- /dev/null +++ b/synthesizer/interface/llm/ollama_interface.py @@ -0,0 +1,60 @@ +"""A module for interfacing with Ollama""" +import logging + +from synthesizer.interface.base import LLMInterface, LLMProviderName +from synthesizer.interface.llm_interface_manager import llm_interface +from synthesizer.llm import GenerationConfig, OllamaConfig, OllamaLLM + +logger = logging.getLogger(__name__) + + +@llm_interface +class OllamaLLMInterface(LLMInterface): + """A class to interface with Ollama.""" + + provider_name = LLMProviderName.OLLAMA + system_message = "You are a helpful assistant." + + def __init__( + self, + config: OllamaConfig, + *args, + **kwargs, + ) -> None: + self.config = config + self._model = OllamaLLM(config) + + def get_completion( + self, prompt: str, generation_config: GenerationConfig + ) -> str: + """Get a completion from the Ollama based on the provided prompt.""" + + logger.debug( + f"Getting completion from Ollama for model={generation_config.model_name}" + ) + if "instruct" in generation_config.model_name: + return self.model.get_instruct_completion( + prompt, generation_config + ) + else: + return self._model.get_chat_completion( + [ + { + "role": "system", + "content": OllamaLLMInterface.system_message, + }, + {"role": "user", "content": prompt}, + ], + generation_config, + ) + + def get_chat_completion( + self, conversation: list[dict], generation_config: GenerationConfig + ) -> str: + raise NotImplementedError( + "Chat completion not yet implemented for Ollama." + ) + + @property + def model(self) -> OllamaLLM: + return self._model diff --git a/synthesizer/interface/rag/bing_search/base.py b/synthesizer/interface/rag/bing_search/base.py index d8832ad..1ce7c7c 100644 --- a/synthesizer/interface/rag/bing_search/base.py +++ b/synthesizer/interface/rag/bing_search/base.py @@ -37,7 +37,7 @@ def __init__( ) -> None: super().__init__(config) self.config: BingRAGConfig = config - print('self.config = ', self.config) + print("self.config = ", self.config) api_key = self.config.api_key or os.getenv("BING_API_KEY") if not api_key: raise ValueError( diff --git a/synthesizer/llm/__init__.py b/synthesizer/llm/__init__.py index f4b2a04..90d9b5a 100644 --- a/synthesizer/llm/__init__.py +++ b/synthesizer/llm/__init__.py @@ -8,6 +8,7 @@ from synthesizer.llm.models.openai_llm import OpenAIConfig, OpenAILLM from synthesizer.llm.models.sciphi_llm import SciPhiConfig, SciPhiLLM from synthesizer.llm.models.vllm_llm import vLLM, vLLMConfig +from synthesizer.llm.models.ollama_llm import OllamaConfig, OllamaLLM __all__ = [ # Base @@ -26,4 +27,6 @@ "SciPhiLLM", "vLLMConfig", "vLLM", + "OllamaConfig", + "OllamaLLM", ] diff --git a/synthesizer/llm/models/ollama_llm.py b/synthesizer/llm/models/ollama_llm.py new file mode 100644 index 0000000..e76cf15 --- /dev/null +++ b/synthesizer/llm/models/ollama_llm.py @@ -0,0 +1,65 @@ +"""A module for creating Ollama model abstractions.""" +import os +from dataclasses import dataclass + +from litellm import completion + +from synthesizer.core import LLMProviderName +from synthesizer.llm.base import LLM, GenerationConfig, LLMConfig +from synthesizer.llm.config_manager import model_config + + +@model_config +@dataclass +class OllamaConfig(LLMConfig): + """Configuration for Ollama models.""" + + # Base + provider_name: LLMProviderName = LLMProviderName.OLLAMA + api_base: str = "http://localhost:11434" + + +class OllamaLLM(LLM): + """A concrete class for creating Ollama models.""" + + def __init__( + self, + config: OllamaConfig, + *args, + **kwargs, + ) -> None: + super().__init__() + self.config: OllamaConfig = config + + # set the config here, again, for typing purposes + if not isinstance(self.config, OllamaConfig): + raise ValueError( + "The provided config must be an instance of OllamaConfig." + ) + + def get_chat_completion( + self, + messages: list[dict[str, str]], + generation_config: GenerationConfig, + ) -> str: + """Get a chat completion from Ollama based on the provided prompt.""" + + # Create the chat completion + response = completion( + model="ollama/mistral", + messages=messages, + api_base=self.config.api_base, + stream=generation_config.do_stream, + ) + + return response.choices[0].message["content"] + + def get_instruct_completion( + self, + messages: list[dict[str, str]], + generation_config: GenerationConfig, + ) -> str: + """Get an instruction completion from Ollama.""" + raise NotImplementedError( + "Instruction completion is not yet supported for Ollama." + )