This repository has been archived by the owner on Feb 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d45900f
commit 49de39d
Showing
7 changed files
with
136 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""A module for interfacing with Ollama""" | ||
import logging | ||
|
||
from synthesizer.interface.base import LLMInterface, LLMProviderName | ||
from synthesizer.interface.llm_interface_manager import llm_interface | ||
from synthesizer.llm import GenerationConfig, OllamaConfig, OllamaLLM | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@llm_interface | ||
class OllamaLLMInterface(LLMInterface): | ||
"""A class to interface with Ollama.""" | ||
|
||
provider_name = LLMProviderName.OLLAMA | ||
system_message = "You are a helpful assistant." | ||
|
||
def __init__( | ||
self, | ||
config: OllamaConfig, | ||
*args, | ||
**kwargs, | ||
) -> None: | ||
self.config = config | ||
self._model = OllamaLLM(config) | ||
|
||
def get_completion( | ||
self, prompt: str, generation_config: GenerationConfig | ||
) -> str: | ||
"""Get a completion from the Ollama based on the provided prompt.""" | ||
|
||
logger.debug( | ||
f"Getting completion from Ollama for model={generation_config.model_name}" | ||
) | ||
if "instruct" in generation_config.model_name: | ||
return self.model.get_instruct_completion( | ||
prompt, generation_config | ||
) | ||
else: | ||
return self._model.get_chat_completion( | ||
[ | ||
{ | ||
"role": "system", | ||
"content": OllamaLLMInterface.system_message, | ||
}, | ||
{"role": "user", "content": prompt}, | ||
], | ||
generation_config, | ||
) | ||
|
||
def get_chat_completion( | ||
self, conversation: list[dict], generation_config: GenerationConfig | ||
) -> str: | ||
raise NotImplementedError( | ||
"Chat completion not yet implemented for Ollama." | ||
) | ||
|
||
@property | ||
def model(self) -> OllamaLLM: | ||
return self._model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""A module for creating Ollama model abstractions.""" | ||
import os | ||
from dataclasses import dataclass | ||
|
||
from litellm import completion | ||
|
||
from synthesizer.core import LLMProviderName | ||
from synthesizer.llm.base import LLM, GenerationConfig, LLMConfig | ||
from synthesizer.llm.config_manager import model_config | ||
|
||
|
||
@model_config | ||
@dataclass | ||
class OllamaConfig(LLMConfig): | ||
"""Configuration for Ollama models.""" | ||
|
||
# Base | ||
provider_name: LLMProviderName = LLMProviderName.OLLAMA | ||
api_base: str = "http://localhost:11434" | ||
|
||
|
||
class OllamaLLM(LLM): | ||
"""A concrete class for creating Ollama models.""" | ||
|
||
def __init__( | ||
self, | ||
config: OllamaConfig, | ||
*args, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
self.config: OllamaConfig = config | ||
|
||
# set the config here, again, for typing purposes | ||
if not isinstance(self.config, OllamaConfig): | ||
raise ValueError( | ||
"The provided config must be an instance of OllamaConfig." | ||
) | ||
|
||
def get_chat_completion( | ||
self, | ||
messages: list[dict[str, str]], | ||
generation_config: GenerationConfig, | ||
) -> str: | ||
"""Get a chat completion from Ollama based on the provided prompt.""" | ||
|
||
# Create the chat completion | ||
response = completion( | ||
model="ollama/mistral", | ||
messages=messages, | ||
api_base=self.config.api_base, | ||
stream=generation_config.do_stream, | ||
) | ||
|
||
return response.choices[0].message["content"] | ||
|
||
def get_instruct_completion( | ||
self, | ||
messages: list[dict[str, str]], | ||
generation_config: GenerationConfig, | ||
) -> str: | ||
"""Get an instruction completion from Ollama.""" | ||
raise NotImplementedError( | ||
"Instruction completion is not yet supported for Ollama." | ||
) |