Skip to content

Commit

Permalink
Merge pull request #5 from ej52/dev
Browse files Browse the repository at this point in the history
refactor and add more options
  • Loading branch information
ej52 authored Nov 10, 2023
2 parents 32e3cfc + 3c5a344 commit ff61636
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 216 deletions.
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,23 @@ Options for Ollama Conversation can be set via the user interface, by taking the
* If multiple instances of Ollama Conversation are configured, choose the instance you want to configure.
* Select the integration, then select ___Configure___.

#### System Prompt
The starting text for the AI language model to generate new text from. This text can include information about your Home Assistant instance, devices, and areas and is written using Home Assistant Templating.

#### Model Configuration
The language model and additional parameters to fine tune the responses.

| Option | Description |
| ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Prompt Template | The starting text for the AI language model to generate new text from. This text can include information about your Home Assistant instance, devices, and areas and is written using Home Assistant Templating. |
| Completion Model | The model used to generate response. |
| Model | The model used to generate response. |
| Context Size | Sets the size of the context window used to generate the next token. |
| Maximum Tokens | The maximum number of words or “tokens” that the AI model should generate in its completion of the prompt. |
| Temperature | The temperature of the model. A higher value (e.g., 0.95) will lead to more unexpected results, while a lower value (e.g. 0.5) will be more deterministic results. |
| Top K | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. |
| Top P | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. |
| Mirostat Mode | Enable Mirostat sampling for controlling perplexity. |
| Mirostat ETA | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. |
| Mirostat TAU | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. |
| Temperature | The temperature of the model. A higher value (e.g., 0.95) will lead to more unexpected results, while a lower value (e.g. 0.5) will be more deterministic results. |
| Top K | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. |
| Top P | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. | |


