diff --git a/basilisk/config/__init__.py b/basilisk/config/__init__.py index 1c1317d7..ed1de433 100644 --- a/basilisk/config/__init__.py +++ b/basilisk/config/__init__.py @@ -1,6 +1,11 @@ """Configuration module for Basilisk.""" -from .account_config import Account, AccountManager, AccountOrganization +from .account_config import ( + CUSTOM_BASE_URL_PATTERN, + Account, + AccountManager, + AccountOrganization, +) from .account_config import get_account_config as accounts from .config_enums import ( AccountSource, @@ -27,6 +32,8 @@ "conf", "ConversationProfile", "conversation_profiles", + "CUSTOM_BASE_URL_PATTERN", + "get_account_source_labels", "KeyStorageMethodEnum", "LogLevelEnum", "ReleaseChannelEnum", diff --git a/basilisk/config/account_config.py b/basilisk/config/account_config.py index 18f3c33d..16d6f995 100644 --- a/basilisk/config/account_config.py +++ b/basilisk/config/account_config.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import re from functools import cache, cached_property from os import getenv from typing import Annotated, Any, Iterable, Optional, Union @@ -38,6 +39,10 @@ log = logging.getLogger(__name__) +CUSTOM_BASE_URL_PATTERN = re.compile( + r"^https?://[\w.-]+(?::\d{1,5})?(?:/[\w-]+)*/?$" +) + class AccountOrganization(BaseModel): """Manage organization key for an account.""" @@ -131,6 +136,11 @@ class Account(BaseModel): organizations: Optional[list[AccountOrganization]] = Field(default=None) active_organization_id: Optional[UUID4] = Field(default=None) source: AccountSource = Field(default=AccountSource.CONFIG, exclude=True) + custom_base_url: Optional[str] = Field( + default=None, + pattern=CUSTOM_BASE_URL_PATTERN, + description="Custom base URL for the API provider. Must be a valid HTTP/HTTPS URL.", + ) def __init__(self, **data: Any): """Initialize an account instance. If an error occurs, log the error and raise an exception.""" @@ -186,6 +196,8 @@ def validate_api_key( if not value: raise ValueError("API key not found in keyring") return SecretStr(value) + elif not data["provider"].require_api_key and value is None: + return None else: raise ValueError("Invalid API key storage method") diff --git a/basilisk/gui/account_dialog.py b/basilisk/gui/account_dialog.py index 5e668c50..ddb37f6e 100644 --- a/basilisk/gui/account_dialog.py +++ b/basilisk/gui/account_dialog.py @@ -1,20 +1,21 @@ """Account dialog for managing accounts and organizations in the basiliskLLM application.""" import logging -from typing import Optional +import re import wx from more_itertools import first, locate from pydantic import SecretStr from basilisk.config import ( + CUSTOM_BASE_URL_PATTERN, Account, AccountOrganization, AccountSource, KeyStorageMethodEnum, accounts, ) -from basilisk.provider import get_provider, providers +from basilisk.provider import Provider, get_provider, providers log = logging.getLogger(__name__) @@ -28,7 +29,7 @@ def __init__( self, parent: wx.Window, title: str, - organization: Optional[AccountOrganization] = None, + organization: AccountOrganization | None = None, size: tuple[int, int] = (400, 200), ): """Initialize the dialog for editing account organization settings. @@ -428,13 +429,13 @@ def __init__( size: The size of the dialog. account: The account to edit. If None, a new account will be created. """ - wx.Dialog.__init__(self, parent, title=title, size=size) + super().__init__(parent, title=title, size=size) self.parent = parent self.account = account self.init_ui() if account: self.init_data() - self.update_ui() + self.update_ui() self.Centre() self.Show() self.name.SetFocus() @@ -459,65 +460,76 @@ def init_ui(self): self.name = wx.TextCtrl(panel) sizer.Add(self.name, 0, wx.EXPAND) - label = wx.StaticText( + self.provider_label = wx.StaticText( panel, # Translators: A label in account dialog label=_("&Provider:"), style=wx.ALIGN_LEFT, ) - sizer.Add(label, 0, wx.ALL, 5) + sizer.Add(self.provider_label, 0, wx.ALL, 5) provider_choices = [provider.name for provider in providers] - self.provider = wx.ComboBox( + self.provider_combo = wx.ComboBox( panel, choices=provider_choices, style=wx.CB_READONLY ) - self.provider.Bind(wx.EVT_COMBOBOX, lambda e: self.update_ui()) - sizer.Add(self.provider, 0, wx.EXPAND) + self.provider_combo.Bind(wx.EVT_COMBOBOX, lambda e: self.update_ui()) + sizer.Add(self.provider_combo, 0, wx.EXPAND) - label = wx.StaticText( + self.api_key_storage_method_label = wx.StaticText( panel, style=wx.ALIGN_LEFT, # Translators: A label in account dialog label=_("API &key storage method:"), ) - sizer.Add(label, 0, wx.ALL, 5) - self.api_key_storage_method = wx.ComboBox( + sizer.Add(self.api_key_storage_method_label, 0, wx.ALL, 5) + self.api_key_storage_method_combo = wx.ComboBox( panel, choices=list(key_storage_methods.values()), style=wx.CB_READONLY, ) - sizer.Add(self.api_key_storage_method, 0, wx.EXPAND) - self.api_key_storage_method.Disable() - label = wx.StaticText( + sizer.Add(self.api_key_storage_method_combo, 0, wx.EXPAND) + + self.api_key_label = wx.StaticText( panel, style=wx.ALIGN_LEFT, # Translators: A label in account dialog label=_("API &key:"), ) - sizer.Add(label, 0, wx.ALL, 5) - self.api_key = wx.TextCtrl(panel) - self.api_key.Disable() - sizer.Add(self.api_key, 0, wx.EXPAND) + sizer.Add(self.api_key_label, 0, wx.ALL, 5) + self.api_key_text_ctrl = wx.TextCtrl(panel) + sizer.Add(self.api_key_text_ctrl, 0, wx.EXPAND) - label = wx.StaticText( - panel, label=_("&Organization to use:"), style=wx.ALIGN_LEFT + self.organization_label = wx.StaticText( + panel, + # Translators: A label in account dialog + label=_("&Organization to use:"), + style=wx.ALIGN_LEFT, ) - sizer.Add(label, 0, wx.ALL, 5) - self.organization = wx.ComboBox(panel, style=wx.CB_READONLY) - self.organization.Disable() - sizer.Add(self.organization, 0, wx.EXPAND) + sizer.Add(self.organization_label, 0, wx.ALL, 5) + self.organization_text_ctrl = wx.ComboBox(panel, style=wx.CB_READONLY) + sizer.Add(self.organization_text_ctrl, 0, wx.EXPAND) - bSizer = wx.BoxSizer(wx.HORIZONTAL) + self.custom_base_url_label = wx.StaticText( + panel, + # Translators: A label in account dialog + label=_("Custom &base URL:"), + style=wx.ALIGN_LEFT, + ) + sizer.Add(self.custom_base_url_label, 0, wx.ALL, 5) + self.custom_base_url_text_ctrl = wx.TextCtrl(panel) + sizer.Add(self.custom_base_url_text_ctrl, 0, wx.EXPAND) + + buttons_sizer = wx.BoxSizer(wx.HORIZONTAL) btn = wx.Button(panel, wx.ID_OK) btn.SetDefault() btn.Bind(wx.EVT_BUTTON, self.on_ok) - bSizer.Add(btn, 0, wx.ALL, 5) + buttons_sizer.Add(btn, 0, wx.ALL, 5) btn = wx.Button(panel, wx.ID_CANCEL) btn.Bind(wx.EVT_BUTTON, self.on_cancel) - bSizer.Add(btn, 0, wx.ALL, 5) + buttons_sizer.Add(btn, 0, wx.ALL, 5) - sizer.Add(bSizer, 0, wx.ALL, 5) + sizer.Add(buttons_sizer, 0, wx.ALL, 5) def init_data(self): """Initialize the data for the dialog. @@ -526,34 +538,52 @@ def init_data(self): and organization settings are set in the dialog. """ if not self.account: - self.api_key_storage_method.SetSelection(0) + self.api_key_storage_method_combo.SetSelection(0) return + self.name.SetValue(self.account.name) index = first( locate(providers, lambda x: x.name == self.account.provider.name), -1, ) - self.provider.SetSelection(index) + self.provider_combo.SetSelection(index) + if self.account.api_key and self.account.api_key_storage_method: - index = first( - locate( - key_storage_methods.keys(), - lambda x: x == self.account.api_key_storage_method, - ), - -1, + self._set_api_key_data() + + self._init_organization_data() + + if self.account.custom_base_url: + self.custom_base_url_text_ctrl.SetValue( + self.account.custom_base_url ) - self.api_key_storage_method.SetSelection(index) - self.api_key.SetValue(self.account.api_key.get_secret_value()) - self.organization.Enable( + + def _set_api_key_data(self) -> None: + """Set API key related fields from account data.""" + index = first( + locate( + key_storage_methods.keys(), + lambda x: x == self.account.api_key_storage_method, + ), + -1, + ) + self.api_key_storage_method_combo.SetSelection(index) + self.api_key_text_ctrl.SetValue(self.account.api_key.get_secret_value()) + + def _init_organization_data(self) -> None: + """Initialize organization related fields.""" + self.organization_text_ctrl.Enable( self.account.provider.organization_mode_available ) if not self.account.provider.organization_mode_available: return + if self.account.organizations: choices = [_("Personal")] + [ organization.name for organization in self.account.organizations ] - self.organization.SetItems(choices) + self.organization_text_ctrl.SetItems(choices) + if self.account.active_organization_id: index = ( first( @@ -565,26 +595,76 @@ def init_data(self): ) + 1 ) - self.organization.SetSelection(index) + self.organization_text_ctrl.SetSelection(index) - def update_ui(self): - """Update the user interface of the dialog. + @property + def provider(self) -> Provider | None: + """Get the provider object from the selected provider name. - Enable or disable API key and organization fields based on the selected provider's requirements. + Returns: + The provider object if a provider is selected, otherwise None. """ - provider_index = self.provider.GetSelection() - if provider_index == -1: - log.debug("No provider selected") + provider_index = self.provider_combo.GetSelection() + if provider_index == wx.NOT_FOUND: + return None + provider_name = self.provider_combo.GetValue() + return get_provider(name=provider_name) + + def update_ui(self) -> None: + """Update UI elements based on selected provider.""" + provider = self.provider + if not provider: + self._disable_all_fields() return - provider_name = self.provider.GetValue() - provider = get_provider(name=provider_name) - if provider.require_api_key: - self.api_key.Enable() - self.api_key_storage_method.Enable() - if self.account: - self.organization.Enable(provider.organization_mode_available) - def on_ok(self, event: wx.Event | None): + self._update_api_key_fields(provider.require_api_key) + self._update_organization_fields(provider.organization_mode_available) + self._update_base_url_fields(provider) + + def _disable_all_fields(self) -> None: + """Disable all provider-dependent fields.""" + fields = [ + self.api_key_label, + self.api_key_text_ctrl, + self.api_key_storage_method_label, + self.api_key_storage_method_combo, + self.organization_label, + self.organization_text_ctrl, + self.custom_base_url_label, + self.custom_base_url_text_ctrl, + ] + for field in fields: + field.Disable() + + def _update_api_key_fields(self, enable: bool) -> None: + """Update API key related fields state.""" + fields = [ + self.api_key_label, + self.api_key_text_ctrl, + self.api_key_storage_method_label, + self.api_key_storage_method_combo, + ] + for field in fields: + field.Enable(enable) + + def _update_organization_fields(self, enable: bool) -> None: + """Update organization related fields state.""" + self.organization_label.Enable(enable) + self.organization_text_ctrl.Enable(enable) + + def _update_base_url_fields(self, provider: Provider) -> None: + """Update base URL related fields.""" + self.custom_base_url_label.Enable(provider.allow_custom_base_url) + self.custom_base_url_text_ctrl.Enable(provider.allow_custom_base_url) + default_base_url = provider.base_url + if default_base_url: + self.custom_base_url_label.SetLabel( + _("Custom &base URL (default: {})").format(default_base_url) + ) + else: + self.custom_base_url_label.SetLabel(_("Custom &base URL")) + + def on_ok(self, event: wx.CommandEvent) -> None: """Handle the OK button click event. Validate the account settings and create or update the account. @@ -593,63 +673,142 @@ def on_ok(self, event: wx.Event | None): Args: event: The event that triggered the OK button click. If None, the OK button was not clicked. """ - if not self.name.GetValue(): - msg = _("Please enter a name") - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.name.SetFocus() - return - provider_index = self.provider.GetSelection() - if provider_index == -1: - msg = _("Please select a provider") - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.provider.SetFocus() + error_message = self._validate_form() + if error_message: + msg, field = error_message + wx.MessageBox( + msg, + # Translators: A title for the error message in account dialog + _("Error"), + wx.OK | wx.ICON_ERROR, + ) + field.SetFocus() return - provider_name = self.provider.GetValue() - provider = get_provider(name=provider_name) + + self._save_account_data() + self.EndModal(wx.ID_OK) + + def _validate_form(self) -> tuple[str, wx.Window] | None: + """Validate form data and return a tuple of error message and field to focus on if any. + + Returns None if form data is valid. + """ + if not self.name.GetValue(): + # Translators: An error message in account dialog + return _("Please enter a name"), self.name + + provider = self.provider + if not provider: + # Translators: An error message in account dialog + return _("Please select a provider"), self.provider_combo + if provider.require_api_key: - if self.api_key_storage_method.GetSelection() == -1: - msg = _("Please select an API key storage method") - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.api_key_storage_method.SetFocus() - return - if not self.api_key.GetValue(): - msg = _( + if self.api_key_storage_method_combo.GetSelection() == wx.NOT_FOUND: + # Translators: An error message in account dialog + return _( + "Please select an API key storage method" + ), self.api_key_storage_method_combo + + if not self.api_key_text_ctrl.GetValue(): + # Translators: An error message in account dialog + return _( "Please enter an API key. It is required for this provider" - ) - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.api_key.SetFocus() - return - organization_index = self.organization.GetSelection() + ), self.api_key_text_ctrl + + if ( + self.provider.allow_custom_base_url + and self.custom_base_url_text_ctrl.GetValue() + ): + if not re.match( + CUSTOM_BASE_URL_PATTERN, + self.custom_base_url_text_ctrl.GetValue(), + ): + # Translators: An error message in account dialog + return _( + "Please enter a valid custom base URL" + ), self.custom_base_url_text_ctrl + + return None + + def _save_account_data(self) -> None: + """Save form data to account object.""" + provider = self.provider + organization_index = self.organization_text_ctrl.GetSelection() active_organization = None - if organization_index > 0: + if ( + organization_index > 0 + and self.account + and self.account.organizations + ): active_organization = self.account.organizations[ organization_index - 1 ].id + api_key_storage_method = None api_key = None if provider.require_api_key: api_key_storage_method = list(key_storage_methods.keys())[ - self.api_key_storage_method.GetSelection() + self.api_key_storage_method_combo.GetSelection() ] - api_key = SecretStr(self.api_key.GetValue()) + api_key = SecretStr(self.api_key_text_ctrl.GetValue()) + + custom_base_url = self.custom_base_url_text_ctrl.GetValue() + if not provider.allow_custom_base_url or not custom_base_url.strip(): + custom_base_url = None + if self.account: - self.account.name = self.name.GetValue() - self.account.provider = provider - self.account.api_key_storage_method = api_key_storage_method - self.account.api_key = api_key - self.account.active_organization_id = active_organization + self._update_existing_account( + provider, + active_organization, + api_key_storage_method, + api_key, + custom_base_url, + ) else: - self.account = Account( - name=self.name.GetValue(), - provider=provider, - api_key_storage_method=api_key_storage_method, - api_key=api_key, - active_organization_id=active_organization, - source=AccountSource.CONFIG, + self._create_new_account( + provider, + active_organization, + api_key_storage_method, + api_key, + custom_base_url, ) - self.EndModal(wx.ID_OK) - def on_cancel(self, event: wx.Event | None): + def _update_existing_account( + self, + provider: Provider, + active_organization: str | None, + api_key_storage_method: KeyStorageMethodEnum | None, + api_key: SecretStr | None, + custom_base_url: str | None, + ) -> None: + """Update existing account with form data.""" + self.account.name = self.name.GetValue() + self.account.provider = provider + self.account.api_key_storage_method = api_key_storage_method + self.account.api_key = api_key + self.account.active_organization_id = active_organization + self.account.custom_base_url = custom_base_url + + def _create_new_account( + self, + provider: Provider, + active_organization: str | None, + api_key_storage_method: KeyStorageMethodEnum | None, + api_key: SecretStr | None, + custom_base_url: str | None, + ) -> None: + """Create new account from form data.""" + self.account = Account( + name=self.name.GetValue(), + provider=provider, + api_key_storage_method=api_key_storage_method, + api_key=api_key, + active_organization_id=active_organization, + source=AccountSource.CONFIG, + custom_base_url=custom_base_url, + ) + + def on_cancel(self, event: wx.CommandEvent) -> None: """Handle the Cancel button click event. Close the dialog without saving any changes. @@ -692,7 +851,12 @@ def init_ui(self): sizer = wx.BoxSizer(wx.VERTICAL) panel.SetSizer(sizer) - label = wx.StaticText(panel, label=_("Accounts"), style=wx.ALIGN_LEFT) + label = wx.StaticText( + panel, + # Translators: A label in account dialog + label=_("Accounts"), + style=wx.ALIGN_LEFT, + ) sizer.Add(label, 0, wx.ALL, 5) self.account_list = wx.ListCtrl(panel, style=wx.LC_REPORT) self.account_list.AppendColumn( diff --git a/basilisk/provider.py b/basilisk/provider.py index d19c2c1a..48972e65 100644 --- a/basilisk/provider.py +++ b/basilisk/provider.py @@ -48,7 +48,7 @@ class Provider: base_url: Optional[str] = field(default=None) organization_mode_available: bool = field(default=False) require_api_key: bool = field(default=True) - custom: bool = field(default=True) + allow_custom_base_url: bool = field(default=False) env_var_name_api_key: Optional[str] = field(default=None) env_var_name_organization_key: Optional[str] = field(default=None) @@ -114,6 +114,7 @@ def engine_cls(self) -> Type[BaseEngine]: require_api_key=True, env_var_name_api_key="MISTRAL_API_KEY", engine_cls_path="basilisk.provider_engine.mistralai_engine.MistralAIEngine", + allow_custom_base_url=True, ), Provider( id="openai", @@ -125,6 +126,7 @@ def engine_cls(self) -> Type[BaseEngine]: env_var_name_api_key="OPENAI_API_KEY", env_var_name_organization_key="OPENAI_ORG_KEY", engine_cls_path="basilisk.provider_engine.openai_engine.OpenAIEngine", + allow_custom_base_url=True, ), Provider( id="openrouter", @@ -135,6 +137,7 @@ def engine_cls(self) -> Type[BaseEngine]: require_api_key=True, env_var_name_api_key="OPENROUTER_API_KEY", engine_cls_path="basilisk.provider_engine.openrouter_engine.OpenRouterEngine", + allow_custom_base_url=True, ), Provider( id="xai", diff --git a/basilisk/provider_engine/openai_engine.py b/basilisk/provider_engine/openai_engine.py index 42250769..750ad539 100644 --- a/basilisk/provider_engine/openai_engine.py +++ b/basilisk/provider_engine/openai_engine.py @@ -78,7 +78,8 @@ def client(self) -> OpenAI: return OpenAI( api_key=self.account.api_key.get_secret_value(), organization=organization_key, - base_url=str(self.account.provider.base_url), + base_url=self.account.custom_base_url + or str(self.account.provider.base_url), ) @cached_property