Skip to content

Commit

Permalink
Merge pull request #82 from BatsResearch/add-ollama-support
Browse files Browse the repository at this point in the history
Add ollama support
  • Loading branch information
dotpyu authored Sep 24, 2024
2 parents a513883 + c307750 commit c668172
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 14 deletions.
15 changes: 11 additions & 4 deletions alfred/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
"groq",
"torch",
"openllm",
"ollama",
"dummy",
], f"Invalid model type: {self.model_type}"
else:
Expand All @@ -100,7 +101,7 @@ def __init__(
self.run = self.cache.cached_query(self.run)

self.grpcClient = None
if end_point and model_type not in ["dummy", "openllm", ]:
if end_point and model_type not in ["dummy", "openllm", "ollama",]:
end_point_pieces = end_point.split(":")
self.end_point_ip, self.end_point_port = (
"".join(end_point_pieces[:-1]),
Expand Down Expand Up @@ -186,6 +187,12 @@ def __init__(

base_url = kwargs.get("base_url", end_point)
self.model = OpenLLMModel(self.model, base_url=base_url, **kwargs)
elif self.model_type == "ollama":
from ..fm.ollama import OllamaModel

if not model and end_point:
model = end_point
self.model = OllamaModel(model)
elif self.model_type == "cohere":
from ..fm.cohere import CohereModel

Expand Down Expand Up @@ -438,12 +445,12 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any):
:param log_save_path: The file to save the chat logs.
:type log_save_path: Optional[str]
"""
if self.model_type in ["openai", "anthropic", "google", "huggingface", "groq"]:
if self.model_type in ["openai", "anthropic", "google", "huggingface", "groq", "ollama"]:
self.model.chat(log_save_path=log_save_path, **kwargs)
else:
logger.error(
"Chat APIs are only supported for Anthropic, Google Gemini and OpenAI models."
"Chat APIs are only supported for Anthropic, Google, OpenAI, HuggingFace, Ollama and Groq models."
)
raise NotImplementedError(
"Currently Chat are only supported for Anthropic, Google Gemini and OpenAI models."
"Chat APIs are only supported for Anthropic, Google, OpenAI, HuggingFace, Ollama and Groq models."
)
2 changes: 1 addition & 1 deletion alfred/fm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _generate_batch(
:rtype List[Response]
"""
raise NotImplementedError(
f"_infer_batch() is not implemented for {self.__class__.__name__}"
f"_generate_batch() is not implemented for {self.__class__.__name__}"
)

def _score_batch(
Expand Down
162 changes: 162 additions & 0 deletions alfred/fm/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import logging
from typing import List, Any, Dict
import re
from .model import LocalAccessFoundationModel
from .response import CompletionResponse

logger = logging.getLogger(__name__)

try:
import ollama
except ImportError:
raise ImportError("Please install Ollama with `pip install ollama`")

from .utils import colorize_str, retry, encode_image, type_print

class OllamaModel(LocalAccessFoundationModel):
"""
OllamaModel wraps an Ollama model. Ollama is a library for easy integration with large language models.
source: https://github.com/ollama/ollama
"""

def __init__(self, model: str, **kwargs: Any):
"""
Initialize an Ollama model.
:param model: The name or path of the model to use.
:type model: str
"""

def is_url(string):
url_pattern = re.compile(
r'^(?:'
r'(?:http)s?://'
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'
r'localhost|'
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
r'(?::\d+)?'
r'(?:/?|[/?]\S+)'
r'|'
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}(?::\d+)?'
r')$',
re.IGNORECASE
)
return bool(url_pattern.match(string))

self.model_string = model

super().__init__(model)
# if model is ip address then launch host
if is_url(model):
self.client = ollama.Client(host=host)
else:
self.client = ollama

def _generate_batch(
self,
batch_instance: List[str],
**kwargs: Any,
) -> List[CompletionResponse]:
"""
Generate completions for a batch of queries.
:param batch_instance: A list of queries.
:type batch_instance: List[str]
:param kwargs: Additional keyword arguments.
:return: A list of `CompletionResponse` objects with the generated content.
:rtype: List[CompletionResponse]
"""

responses = []
for query in batch_instance:
response = self.client.generate(
model=self.model_string,
prompt=query,
)
responses.append(CompletionResponse(prediction=response['response']))

return responses

def chat(self, **kwargs: Any):
"""
Launch an interactive chat session with the Ollama API.
"""

def _feedback(feedback: str, no_newline=False):
print(
colorize_str("Chat AI: ", "GREEN") + feedback,
end="\n" if not no_newline else "",
)

model = kwargs.get("model", self.model_string)
c_title = colorize_str("Alfred's Ollama Chat", "BLUE")
c_model = colorize_str(model, "WARNING")
c_exit = colorize_str("exit", "FAIL")
c_ctrlc = colorize_str("Ctrl+C", "FAIL")

temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 1024)
log_save_path = kwargs.get("log_save_path", None)
manual_chat_sequence = kwargs.get("manual_chat_sequence", None)

print(f"Welcome to the {c_title} session!\nYou are using the {c_model} model.")
print(f"Type '{c_exit}' or hit {c_ctrlc} to exit the chat session.")

message_log = [
{
"role": "system",
"content": "You are an intelligent assistant. Please answer the user with professional language.",
}
]

print()
print("======== Chat Begin ========")
print()

try:
while True:
if manual_chat_sequence is not None:
query = manual_chat_sequence.pop(0)
_feedback(query, no_newline=True)
print()
if len(manual_chat_sequence) == 0:
break
else:
query = input(colorize_str("You: "))
if query.lower() == "exit":
_feedback("Goodbye!")
break
message_log.append({"role": "user", "content": query})
_feedback("", no_newline=True)

response = self.client.chat(
model=self.model_string,
messages=message_log,
stream=True,
)

full_response = []
for chunk in response:
try:
txt = chunk['message']['content']
type_print(txt)
full_response.append(txt)
except KeyError:
pass
print()

full_response = "".join(full_response).strip()
message_log.append({"role": "assistant", "content": full_response})
except KeyboardInterrupt:
_feedback("Goodbye!")

print()
print("======== Chat End ========")
print()
print(colorize_str("Thank you for using Alfred!"))

if log_save_path:
with open(log_save_path, "w") as f:
json.dump(message_log, f)
print(f"Your chat log is saved to {log_save_path}")
5 changes: 5 additions & 0 deletions alfred/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
"groq",
"google",
"torch",
"ollama",
"dummy",
], f"Invalid model type: {self.model_type}"
if self.model_type == "huggingface":
Expand Down Expand Up @@ -97,6 +98,10 @@ def __init__(
from .fm.groq import GroqModel

self.model = GroqModel(self.model, **kwargs)
elif self.model_type == "ollama":
from ..fm.ollama import OllamaModel

self.model = OllamaModel(self.model)
elif self.model_type == "dummy":
from .fm.dummy import DummyModel

Expand Down
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ A full list of `Alfred` project modules.
- [Huggingfacedocument](alfred/fm/huggingfacedocument.md#huggingfacedocument)
- [Huggingfacevlm](alfred/fm/huggingfacevlm.md#huggingfacevlm)
- [Model](alfred/fm/model.md#model)
- [Ollama](alfred/fm/ollama.md#ollama)
- [Onnx](alfred/fm/onnx.md#onnx)
- [Openai](alfred/fm/openai.md#openai)
- [Openllm](alfred/fm/openllm.md#openllm)
Expand Down
16 changes: 8 additions & 8 deletions docs/alfred/client/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Client:

### Client().__call__

[Show source in client.py:319](../../../alfred/client/client.py#L319)
[Show source in client.py:326](../../../alfred/client/client.py#L326)

__call__() function to run the model on the queries.
Equivalent to run() function.
Expand All @@ -71,7 +71,7 @@ def __call__(

### Client().calibrate

[Show source in client.py:335](../../../alfred/client/client.py#L335)
[Show source in client.py:342](../../../alfred/client/client.py#L342)

calibrate are used to calibrate foundation models contextually given the template.
A voter class may be passed to calibrate the model with a specific voter.
Expand Down Expand Up @@ -115,7 +115,7 @@ def calibrate(

### Client().chat

[Show source in client.py:433](../../../alfred/client/client.py#L433)
[Show source in client.py:440](../../../alfred/client/client.py#L440)

Chat with the model APIs.
Currently, Alfred supports Chat APIs from Anthropic and OpenAI
Expand All @@ -133,7 +133,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any): ...

### Client().encode

[Show source in client.py:407](../../../alfred/client/client.py#L407)
[Show source in client.py:414](../../../alfred/client/client.py#L414)

embed() function to embed the queries.

Expand All @@ -155,7 +155,7 @@ def encode(

### Client().generate

[Show source in client.py:278](../../../alfred/client/client.py#L278)
[Show source in client.py:285](../../../alfred/client/client.py#L285)

Wrapper function to generate the response(s) from the model. (For completion)

Expand Down Expand Up @@ -183,7 +183,7 @@ def generate(

### Client().remote_run

[Show source in client.py:252](../../../alfred/client/client.py#L252)
[Show source in client.py:259](../../../alfred/client/client.py#L259)

Wrapper function for running the model on the queries thru a gRPC Server.

Expand All @@ -209,7 +209,7 @@ def remote_run(

### Client().run

[Show source in client.py:232](../../../alfred/client/client.py#L232)
[Show source in client.py:239](../../../alfred/client/client.py#L239)

Run the model on the queries.

Expand All @@ -235,7 +235,7 @@ def run(

### Client().score

[Show source in client.py:295](../../../alfred/client/client.py#L295)
[Show source in client.py:302](../../../alfred/client/client.py#L302)

Wrapper function to score the response(s) from the model. (For ranking)

Expand Down
1 change: 1 addition & 0 deletions docs/alfred/fm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- [Huggingfacedocument](./huggingfacedocument.md)
- [Huggingfacevlm](./huggingfacevlm.md)
- [Model](./model.md)
- [Ollama](./ollama.md)
- [Onnx](./onnx.md)
- [Openai](./openai.md)
- [Openllm](./openllm.md)
Expand Down
2 changes: 1 addition & 1 deletion docs/alfred/run_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ModelServer:

## start_server

[Show source in run_server.py:132](../../alfred/run_server.py#L132)
[Show source in run_server.py:137](../../alfred/run_server.py#L137)

Wrapper function to start gRPC Server.

Expand Down

0 comments on commit c668172

Please sign in to comment.