Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Ollama support #148

Merged
merged 6 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions basilisk/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def engine_cls(self) -> Type[BaseEngine]:
engine_cls_path="basilisk.provider_engine.mistralai_engine.MistralAIEngine",
allow_custom_base_url=True,
),
Provider(
id="ollama",
name="Ollama",
base_url="http://127.0.0.1:11434",
api_type=ProviderAPIType.OLLAMA,
organization_mode_available=False,
require_api_key=False,
env_var_name_api_key="OLLAMA_API_KEY",
engine_cls_path="basilisk.provider_engine.ollama_engine.OllamaEngine",
allow_custom_base_url=True,
),
Comment on lines +119 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Consider robust handling of local vs. remote usage.

With require_api_key=False and a local base URL, the provider relies on an accessible local Ollama instance. If a remote configuration is desired in the future, consider providing fallback or guidance for scenarios where the local server is not running, improving user experience.

Provider(
id="openai",
name="OpenAI",
Expand Down
167 changes: 167 additions & 0 deletions basilisk/provider_engine/ollama_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Ollama provider engine implementation."""

import json
import logging
from functools import cached_property
from typing import Iterator

from ollama import ChatResponse, Client

from basilisk.conversation import (
Conversation,
ImageFileTypes,
Message,
MessageBlock,
MessageRoleEnum,
)
from basilisk.decorators import measure_time

from .base_engine import BaseEngine, ProviderAIModel, ProviderCapability

log = logging.getLogger(__name__)


class OllamaEngine(BaseEngine):
"""Engine implementation for Ollama API integration."""

capabilities: set[ProviderCapability] = {
ProviderCapability.TEXT,
ProviderCapability.IMAGE,
}
Comment on lines +24 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Add unit tests for the new engine.

This engine class is critical for Ollama API integration yet it lacks direct test coverage. Please consider adding corresponding tests to maintain code reliability and facilitate future refactoring.

Would you like me to provide an initial test file scaffold for OllamaEngine?


@cached_property
@measure_time
def models(self) -> list[ProviderAIModel]:
"""Get Ollama models.

Returns:
A list of provider AI models.
"""
models = []
models_list = self.client.list().models
for model in models_list:
info = self.client.show(model.model)
context_length = 0
description = json.dumps(info.modelinfo, indent=2)
description += f"\n\n{info.license}"
for k, v in info.modelinfo.items():
if k.endswith("context_length"):
context_length = v
models.append(
ProviderAIModel(
id=model.model,
name=model.model,
description=description,
context_window=context_length,
max_output_tokens=0,
max_temperature=2,
default_temperature=1,
vision=True,
)
)

return models

Comment on lines +32 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Validate handling of nullable or missing fields.
The logic for parsing info.license and info.modelinfo (especially context length) looks good. However, consider robust error handling or a helpful fallback if these fields are missing or unexpected in the Ollama response.

@cached_property
def client(self) -> Client:
"""Get Ollama client.

Returns:
The Ollama client instance.
"""
base_url = self.account.custom_base_url or str(
self.account.provider.base_url
)
log.info(f"Base URL: {base_url}")
return Client(host=base_url)

def completion(
self,
new_block: MessageBlock,
conversation: Conversation,
system_message: Message | None,
**kwargs,
) -> ChatResponse | Iterator[ChatResponse]:
"""Get completion from Ollama.

Args:
new_block: The new message block.
conversation: The conversation instance.
system_message: The system message, if any.
**kwargs: Additional keyword arguments.

Returns:
The chat response or an iterator of chat responses.
"""
super().completion(new_block, conversation, system_message, **kwargs)
params = {
"model": new_block.model.model_id,
"messages": self.get_messages(
new_block, conversation, system_message
),
"stream": new_block.stream,
}
params.update(kwargs)
return self.client.chat(**params)

def prepare_message_request(self, message: Message):
"""Prepare message request for Ollama.

Args:
message: The message to prepare.

Returns:
The prepared message request.
"""
super().prepare_message_request(message)
images = []
if message.attachments:
for attachment in message.attachments:
if attachment.type == ImageFileTypes.IMAGE_URL:
log.warning(
f"Received unsupported image type: {attachment.type}, {attachment.location}"
)
raise NotImplementedError(
"images URL are not supported for Ollama"
)
images.append(attachment.encode_image())
return {
"role": message.role.value,
"content": message.content,
"images": images,
}

prepare_message_response = prepare_message_request

def completion_response_with_stream(self, stream):
"""Process a streaming completion response.

Args:
stream: The stream of chat completion responses.

Returns:
An iterator of the completion response content.
"""
for chunk in stream:
content = chunk.get("message", {}).get("content")
if content:
yield content

def completion_response_without_stream(
self, response, new_block: MessageBlock, **kwargs
) -> MessageBlock:
"""Process a non-streaming completion response.

Args:
response: The chat completion response.
new_block: The message block to update with the response.
**kwargs: Additional keyword arguments.

Returns:
The updated message block with the response.
"""
new_block.response = Message(
role=MessageRoleEnum.ASSISTANT,
content=response["message"]["content"],
)
return new_block
Comment on lines +164 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add fallback or error handling for missing content in responses.

This line can raise a KeyError if "message" or "content" are absent from the response. Consider adding a safe fallback or raising a more descriptive error for debugging.

Here’s a potential fix:

- content=response["message"]["content"],
+ content=response.get("message", {}).get("content") or ""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
role=MessageRoleEnum.ASSISTANT,
content=response["message"]["content"],
)
return new_block
role=MessageRoleEnum.ASSISTANT,
- content=response["message"]["content"],
+ content=response.get("message", {}).get("content") or "",
)
return new_block

Loading
Loading