## Contributions are welcome!
Expand All @@ -70,3 +77,4 @@ Discussions for this integration over on [Home Assistant Community][discussions]
[ollama]: https://ollama.ai/
[ollama-github]: https://github.com/jmorganca/ollama
[sentence-trigger]: https://www.home-assistant.io/docs/automation/trigger/#sentence-trigger
[discussions]: https://community.home-assistant.io/t/custom-integration-ollama-conversation-local-ai-agent/636103
161 changes: 99 additions & 62 deletions custom_components/ollama_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,54 @@
"""
from __future__ import annotations

import json
from typing import Literal

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryAuthFailed, TemplateError
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError, TemplateError
from homeassistant.helpers import intent, template
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util import ulid

from .api import (
OllamaApiClient,
OllamaApiClientAuthenticationError,
OllamaApiClientError,
)
from .api import OllamaApiClient
from .const import (
DOMAIN, LOGGER,

CONF_BASE_URL,
CONF_CHAT_MODEL,
CONF_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
CONF_MODEL,
CONF_CTX_SIZE,
CONF_MAX_TOKENS,
CONF_MIROSTAT_MODE,
CONF_MIROSTAT_ETA,
CONF_MIROSTAT_TAU,
CONF_TEMPERATURE,
CONF_REPEAT_PENALTY,
CONF_TOP_K,
CONF_TOP_P,
CONF_PROMPT_SYSTEM,

DEFAULT_CHAT_MODEL,
DEFAULT_PROMPT,
DEFAULT_MODEL,
DEFAULT_CTX_SIZE,
DEFAULT_MAX_TOKENS,
DEFAULT_MIROSTAT_MODE,
DEFAULT_MIROSTAT_ETA,
DEFAULT_MIROSTAT_TAU,
DEFAULT_TEMPERATURE,
DEFAULT_REPEAT_PENALTY,
DEFAULT_TOP_K,
DEFAULT_TOP_P
DEFAULT_TOP_P,
DEFAULT_PROMPT_SYSTEM
)
from .coordinator import OllamaDataUpdateCoordinator
from .exceptions import (
ApiClientError,
ApiCommError,
ApiJsonError,
ApiTimeoutError
)
from .helpers import get_exposed_entities

# https://developers.home-assistant.io/docs/config_entries_index/#setting-up-an-entry
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
Expand All @@ -63,10 +73,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
try:
response = await client.async_get_heartbeat()
if not response:
raise OllamaApiClientError("Invalid Ollama server")
except OllamaApiClientAuthenticationError as exception:
raise ConfigEntryAuthFailed(exception) from exception
except OllamaApiClientError as err:
raise ApiClientError("Invalid Ollama server")
except ApiClientError as err:
raise ConfigEntryNotReady(err) from err

entry.async_on_unload(entry.add_update_listener(async_reload_entry))
Expand Down Expand Up @@ -95,12 +103,7 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry, client: OllamaApiCli
self.hass = hass
self.entry = entry
self.client = client
self.history: dict[str, list[dict]] = {}

@property
def attribution(self):
"""Return the attribution."""
return {"name": "Powered by Ollama", "url": "https://github.com/ej52/hass-ollama-conversation"}
self.history: dict[str, dict] = {}

@property
def supported_languages(self) -> list[str] | Literal["*"]:
Expand All @@ -111,73 +114,107 @@ async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""

intent_response = intent.IntentResponse(language=user_input.language)

model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
raw_system_prompt = self.entry.options.get(CONF_PROMPT_SYSTEM, DEFAULT_PROMPT_SYSTEM)
exposed_entities = get_exposed_entities(self.hass)

if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
context = self.history[conversation_id]
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid()
context = None
try:
system_prompt = self._async_generate_prompt(raw_system_prompt, exposed_entities)
except TemplateError as err:
LOGGER.error("Error rendering system prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"I had a problem with my system prompt, please check the logs for more information.",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = {
"system": system_prompt,
"context": None,
}

messages["prompt"] = user_input.text

try:
system_prompt = self._async_generate_prompt(prompt)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
response = await self.query(messages)
except (
ApiCommError,
ApiJsonError,
ApiTimeoutError
) as err:
LOGGER.error("Error generating prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
f"Something went wrong, {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

payload = {
"model": model,
"context": context,
"system": system_prompt,
"prompt": user_input.text,
"stream": False,
"options": {
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
"num_ctx": self.entry.options.get(CONF_CTX_SIZE, DEFAULT_CTX_SIZE),
"num_predict": self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS),
"temperature": self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
}
}

LOGGER.debug("Prompt for %s: %s", model, json.dumps(payload))

try:
result = await self.client.async_generate(payload)
except OllamaApiClientError as err:
except HomeAssistantError as err:
LOGGER.error("Something went wrong: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
"Something went wrong, please check the logs for more information.",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

LOGGER.debug("Response %s", json.dumps(result))

self.history[conversation_id] = result["context"]
intent_response.async_set_speech(result["response"])
messages["context"] = response["context"]
self.history[conversation_id] = messages

intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["response"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

def _async_generate_prompt(self, raw_prompt: str) -> str:
def _async_generate_prompt(self, raw_prompt: str, exposed_entities) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"exposed_entities": exposed_entities,
},
parse_result=False,
)

async def query(
self,
messages
):
"""Process a sentence."""
model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL)

LOGGER.debug("Prompt for %s: %s", model, messages["prompt"])

result = await self.client.async_generate({
"model": model,
"context": messages["context"],
"system": messages["system"],
"prompt": messages["prompt"],
"stream": False,
"options": {
"mirostat": int(self.entry.options.get(CONF_MIROSTAT_MODE, DEFAULT_MIROSTAT_MODE)),
"mirostat_eta": self.entry.options.get(CONF_MIROSTAT_ETA, DEFAULT_MIROSTAT_ETA),
"mirostat_tau": self.entry.options.get(CONF_MIROSTAT_TAU, DEFAULT_MIROSTAT_TAU),
"num_ctx": self.entry.options.get(CONF_CTX_SIZE, DEFAULT_CTX_SIZE),
"num_predict": self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS),
"temperature": self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
"repeat_penalty": self.entry.options.get(CONF_REPEAT_PENALTY, DEFAULT_REPEAT_PENALTY),
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
}
})

response: str = result["response"]
LOGGER.debug("Response %s", response)
return result
59 changes: 22 additions & 37 deletions custom_components/ollama_conversation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,16 @@
import async_timeout

from .const import TIMEOUT


class OllamaApiClientError(Exception):
"""Exception to indicate a general API error."""


class OllamaApiClientCommunicationError(
OllamaApiClientError
):
"""Exception to indicate a communication error."""


class OllamaApiClientAuthenticationError(
OllamaApiClientError
):
"""Exception to indicate an authentication error."""
from .exceptions import (
ApiClientError,
ApiCommError,
ApiJsonError,
ApiTimeoutError
)


class OllamaApiClient:
"""Sample API Client."""
"""Ollama API Client."""

def __init__(
self,
Expand All @@ -40,10 +30,10 @@ def __init__(

async def async_get_heartbeat(self) -> bool:
"""Get heartbeat from the API."""
response = await self._api_wrapper(
response: str = await self._api_wrapper(
method="get", url=self._base_url, decode_json=False
)
return response == "Ollama is running"
return response.strip() == "Ollama is running"

async def async_get_models(self) -> any:
"""Get models from the API."""
Expand Down Expand Up @@ -78,28 +68,23 @@ async def _api_wrapper(
method=method,
url=url,
headers=headers,
raise_for_status=True,
json=data,
)

if response.status in (401, 403):
raise OllamaApiClientAuthenticationError(
"Invalid credentials",
)
if response.status == 404 and decode_json:
json = await response.json()
raise ApiJsonError(json["error"])

response.raise_for_status()

if decode_json:
return await response.json()
return await response.text()

except asyncio.TimeoutError as exception:
raise OllamaApiClientCommunicationError(
"Timeout error fetching information",
) from exception
except (aiohttp.ClientError, socket.gaierror) as exception:
raise OllamaApiClientCommunicationError(
"Error fetching information",
) from exception
except Exception as exception: # pylint: disable=broad-except
raise OllamaApiClientError(
"Something really wrong happened!"
) from exception
except ApiJsonError as e:
raise e
except asyncio.TimeoutError as e:
raise ApiTimeoutError("timeout while talking to the server") from e
except (aiohttp.ClientError, socket.gaierror) as e:
raise ApiCommError("unknown error while talking to the server") from e
except Exception as e: # pylint: disable=broad-except
raise ApiClientError("something really went wrong!") from e
Loading

0 comments on commit ff61636

Please sign in to comment.