Skip to content

Commit

Permalink
Release v2.1.1: Token Handling Improvement and DeepSeek Support
Browse files Browse the repository at this point in the history
- Completely reworked token handling mechanism
- Removed custom token calculation logic
- Direct max_tokens passing to LLM APIs
- Added support for DeepSeek provider
- Integrated deepseek-chat and deepseek-reasoner models

Thanks to @estiens for reporting token handling issues and providing valuable feedback (#1).
  • Loading branch information
SMKRV committed Jan 28, 2025
1 parent 82e1f0c commit bfd64d1
Show file tree
Hide file tree
Showing 16 changed files with 125 additions and 148 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Transform your smart home experience with powerful AI assistance powered by mult
- Active API key from:
- OpenAI ([Get key](https://platform.openai.com/account/api-keys))
- Anthropic ([Get key](https://console.anthropic.com/))
- DeepSeek 🆕 ([Get key](https://platform.deepseek.com/api_keys))
- OpenRouter ([Get key](https://openrouter.ai/keys))
- Any OpenAI-compatible API provider
- Python 3.9 or newer
Expand All @@ -120,7 +121,7 @@ Transform your smart home experience with powerful AI assistance powered by mult
- 🔑 **API Key**: Provider-specific authentication
- 🤖 **Model Selection**: Flexible, provider-specific models
- 🌡️ **Temperature**: Creativity control (0.0-2.0)
- 📏 **Max Tokens**: Response length limit (token usage is estimated using a heuristic method based on word count and specific word characteristics, which may differ from actual token usage)
- 📏 **Max Tokens**: Response length limit (passed directly to the LLM API to control the maximum length of the response)
- ⏱️ **Request Interval**: API call throttling
- 💾 **History Size**: Number of messages to retain
- 🌍 **Custom API Endpoint**: Optional advanced configuration
Expand Down
5 changes: 5 additions & 0 deletions custom_components/ha_text_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@
CONF_CONTEXT_MESSAGES,
API_PROVIDER_OPENAI,
API_PROVIDER_ANTHROPIC,
API_PROVIDER_DEEPSEEK,
DEFAULT_MODEL,
DEFAULT_DEEPSEEK_MODEL,
DEFAULT_TEMPERATURE,
DEFAULT_MAX_TOKENS,
DEFAULT_OPENAI_ENDPOINT,
DEFAULT_ANTHROPIC_ENDPOINT,
DEFAULT_DEEPSEEK_ENDPOINT,
DEFAULT_REQUEST_INTERVAL,
DEFAULT_CONTEXT_MESSAGES,
API_TIMEOUT,
Expand Down Expand Up @@ -235,6 +238,8 @@ async def async_check_api(session, endpoint: str, headers: dict, provider: str)
try:
if provider == API_PROVIDER_ANTHROPIC:
check_url = f"{endpoint}/v1/models"
elif provider == API_PROVIDER_DEEPSEEK:
check_url = f"{endpoint}/models" # DeepSeek
else: # OpenAI
check_url = f"{endpoint}/models"

Expand Down
42 changes: 37 additions & 5 deletions custom_components/ha_text_ai/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
API_TIMEOUT,
API_RETRY_COUNT,
API_PROVIDER_ANTHROPIC,
API_PROVIDER_DEEPSEEK,
API_PROVIDER_OPENAI,
MIN_TEMPERATURE,
MAX_TEMPERATURE,
MIN_MAX_TOKENS,
Expand Down Expand Up @@ -109,19 +111,49 @@ async def create(
return await self._create_anthropic_completion(
model, messages, temperature, max_tokens
)
elif self.api_provider == API_PROVIDER_DEEPSEEK:
return await self._create_deepseek_completion(
model, messages, temperature, max_tokens
)
else:
return await self._create_openai_completion(
model, messages, temperature, max_tokens
)
except (KeyError, IndexError) as e:
if "'choices'" in str(e) or "'message'" in str(e):
raise HomeAssistantError("Failed to get a response from the AI model. Please check your internet connection and try again later.")
else:
raise
except Exception as e:
_LOGGER.error("API request failed: %s", str(e))
raise HomeAssistantError(f"API request failed: {str(e)}")

async def _create_deepseek_completion(
self,
model: str,
messages: List[Dict[str, str]],
temperature: float,
max_tokens: int,
) -> Dict[str, Any]:
"""Create completion using DeepSeek API."""
url = f"{self.endpoint}/chat/completions"
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": False
}

data = await self._make_request(url, payload)
return {
"choices": [
{
"message": {"content": data["choices"][0]["message"]["content"]},
}
],
"usage": {
"prompt_tokens": data["usage"]["prompt_tokens"],
"completion_tokens": data["usage"]["completion_tokens"],
"total_tokens": data["usage"]["total_tokens"],
},
}

async def _create_openai_completion(
self,
model: str,
Expand Down
30 changes: 19 additions & 11 deletions custom_components/ha_text_ai/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
CONF_CONTEXT_MESSAGES,
API_PROVIDER_OPENAI,
API_PROVIDER_ANTHROPIC,
API_PROVIDER_DEEPSEEK,
API_PROVIDERS,
DEFAULT_MODEL,
DEFAULT_DEEPSEEK_MODEL,
DEFAULT_TEMPERATURE,
DEFAULT_MAX_TOKENS,
DEFAULT_REQUEST_INTERVAL,
DEFAULT_OPENAI_ENDPOINT,
DEFAULT_ANTHROPIC_ENDPOINT,
DEFAULT_DEEPSEEK_ENDPOINT,
DEFAULT_CONTEXT_MESSAGES,
MIN_TEMPERATURE,
MAX_TEMPERATURE,
Expand Down Expand Up @@ -90,17 +93,22 @@ async def async_step_provider(self, user_input: Optional[Dict[str, Any]] = None)
self._errors = {}

if user_input is None:
default_endpoint = (
DEFAULT_OPENAI_ENDPOINT if self._provider == API_PROVIDER_OPENAI
else DEFAULT_ANTHROPIC_ENDPOINT
)
# Выбор endpoint по провайдеру
default_endpoint = {
API_PROVIDER_OPENAI: DEFAULT_OPENAI_ENDPOINT,
API_PROVIDER_ANTHROPIC: DEFAULT_ANTHROPIC_ENDPOINT,
API_PROVIDER_DEEPSEEK: DEFAULT_DEEPSEEK_ENDPOINT,
}.get(self._provider, DEFAULT_OPENAI_ENDPOINT)

# Выбор модели по умолчанию по провайдеру
default_model = DEFAULT_DEEPSEEK_MODEL if self._provider == API_PROVIDER_DEEPSEEK else DEFAULT_MODEL

return self.async_show_form(
step_id="provider",
data_schema=vol.Schema({
vol.Required(CONF_NAME, default="my_assistant"): str,
vol.Required(CONF_API_KEY): str,
vol.Required(CONF_MODEL, default=DEFAULT_MODEL): str,
vol.Required(CONF_MODEL, default=default_model): str,
vol.Required(CONF_API_ENDPOINT, default=default_endpoint): str,
vol.Optional(CONF_TEMPERATURE, default=DEFAULT_TEMPERATURE): vol.All(
vol.Coerce(float),
Expand Down Expand Up @@ -156,15 +164,13 @@ async def async_step_provider(self, user_input: Optional[Dict[str, Any]] = None)
if not await self._async_validate_api(input_copy):
return self.async_show_form(
step_id="provider",
data_schema=vol.Schema({
}),
data_schema=vol.Schema({}),
errors=self._errors
)
except Exception as e:
return self.async_show_form(
step_id="provider",
data_schema=vol.Schema({
}),
data_schema=vol.Schema({}),
errors={"base": str(e)}
)

Expand Down Expand Up @@ -250,14 +256,16 @@ async def _create_entry(self, user_input: Dict[str, Any]) -> FlowResult:

unique_id = f"{DOMAIN}_{normalized_name}_{self._provider}".lower()

default_model = DEFAULT_DEEPSEEK_MODEL if self._provider == API_PROVIDER_DEEPSEEK else DEFAULT_MODEL

entry_data = {
CONF_API_PROVIDER: self._provider,
CONF_NAME: instance_name,
"normalized_name": normalized_name,
CONF_API_KEY: user_input.get(CONF_API_KEY),
CONF_API_ENDPOINT: user_input.get(CONF_API_ENDPOINT),
"unique_id": unique_id,
CONF_MODEL: user_input.get(CONF_MODEL, DEFAULT_MODEL),
CONF_MODEL: user_input.get(CONF_MODEL, default_model),
CONF_TEMPERATURE: user_input.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
CONF_MAX_TOKENS: user_input.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS),
CONF_REQUEST_INTERVAL: user_input.get(CONF_REQUEST_INTERVAL, DEFAULT_REQUEST_INTERVAL),
Expand Down Expand Up @@ -302,7 +310,7 @@ async def async_step_init(self, user_input: Optional[Dict[str, Any]] = None) ->
data_schema=vol.Schema({
vol.Optional(
CONF_MODEL,
default=current_data.get(CONF_MODEL, DEFAULT_MODEL)
default=current_data.get(CONF_MODEL, default_model)
): str,
vol.Optional(
CONF_TEMPERATURE,
Expand Down
10 changes: 7 additions & 3 deletions custom_components/ha_text_ai/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@source: https://github.com/smkrv/ha-text-ai
"""
import os
import json
import json
from typing import Final
import voluptuous as vol
from homeassistant.const import Platform, CONF_API_KEY, CONF_NAME
Expand All @@ -21,10 +21,12 @@
CONF_API_PROVIDER: Final = "api_provider"
API_PROVIDER_OPENAI: Final = "openai"
API_PROVIDER_ANTHROPIC: Final = "anthropic"
API_PROVIDER_DEEPSEEK: Final = "deepseek"

API_PROVIDERS: Final = [
API_PROVIDER_OPENAI,
API_PROVIDER_ANTHROPIC
API_PROVIDER_ANTHROPIC,
API_PROVIDER_DEEPSEEK
]

# Read version from manifest.json
Expand All @@ -46,6 +48,7 @@
# Default endpoints
DEFAULT_OPENAI_ENDPOINT: Final = "https://api.openai.com/v1"
DEFAULT_ANTHROPIC_ENDPOINT: Final = "https://api.anthropic.com"
DEFAULT_DEEPSEEK_ENDPOINT: Final = "https://api.deepseek.com"

# Configuration constants
CONF_MODEL: Final = "model"
Expand All @@ -65,6 +68,7 @@

# Default values
DEFAULT_MODEL: Final = "gpt-4o-mini"
DEFAULT_DEEPSEEK_MODEL: Final = "deepseek-chat"
DEFAULT_TEMPERATURE: Final = 0.1
DEFAULT_MAX_TOKENS: Final = 1000
DEFAULT_REQUEST_INTERVAL: Final = 1.0
Expand All @@ -80,7 +84,7 @@
MIN_TEMPERATURE: Final = 0.0
MAX_TEMPERATURE: Final = 2.0
MIN_MAX_TOKENS: Final = 1
MAX_MAX_TOKENS: Final = 4096
MAX_MAX_TOKENS: Final = 100000
MIN_REQUEST_INTERVAL: Final = 0.1
MAX_REQUEST_INTERVAL: Final = 60.0

Expand Down
91 changes: 5 additions & 86 deletions custom_components/ha_text_ai/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,46 +796,6 @@ def _get_safe_initial_state(self) -> Dict[str, Any]:
"normalized_name": self.normalized_name,
}

def _calculate_context_tokens(self, messages: List[Dict[str, str]], model: Optional[str] = None) -> int:
total_tokens = 0

# Compile regular expressions for performance
number_pattern = re.compile(r'[0-9]')
special_char_pattern = re.compile(r'[^\w\s]')
whitespace_pattern = re.compile(r'\s+')

try:
for message in messages:
text = message.get('content', '')
if text:
# Normalize whitespace
text = whitespace_pattern.sub(' ', text.strip())

# Advanced token estimation heuristics
words = text.split()
for word in words:
# Complexity-based token calculation
if len(word) > 8: # Long words
total_tokens += 2
elif number_pattern.search(word): # Words with numbers
total_tokens += 1.5
elif special_char_pattern.search(word): # Words with special characters
total_tokens += 1.5
elif word.isupper(): # Acronyms and technical terms
total_tokens += 1.5
else: # Regular words
total_tokens += 1

# Additional correction for technical texts
total_tokens = math.ceil(total_tokens * 1.2)

return int(total_tokens)

except Exception as e:
_LOGGER.error(f"Token calculation error: {e}. Processed {len(messages)} messages.")
# Safe fallback with minimal token estimation
return len(messages) * 100

async def async_ask_question(
self,
question: str,
Expand Down Expand Up @@ -876,9 +836,7 @@ async def async_process_question(
system_prompt: Optional[str] = None,
context_messages: Optional[int] = None,
) -> dict:
"""
Enhanced question processing with intelligent token management.
"""
"""Process question with context management."""
if self.client is None:
raise HomeAssistantError("AI client not initialized")

Expand All @@ -900,62 +858,23 @@ async def async_process_question(
if temp_system_prompt:
messages.append({"role": "system", "content": temp_system_prompt})

# Context history management
# Add context from history
context_history = self._conversation_history[-temp_context_messages:]

# Comprehensive token calculation
context_tokens = self._calculate_context_tokens(
[{"content": entry["question"]} for entry in context_history] +
[{"content": entry["response"]} for entry in context_history] +
[{"content": question}],
temp_model
)

# Dynamic token allocation
available_tokens = max(0, temp_max_tokens - context_tokens)

# Context trimming if over token limit
if context_tokens > temp_max_tokens:
_LOGGER.warning(
f"Token limit exceeded. "
f"Context: {context_tokens}, "
f"Max: {temp_max_tokens}"
)

# Intelligent context reduction
while context_tokens > temp_max_tokens // 2 and context_history:
context_history.pop(0)
context_tokens = self._calculate_context_tokens(
[{"content": entry["question"]} for entry in context_history] +
[{"content": entry["response"]} for entry in context_history] +
[{"content": question}],
temp_model
)

# Rebuild messages with trimmed context
for entry in context_history:
messages.append({"role": "user", "content": entry["question"]})
messages.append({"role": "assistant", "content": entry["response"]})

# Add current question
messages.append({"role": "user", "content": question})

# Detailed token logging
_LOGGER.debug(
f"Token Analysis: "
f"Context Tokens: {context_tokens}, "
f"Max Tokens: {temp_max_tokens}, "
f"Available Tokens: {available_tokens}"
)

# Prepare API call with dynamic token management
# Process message
kwargs = {
"model": temp_model,
"temperature": temp_temperature,
"max_tokens": min(temp_max_tokens, available_tokens),
"max_tokens": temp_max_tokens,
"messages": messages,
}

# Process message
response = await self.async_process_message(question, **kwargs)

# Update metrics
Expand Down
2 changes: 1 addition & 1 deletion custom_components/ha_text_ai/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@
"single_config_entry": false,
"ssdp": [],
"usb": [],
"version": "2.1.0",
"version": "2.1.1",
"zeroconf": []
}
2 changes: 1 addition & 1 deletion custom_components/ha_text_ai/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
name=self._attr_name,
manufacturer="Community",
model=f"{model} ({api_provider} provider)",
sw_version=VERSION,
sw_version=VERSION,
)

_LOGGER.debug(
Expand Down
Loading

0 comments on commit bfd64d1

Please sign in to comment.