From 075099e1e2498b86438565e608a0eb960b871b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Sat, 15 Feb 2025 08:06:42 +0100 Subject: [PATCH 1/9] feat(custom-base-url): add support for custom base URL configuration - Introduced a `custom_base_url` field to the `Account` model in `account_config.py` for providers that allow custom base URLs. - Updated `EditAccountDialog` in `account_dialog.py` to include a new input field for the custom base URL and adjusted the UI to reflect this change. - Modified `Provider` class in `provider.py` to include an `allow_custom_base_url` attribute, enabling custom base URL configuration for specific providers. - Updated provider configurations for `mistralai`, `openai`, and `openrouter` to support custom base URLs. - Adjusted OpenAIEngine logic to prioritize the custom base URL over the default provider base URL if provided. --- basilisk/config/account_config.py | 3 + basilisk/gui/account_dialog.py | 169 ++++++++++++++-------- basilisk/provider.py | 5 +- basilisk/provider_engine/openai_engine.py | 3 +- 4 files changed, 120 insertions(+), 60 deletions(-) diff --git a/basilisk/config/account_config.py b/basilisk/config/account_config.py index 18f3c33d..0946057b 100644 --- a/basilisk/config/account_config.py +++ b/basilisk/config/account_config.py @@ -131,6 +131,7 @@ class Account(BaseModel): organizations: Optional[list[AccountOrganization]] = Field(default=None) active_organization_id: Optional[UUID4] = Field(default=None) source: AccountSource = Field(default=AccountSource.CONFIG, exclude=True) + custom_base_url: Optional[str] = Field(default=None) def __init__(self, **data: Any): """Initialize an account instance. If an error occurs, log the error and raise an exception.""" @@ -186,6 +187,8 @@ def validate_api_key( if not value: raise ValueError("API key not found in keyring") return SecretStr(value) + elif not data["provider"].require_api_key and value is None: + return else: raise ValueError("Invalid API key storage method") diff --git a/basilisk/gui/account_dialog.py b/basilisk/gui/account_dialog.py index 5e668c50..f75bf1fe 100644 --- a/basilisk/gui/account_dialog.py +++ b/basilisk/gui/account_dialog.py @@ -1,7 +1,7 @@ """Account dialog for managing accounts and organizations in the basiliskLLM application.""" import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional import wx from more_itertools import first, locate @@ -16,6 +16,9 @@ ) from basilisk.provider import get_provider, providers +if TYPE_CHECKING: + from basilisk.provider import Provider + log = logging.getLogger(__name__) key_storage_methods = KeyStorageMethodEnum.get_labels() @@ -434,7 +437,7 @@ def __init__( self.init_ui() if account: self.init_data() - self.update_ui() + self.update_ui() self.Centre() self.Show() self.name.SetFocus() @@ -459,65 +462,73 @@ def init_ui(self): self.name = wx.TextCtrl(panel) sizer.Add(self.name, 0, wx.EXPAND) - label = wx.StaticText( + self.provider_label = wx.StaticText( panel, # Translators: A label in account dialog label=_("&Provider:"), style=wx.ALIGN_LEFT, ) - sizer.Add(label, 0, wx.ALL, 5) + sizer.Add(self.provider_label, 0, wx.ALL, 5) provider_choices = [provider.name for provider in providers] - self.provider = wx.ComboBox( + self.provider_combo = wx.ComboBox( panel, choices=provider_choices, style=wx.CB_READONLY ) - self.provider.Bind(wx.EVT_COMBOBOX, lambda e: self.update_ui()) - sizer.Add(self.provider, 0, wx.EXPAND) + self.provider_combo.Bind(wx.EVT_COMBOBOX, lambda e: self.update_ui()) + sizer.Add(self.provider_combo, 0, wx.EXPAND) - label = wx.StaticText( + self.api_key_storage_method_label = wx.StaticText( panel, style=wx.ALIGN_LEFT, # Translators: A label in account dialog label=_("API &key storage method:"), ) - sizer.Add(label, 0, wx.ALL, 5) - self.api_key_storage_method = wx.ComboBox( + sizer.Add(self.api_key_storage_method_label, 0, wx.ALL, 5) + self.api_key_storage_method_combo = wx.ComboBox( panel, choices=list(key_storage_methods.values()), style=wx.CB_READONLY, ) - sizer.Add(self.api_key_storage_method, 0, wx.EXPAND) - self.api_key_storage_method.Disable() - label = wx.StaticText( + sizer.Add(self.api_key_storage_method_combo, 0, wx.EXPAND) + + self.api_key_label = wx.StaticText( panel, style=wx.ALIGN_LEFT, # Translators: A label in account dialog label=_("API &key:"), ) - sizer.Add(label, 0, wx.ALL, 5) - self.api_key = wx.TextCtrl(panel) - self.api_key.Disable() - sizer.Add(self.api_key, 0, wx.EXPAND) + sizer.Add(self.api_key_label, 0, wx.ALL, 5) + self.api_key_text_ctrl = wx.TextCtrl(panel) + sizer.Add(self.api_key_text_ctrl, 0, wx.EXPAND) - label = wx.StaticText( + self.organization_label = wx.StaticText( panel, label=_("&Organization to use:"), style=wx.ALIGN_LEFT ) - sizer.Add(label, 0, wx.ALL, 5) - self.organization = wx.ComboBox(panel, style=wx.CB_READONLY) - self.organization.Disable() - sizer.Add(self.organization, 0, wx.EXPAND) + sizer.Add(self.organization_label, 0, wx.ALL, 5) + self.organization_text_ctrl = wx.ComboBox(panel, style=wx.CB_READONLY) + sizer.Add(self.organization_text_ctrl, 0, wx.EXPAND) - bSizer = wx.BoxSizer(wx.HORIZONTAL) + self.custom_base_url_label = wx.StaticText( + panel, + # Translators: A label in account dialog + label=_("Custom &base URL:"), + style=wx.ALIGN_LEFT, + ) + sizer.Add(self.custom_base_url_label, 0, wx.ALL, 5) + self.custom_base_url_text_ctrl = wx.TextCtrl(panel) + sizer.Add(self.custom_base_url_text_ctrl, 0, wx.EXPAND) + + buttons_sizer = wx.BoxSizer(wx.HORIZONTAL) btn = wx.Button(panel, wx.ID_OK) btn.SetDefault() btn.Bind(wx.EVT_BUTTON, self.on_ok) - bSizer.Add(btn, 0, wx.ALL, 5) + buttons_sizer.Add(btn, 0, wx.ALL, 5) btn = wx.Button(panel, wx.ID_CANCEL) btn.Bind(wx.EVT_BUTTON, self.on_cancel) - bSizer.Add(btn, 0, wx.ALL, 5) + buttons_sizer.Add(btn, 0, wx.ALL, 5) - sizer.Add(bSizer, 0, wx.ALL, 5) + sizer.Add(buttons_sizer, 0, wx.ALL, 5) def init_data(self): """Initialize the data for the dialog. @@ -526,14 +537,14 @@ def init_data(self): and organization settings are set in the dialog. """ if not self.account: - self.api_key_storage_method.SetSelection(0) + self.api_key_storage_method_combo.SetSelection(0) return self.name.SetValue(self.account.name) index = first( locate(providers, lambda x: x.name == self.account.provider.name), -1, ) - self.provider.SetSelection(index) + self.provider_combo.SetSelection(index) if self.account.api_key and self.account.api_key_storage_method: index = first( locate( @@ -542,9 +553,11 @@ def init_data(self): ), -1, ) - self.api_key_storage_method.SetSelection(index) - self.api_key.SetValue(self.account.api_key.get_secret_value()) - self.organization.Enable( + self.api_key_storage_method_combo.SetSelection(index) + self.api_key_text_ctrl.SetValue( + self.account.api_key.get_secret_value() + ) + self.organization_text_ctrl.Enable( self.account.provider.organization_mode_available ) if not self.account.provider.organization_mode_available: @@ -553,7 +566,7 @@ def init_data(self): choices = [_("Personal")] + [ organization.name for organization in self.account.organizations ] - self.organization.SetItems(choices) + self.organization_text_ctrl.SetItems(choices) if self.account.active_organization_id: index = ( first( @@ -565,24 +578,61 @@ def init_data(self): ) + 1 ) - self.organization.SetSelection(index) + self.organization_text_ctrl.SetSelection(index) + if self.account.custom_base_url: + self.custom_base_url_text_ctrl.SetValue( + self.account.custom_base_url + ) - def update_ui(self): - """Update the user interface of the dialog. + def get_selected_provider(self) -> Optional[Provider]: + """Get the provider object from the selected provider name. - Enable or disable API key and organization fields based on the selected provider's requirements. + Returns: + The provider object if a provider is selected, otherwise None. """ - provider_index = self.provider.GetSelection() - if provider_index == -1: - log.debug("No provider selected") + provider_index = self.provider_combo.GetSelection() + if provider_index == wx.NOT_FOUND: + return None + provider_name = self.provider_combo.GetValue() + return get_provider(name=provider_name) + + def _disable_all_provider_fields(self): + """Disable all provider-dependent fields.""" + fields_to_disable = ( + self.api_key_label, + self.api_key_text_ctrl, + self.api_key_storage_method_label, + self.api_key_storage_method_combo, + self.organization_label, + self.organization_text_ctrl, + self.custom_base_url_label, + self.custom_base_url_text_ctrl, + ) + for field in fields_to_disable: + field.Disable() + + def update_ui(self) -> None: + """Update UI elements based on selected provider settings.""" + provider = self.get_selected_provider() + if not provider: + self._disable_all_provider_fields() return - provider_name = self.provider.GetValue() - provider = get_provider(name=provider_name) - if provider.require_api_key: - self.api_key.Enable() - self.api_key_storage_method.Enable() - if self.account: - self.organization.Enable(provider.organization_mode_available) + self.api_key_label.Enable(provider.require_api_key) + self.api_key_text_ctrl.Enable(provider.require_api_key) + self.api_key_storage_method_label.Enable(provider.require_api_key) + self.api_key_storage_method_combo.Enable(provider.require_api_key) + self.organization_label.Enable(provider.organization_mode_available) + self.organization_text_ctrl.Enable(provider.organization_mode_available) + default_base_url = provider.base_url + self.custom_base_url_label.Enable(provider.allow_custom_base_url) + self.custom_base_url_text_ctrl.Enable(provider.allow_custom_base_url) + if provider.allow_custom_base_url: + custom_base_url_label = _("Custom &base URL:") + if default_base_url: + custom_base_url_label = _( + "Custom &base URL (default: {}):" + ).format(default_base_url) + self.custom_base_url_label.SetLabel(custom_base_url_label) def on_ok(self, event: wx.Event | None): """Handle the OK button click event. @@ -598,28 +648,26 @@ def on_ok(self, event: wx.Event | None): wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) self.name.SetFocus() return - provider_index = self.provider.GetSelection() - if provider_index == -1: + provider = self.get_selected_provider() + if not provider: msg = _("Please select a provider") wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.provider.SetFocus() + self.provider_combo.SetFocus() return - provider_name = self.provider.GetValue() - provider = get_provider(name=provider_name) if provider.require_api_key: - if self.api_key_storage_method.GetSelection() == -1: + if self.api_key_storage_method_combo.GetSelection() == -1: msg = _("Please select an API key storage method") wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.api_key_storage_method.SetFocus() + self.api_key_storage_method_combo.SetFocus() return - if not self.api_key.GetValue(): + if not self.api_key_text_ctrl.GetValue(): msg = _( "Please enter an API key. It is required for this provider" ) wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.api_key.SetFocus() + self.api_key_text_ctrl.SetFocus() return - organization_index = self.organization.GetSelection() + organization_index = self.organization_text_ctrl.GetSelection() active_organization = None if organization_index > 0: active_organization = self.account.organizations[ @@ -629,15 +677,19 @@ def on_ok(self, event: wx.Event | None): api_key = None if provider.require_api_key: api_key_storage_method = list(key_storage_methods.keys())[ - self.api_key_storage_method.GetSelection() + self.api_key_storage_method_combo.GetSelection() ] - api_key = SecretStr(self.api_key.GetValue()) + api_key = SecretStr(self.api_key_text_ctrl.GetValue()) + custom_base_url = self.custom_base_url_text_ctrl.GetValue() + if not provider.allow_custom_base_url or not custom_base_url.strip(): + custom_base_url = None if self.account: self.account.name = self.name.GetValue() self.account.provider = provider self.account.api_key_storage_method = api_key_storage_method self.account.api_key = api_key self.account.active_organization_id = active_organization + self.account.custom_base_url = custom_base_url else: self.account = Account( name=self.name.GetValue(), @@ -646,6 +698,7 @@ def on_ok(self, event: wx.Event | None): api_key=api_key, active_organization_id=active_organization, source=AccountSource.CONFIG, + custom_base_url=custom_base_url, ) self.EndModal(wx.ID_OK) diff --git a/basilisk/provider.py b/basilisk/provider.py index d19c2c1a..48972e65 100644 --- a/basilisk/provider.py +++ b/basilisk/provider.py @@ -48,7 +48,7 @@ class Provider: base_url: Optional[str] = field(default=None) organization_mode_available: bool = field(default=False) require_api_key: bool = field(default=True) - custom: bool = field(default=True) + allow_custom_base_url: bool = field(default=False) env_var_name_api_key: Optional[str] = field(default=None) env_var_name_organization_key: Optional[str] = field(default=None) @@ -114,6 +114,7 @@ def engine_cls(self) -> Type[BaseEngine]: require_api_key=True, env_var_name_api_key="MISTRAL_API_KEY", engine_cls_path="basilisk.provider_engine.mistralai_engine.MistralAIEngine", + allow_custom_base_url=True, ), Provider( id="openai", @@ -125,6 +126,7 @@ def engine_cls(self) -> Type[BaseEngine]: env_var_name_api_key="OPENAI_API_KEY", env_var_name_organization_key="OPENAI_ORG_KEY", engine_cls_path="basilisk.provider_engine.openai_engine.OpenAIEngine", + allow_custom_base_url=True, ), Provider( id="openrouter", @@ -135,6 +137,7 @@ def engine_cls(self) -> Type[BaseEngine]: require_api_key=True, env_var_name_api_key="OPENROUTER_API_KEY", engine_cls_path="basilisk.provider_engine.openrouter_engine.OpenRouterEngine", + allow_custom_base_url=True, ), Provider( id="xai", diff --git a/basilisk/provider_engine/openai_engine.py b/basilisk/provider_engine/openai_engine.py index 42250769..750ad539 100644 --- a/basilisk/provider_engine/openai_engine.py +++ b/basilisk/provider_engine/openai_engine.py @@ -78,7 +78,8 @@ def client(self) -> OpenAI: return OpenAI( api_key=self.account.api_key.get_secret_value(), organization=organization_key, - base_url=str(self.account.provider.base_url), + base_url=self.account.custom_base_url + or str(self.account.provider.base_url), ) @cached_property From 456782d1cbf7d0bc6abe03e86617d51f15cb200a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Sat, 15 Feb 2025 08:24:48 +0100 Subject: [PATCH 2/9] refactor(account_dialog): improve code structure in EditAccountDialog - Removed `TYPE_CHECKING` imports and directly imported `Provider`. - Refactored the `EditAccountDialog` class to improve clarity and maintainability: - Introduced type hints for method parameters and return types. - Improved method names and added docstrings for better readability. - Split large methods into smaller, more focused methods. - Added error messages for form validation as comments for translators. - Reduced code duplication by consolidating similar code blocks. - Revised organization and API key fields handling for clarity. --- basilisk/gui/account_dialog.py | 243 ++++++++++++++++++++++----------- 1 file changed, 162 insertions(+), 81 deletions(-) diff --git a/basilisk/gui/account_dialog.py b/basilisk/gui/account_dialog.py index f75bf1fe..1f55f5c0 100644 --- a/basilisk/gui/account_dialog.py +++ b/basilisk/gui/account_dialog.py @@ -1,7 +1,7 @@ """Account dialog for managing accounts and organizations in the basiliskLLM application.""" import logging -from typing import TYPE_CHECKING, Optional +from typing import Optional import wx from more_itertools import first, locate @@ -14,10 +14,7 @@ KeyStorageMethodEnum, accounts, ) -from basilisk.provider import get_provider, providers - -if TYPE_CHECKING: - from basilisk.provider import Provider +from basilisk.provider import Provider, get_provider, providers log = logging.getLogger(__name__) @@ -431,7 +428,7 @@ def __init__( size: The size of the dialog. account: The account to edit. If None, a new account will be created. """ - wx.Dialog.__init__(self, parent, title=title, size=size) + super().__init__(parent, title=title, size=size) self.parent = parent self.account = account self.init_ui() @@ -501,7 +498,10 @@ def init_ui(self): sizer.Add(self.api_key_text_ctrl, 0, wx.EXPAND) self.organization_label = wx.StaticText( - panel, label=_("&Organization to use:"), style=wx.ALIGN_LEFT + panel, + # Translators: A label in account dialog + label=_("&Organization to use:"), + style=wx.ALIGN_LEFT, ) sizer.Add(self.organization_label, 0, wx.ALL, 5) self.organization_text_ctrl = wx.ComboBox(panel, style=wx.CB_READONLY) @@ -539,34 +539,50 @@ def init_data(self): if not self.account: self.api_key_storage_method_combo.SetSelection(0) return + self.name.SetValue(self.account.name) index = first( locate(providers, lambda x: x.name == self.account.provider.name), -1, ) self.provider_combo.SetSelection(index) + if self.account.api_key and self.account.api_key_storage_method: - index = first( - locate( - key_storage_methods.keys(), - lambda x: x == self.account.api_key_storage_method, - ), - -1, - ) - self.api_key_storage_method_combo.SetSelection(index) - self.api_key_text_ctrl.SetValue( - self.account.api_key.get_secret_value() + self._set_api_key_data() + + self._init_organization_data() + + if self.account.custom_base_url: + self.custom_base_url_text_ctrl.SetValue( + self.account.custom_base_url ) + + def _set_api_key_data(self) -> None: + """Set API key related fields from account data.""" + index = first( + locate( + key_storage_methods.keys(), + lambda x: x == self.account.api_key_storage_method, + ), + -1, + ) + self.api_key_storage_method_combo.SetSelection(index) + self.api_key_text_ctrl.SetValue(self.account.api_key.get_secret_value()) + + def _init_organization_data(self) -> None: + """Initialize organization related fields.""" self.organization_text_ctrl.Enable( self.account.provider.organization_mode_available ) if not self.account.provider.organization_mode_available: return + if self.account.organizations: choices = [_("Personal")] + [ organization.name for organization in self.account.organizations ] self.organization_text_ctrl.SetItems(choices) + if self.account.active_organization_id: index = ( first( @@ -579,10 +595,6 @@ def init_data(self): + 1 ) self.organization_text_ctrl.SetSelection(index) - if self.account.custom_base_url: - self.custom_base_url_text_ctrl.SetValue( - self.account.custom_base_url - ) def get_selected_provider(self) -> Optional[Provider]: """Get the provider object from the selected provider name. @@ -596,9 +608,20 @@ def get_selected_provider(self) -> Optional[Provider]: provider_name = self.provider_combo.GetValue() return get_provider(name=provider_name) - def _disable_all_provider_fields(self): + def update_ui(self) -> None: + """Update UI elements based on selected provider.""" + provider = self.get_selected_provider() + if not provider: + self._disable_all_fields() + return + + self._update_api_key_fields(provider.require_api_key) + self._update_organization_fields(provider.organization_mode_available) + self._update_base_url_fields(provider) + + def _disable_all_fields(self) -> None: """Disable all provider-dependent fields.""" - fields_to_disable = ( + fields = [ self.api_key_label, self.api_key_text_ctrl, self.api_key_storage_method_label, @@ -607,34 +630,32 @@ def _disable_all_provider_fields(self): self.organization_text_ctrl, self.custom_base_url_label, self.custom_base_url_text_ctrl, - ) - for field in fields_to_disable: + ] + for field in fields: field.Disable() - def update_ui(self) -> None: - """Update UI elements based on selected provider settings.""" - provider = self.get_selected_provider() - if not provider: - self._disable_all_provider_fields() - return - self.api_key_label.Enable(provider.require_api_key) - self.api_key_text_ctrl.Enable(provider.require_api_key) - self.api_key_storage_method_label.Enable(provider.require_api_key) - self.api_key_storage_method_combo.Enable(provider.require_api_key) - self.organization_label.Enable(provider.organization_mode_available) - self.organization_text_ctrl.Enable(provider.organization_mode_available) - default_base_url = provider.base_url + def _update_api_key_fields(self, enable: bool) -> None: + """Update API key related fields state.""" + fields = [ + self.api_key_label, + self.api_key_text_ctrl, + self.api_key_storage_method_label, + self.api_key_storage_method_combo, + ] + for field in fields: + field.Enable(enable) + + def _update_organization_fields(self, enable: bool) -> None: + """Update organization related fields state.""" + self.organization_label.Enable(enable) + self.organization_text_ctrl.Enable(enable) + + def _update_base_url_fields(self, provider: Provider) -> None: + """Update base URL related fields.""" self.custom_base_url_label.Enable(provider.allow_custom_base_url) self.custom_base_url_text_ctrl.Enable(provider.allow_custom_base_url) - if provider.allow_custom_base_url: - custom_base_url_label = _("Custom &base URL:") - if default_base_url: - custom_base_url_label = _( - "Custom &base URL (default: {}):" - ).format(default_base_url) - self.custom_base_url_label.SetLabel(custom_base_url_label) - def on_ok(self, event: wx.Event | None): + def on_ok(self, event: wx.CommandEvent) -> None: """Handle the OK button click event. Validate the account settings and create or update the account. @@ -643,36 +664,56 @@ def on_ok(self, event: wx.Event | None): Args: event: The event that triggered the OK button click. If None, the OK button was not clicked. """ - if not self.name.GetValue(): - msg = _("Please enter a name") - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.name.SetFocus() + error_message = self._validate_form() + if error_message: + wx.MessageBox( + error_message, + # Translators: A title for the error message in account dialog + _("Error"), + wx.OK | wx.ICON_ERROR, + ) return + + self._save_account_data() + self.EndModal(wx.ID_OK) + + def _validate_form(self) -> Optional[str]: + """Validate form data and return error message if invalid.""" + if not self.name.GetValue(): + # Translators: An error message in account dialog + return _("Please enter a name") + provider = self.get_selected_provider() if not provider: - msg = _("Please select a provider") - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.provider_combo.SetFocus() - return + # Translators: An error message in account dialog + return _("Please select a provider") + if provider.require_api_key: - if self.api_key_storage_method_combo.GetSelection() == -1: - msg = _("Please select an API key storage method") - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.api_key_storage_method_combo.SetFocus() - return + if self.api_key_storage_method_combo.GetSelection() == wx.NOT_FOUND: + # Translators: An error message in account dialog + return _("Please select an API key storage method") if not self.api_key_text_ctrl.GetValue(): - msg = _( + # Translators: An error message in account dialog + return _( "Please enter an API key. It is required for this provider" ) - wx.MessageBox(msg, _("Error"), wx.OK | wx.ICON_ERROR) - self.api_key_text_ctrl.SetFocus() - return + + return None + + def _save_account_data(self) -> None: + """Save form data to account object.""" + provider = self.get_selected_provider() organization_index = self.organization_text_ctrl.GetSelection() active_organization = None - if organization_index > 0: + if ( + organization_index > 0 + and self.account + and self.account.organizations + ): active_organization = self.account.organizations[ organization_index - 1 ].id + api_key_storage_method = None api_key = None if provider.require_api_key: @@ -680,29 +721,64 @@ def on_ok(self, event: wx.Event | None): self.api_key_storage_method_combo.GetSelection() ] api_key = SecretStr(self.api_key_text_ctrl.GetValue()) + custom_base_url = self.custom_base_url_text_ctrl.GetValue() if not provider.allow_custom_base_url or not custom_base_url.strip(): custom_base_url = None + if self.account: - self.account.name = self.name.GetValue() - self.account.provider = provider - self.account.api_key_storage_method = api_key_storage_method - self.account.api_key = api_key - self.account.active_organization_id = active_organization - self.account.custom_base_url = custom_base_url + self._update_existing_account( + provider, + active_organization, + api_key_storage_method, + api_key, + custom_base_url, + ) else: - self.account = Account( - name=self.name.GetValue(), - provider=provider, - api_key_storage_method=api_key_storage_method, - api_key=api_key, - active_organization_id=active_organization, - source=AccountSource.CONFIG, - custom_base_url=custom_base_url, + self._create_new_account( + provider, + active_organization, + api_key_storage_method, + api_key, + custom_base_url, ) - self.EndModal(wx.ID_OK) - def on_cancel(self, event: wx.Event | None): + def _update_existing_account( + self, + provider: Provider, + active_organization: Optional[str], + api_key_storage_method: Optional[KeyStorageMethodEnum], + api_key: Optional[SecretStr], + custom_base_url: Optional[str], + ) -> None: + """Update existing account with form data.""" + self.account.name = self.name.GetValue() + self.account.provider = provider + self.account.api_key_storage_method = api_key_storage_method + self.account.api_key = api_key + self.account.active_organization_id = active_organization + self.account.custom_base_url = custom_base_url + + def _create_new_account( + self, + provider: Provider, + active_organization: Optional[str], + api_key_storage_method: Optional[KeyStorageMethodEnum], + api_key: Optional[SecretStr], + custom_base_url: Optional[str], + ) -> None: + """Create new account from form data.""" + self.account = Account( + name=self.name.GetValue(), + provider=provider, + api_key_storage_method=api_key_storage_method, + api_key=api_key, + active_organization_id=active_organization, + source=AccountSource.CONFIG, + custom_base_url=custom_base_url, + ) + + def on_cancel(self, event: wx.CommandEvent) -> None: """Handle the Cancel button click event. Close the dialog without saving any changes. @@ -745,7 +821,12 @@ def init_ui(self): sizer = wx.BoxSizer(wx.VERTICAL) panel.SetSizer(sizer) - label = wx.StaticText(panel, label=_("Accounts"), style=wx.ALIGN_LEFT) + label = wx.StaticText( + panel, + # Translators: A label in account dialog + label=_("Accounts"), + style=wx.ALIGN_LEFT, + ) sizer.Add(label, 0, wx.ALL, 5) self.account_list = wx.ListCtrl(panel, style=wx.LC_REPORT) self.account_list.AppendColumn( From 92ef9922c81fef709ebe0999724bce8747537c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Sat, 15 Feb 2025 08:28:56 +0100 Subject: [PATCH 3/9] feat(account_dialog): enhance validation and focus control - Introduced a regex pattern `CUSTOM_BASE_URL_PATTERN` for validating custom base URLs. - Changed `get_selected_provider` method into a `provider` property for enhanced code readability. - Updated `_validate_form` to return a tuple of error message and associated field. - Added focus control to invalid form fields to enhance user experience by directing attention to the error. - Refactor form validation to include custom base URL validation. --- basilisk/gui/account_dialog.py | 47 ++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/basilisk/gui/account_dialog.py b/basilisk/gui/account_dialog.py index 1f55f5c0..694ae739 100644 --- a/basilisk/gui/account_dialog.py +++ b/basilisk/gui/account_dialog.py @@ -1,6 +1,7 @@ """Account dialog for managing accounts and organizations in the basiliskLLM application.""" import logging +import re from typing import Optional import wx @@ -19,6 +20,9 @@ log = logging.getLogger(__name__) key_storage_methods = KeyStorageMethodEnum.get_labels() +CUSTOM_BASE_URL_PATTERN = ( + r"^https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(/[\w./%\-]*)?" +) class EditAccountOrganizationDialog(wx.Dialog): @@ -596,7 +600,8 @@ def _init_organization_data(self) -> None: ) self.organization_text_ctrl.SetSelection(index) - def get_selected_provider(self) -> Optional[Provider]: + @property + def provider(self) -> Optional[Provider]: """Get the provider object from the selected provider name. Returns: @@ -610,7 +615,7 @@ def get_selected_provider(self) -> Optional[Provider]: def update_ui(self) -> None: """Update UI elements based on selected provider.""" - provider = self.get_selected_provider() + provider = self.provider if not provider: self._disable_all_fields() return @@ -666,43 +671,63 @@ def on_ok(self, event: wx.CommandEvent) -> None: """ error_message = self._validate_form() if error_message: + msg, field = error_message wx.MessageBox( - error_message, + msg, # Translators: A title for the error message in account dialog _("Error"), wx.OK | wx.ICON_ERROR, ) + field.SetFocus() return self._save_account_data() self.EndModal(wx.ID_OK) - def _validate_form(self) -> Optional[str]: - """Validate form data and return error message if invalid.""" + def _validate_form(self) -> Optional[tuple[str, wx.Window]]: + """Validate form data and return a tuple of error message and field to focus on if any. + Returns None if form data is valid. + """ if not self.name.GetValue(): # Translators: An error message in account dialog - return _("Please enter a name") + return _("Please enter a name"), self.name - provider = self.get_selected_provider() + provider = self.provider if not provider: # Translators: An error message in account dialog - return _("Please select a provider") + return _("Please select a provider"), self.provider_combo if provider.require_api_key: if self.api_key_storage_method_combo.GetSelection() == wx.NOT_FOUND: # Translators: An error message in account dialog - return _("Please select an API key storage method") + return _( + "Please select an API key storage method" + ), self.api_key_storage_method_combo + if not self.api_key_text_ctrl.GetValue(): # Translators: An error message in account dialog return _( "Please enter an API key. It is required for this provider" - ) + ), self.api_key_text_ctrl + + if ( + self.provider.allow_custom_base_url + and self.custom_base_url_text_ctrl.GetValue() + ): + if not re.match( + CUSTOM_BASE_URL_PATTERN, + self.custom_base_url_text_ctrl.GetValue(), + ): + # Translators: An error message in account dialog + return _( + "Please enter a valid custom base URL" + ), self.custom_base_url_text_ctrl return None def _save_account_data(self) -> None: """Save form data to account object.""" - provider = self.get_selected_provider() + provider = self.provider organization_index = self.organization_text_ctrl.GetSelection() active_organization = None if ( From 9858bac9a8b81602e9680408d85ec418f806c699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Fri, 14 Feb 2025 20:00:30 +0100 Subject: [PATCH 4/9] Update basilisk/config/account_config.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- basilisk/config/account_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basilisk/config/account_config.py b/basilisk/config/account_config.py index 0946057b..31192bd5 100644 --- a/basilisk/config/account_config.py +++ b/basilisk/config/account_config.py @@ -131,7 +131,7 @@ class Account(BaseModel): organizations: Optional[list[AccountOrganization]] = Field(default=None) active_organization_id: Optional[UUID4] = Field(default=None) source: AccountSource = Field(default=AccountSource.CONFIG, exclude=True) - custom_base_url: Optional[str] = Field(default=None) + custom_base_url: Optional[str] = Field(default=None, pattern="^https?://[\\w.-]+(?::\\d+)?(?:/[\\w.-]*)*/?$") def __init__(self, **data: Any): """Initialize an account instance. If an error occurs, log the error and raise an exception.""" From 11ee1b61caa04116a17ee2e32343e339ceed9fd4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Feb 2025 19:00:38 +0000 Subject: [PATCH 5/9] style(pre-commit.ci): auto fixes from pre-commit hooks for more information, see https://pre-commit.ci --- basilisk/config/account_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/basilisk/config/account_config.py b/basilisk/config/account_config.py index 31192bd5..0b832c92 100644 --- a/basilisk/config/account_config.py +++ b/basilisk/config/account_config.py @@ -131,7 +131,9 @@ class Account(BaseModel): organizations: Optional[list[AccountOrganization]] = Field(default=None) active_organization_id: Optional[UUID4] = Field(default=None) source: AccountSource = Field(default=AccountSource.CONFIG, exclude=True) - custom_base_url: Optional[str] = Field(default=None, pattern="^https?://[\\w.-]+(?::\\d+)?(?:/[\\w.-]*)*/?$") + custom_base_url: Optional[str] = Field( + default=None, pattern="^https?://[\\w.-]+(?::\\d+)?(?:/[\\w.-]*)*/?$" + ) def __init__(self, **data: Any): """Initialize an account instance. If an error occurs, log the error and raise an exception.""" From 1bc65695dde22165f93c2a991a2d3aa7181c1098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Fri, 14 Feb 2025 20:12:05 +0100 Subject: [PATCH 6/9] Update basilisk/config/account_config.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- basilisk/config/account_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basilisk/config/account_config.py b/basilisk/config/account_config.py index 0b832c92..9937f71c 100644 --- a/basilisk/config/account_config.py +++ b/basilisk/config/account_config.py @@ -190,7 +190,7 @@ def validate_api_key( raise ValueError("API key not found in keyring") return SecretStr(value) elif not data["provider"].require_api_key and value is None: - return + return None else: raise ValueError("Invalid API key storage method") From d669042425c8b2ef107acacc6537e146a618bd20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Sat, 15 Feb 2025 08:34:14 +0100 Subject: [PATCH 7/9] refactor(config): centralize `CUSTOM_BASE_URL_PATTERN` for unified URL validation - Move the `CUSTOM_BASE_URL_PATTERN` regex from `account_dialog.py` to `account_config.py`. - Update the `Account` class definition to use this centralized pattern for `custom_base_url`. - Adjust imports in `__init__.py` and `account_dialog.py` to utilize the centralized `CUSTOM_BASE_URL_PATTERN`. --- basilisk/config/__init__.py | 9 ++++++++- basilisk/config/account_config.py | 9 ++++++++- basilisk/gui/account_dialog.py | 4 +--- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/basilisk/config/__init__.py b/basilisk/config/__init__.py index 1c1317d7..ed1de433 100644 --- a/basilisk/config/__init__.py +++ b/basilisk/config/__init__.py @@ -1,6 +1,11 @@ """Configuration module for Basilisk.""" -from .account_config import Account, AccountManager, AccountOrganization +from .account_config import ( + CUSTOM_BASE_URL_PATTERN, + Account, + AccountManager, + AccountOrganization, +) from .account_config import get_account_config as accounts from .config_enums import ( AccountSource, @@ -27,6 +32,8 @@ "conf", "ConversationProfile", "conversation_profiles", + "CUSTOM_BASE_URL_PATTERN", + "get_account_source_labels", "KeyStorageMethodEnum", "LogLevelEnum", "ReleaseChannelEnum", diff --git a/basilisk/config/account_config.py b/basilisk/config/account_config.py index 9937f71c..16d6f995 100644 --- a/basilisk/config/account_config.py +++ b/basilisk/config/account_config.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import re from functools import cache, cached_property from os import getenv from typing import Annotated, Any, Iterable, Optional, Union @@ -38,6 +39,10 @@ log = logging.getLogger(__name__) +CUSTOM_BASE_URL_PATTERN = re.compile( + r"^https?://[\w.-]+(?::\d{1,5})?(?:/[\w-]+)*/?$" +) + class AccountOrganization(BaseModel): """Manage organization key for an account.""" @@ -132,7 +137,9 @@ class Account(BaseModel): active_organization_id: Optional[UUID4] = Field(default=None) source: AccountSource = Field(default=AccountSource.CONFIG, exclude=True) custom_base_url: Optional[str] = Field( - default=None, pattern="^https?://[\\w.-]+(?::\\d+)?(?:/[\\w.-]*)*/?$" + default=None, + pattern=CUSTOM_BASE_URL_PATTERN, + description="Custom base URL for the API provider. Must be a valid HTTP/HTTPS URL.", ) def __init__(self, **data: Any): diff --git a/basilisk/gui/account_dialog.py b/basilisk/gui/account_dialog.py index 694ae739..18b78ca9 100644 --- a/basilisk/gui/account_dialog.py +++ b/basilisk/gui/account_dialog.py @@ -9,6 +9,7 @@ from pydantic import SecretStr from basilisk.config import ( + CUSTOM_BASE_URL_PATTERN, Account, AccountOrganization, AccountSource, @@ -20,9 +21,6 @@ log = logging.getLogger(__name__) key_storage_methods = KeyStorageMethodEnum.get_labels() -CUSTOM_BASE_URL_PATTERN = ( - r"^https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(/[\w./%\-]*)?" -) class EditAccountOrganizationDialog(wx.Dialog): From 4f6f3cda1c20e6eb7ae6621d85e5ed262725f0c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Sat, 15 Feb 2025 08:49:29 +0100 Subject: [PATCH 8/9] refactor(account_dialog): update type annotations to Python 3.10+ syntax Replaced 'Optional[T]' with 'T | None' for type hinting in the account_dialog.py file. This change modernizes the code to make use of the Python 3.10+ feature which improves readability. --- basilisk/gui/account_dialog.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/basilisk/gui/account_dialog.py b/basilisk/gui/account_dialog.py index 18b78ca9..ae32ed77 100644 --- a/basilisk/gui/account_dialog.py +++ b/basilisk/gui/account_dialog.py @@ -2,7 +2,6 @@ import logging import re -from typing import Optional import wx from more_itertools import first, locate @@ -30,7 +29,7 @@ def __init__( self, parent: wx.Window, title: str, - organization: Optional[AccountOrganization] = None, + organization: AccountOrganization | None = None, size: tuple[int, int] = (400, 200), ): """Initialize the dialog for editing account organization settings. @@ -599,7 +598,7 @@ def _init_organization_data(self) -> None: self.organization_text_ctrl.SetSelection(index) @property - def provider(self) -> Optional[Provider]: + def provider(self) -> Provider | None: """Get the provider object from the selected provider name. Returns: @@ -682,8 +681,9 @@ def on_ok(self, event: wx.CommandEvent) -> None: self._save_account_data() self.EndModal(wx.ID_OK) - def _validate_form(self) -> Optional[tuple[str, wx.Window]]: + def _validate_form(self) -> tuple[str, wx.Window] | None: """Validate form data and return a tuple of error message and field to focus on if any. + Returns None if form data is valid. """ if not self.name.GetValue(): @@ -769,10 +769,10 @@ def _save_account_data(self) -> None: def _update_existing_account( self, provider: Provider, - active_organization: Optional[str], - api_key_storage_method: Optional[KeyStorageMethodEnum], - api_key: Optional[SecretStr], - custom_base_url: Optional[str], + active_organization: str | None, + api_key_storage_method: KeyStorageMethodEnum | None, + api_key: SecretStr | None, + custom_base_url: str | None, ) -> None: """Update existing account with form data.""" self.account.name = self.name.GetValue() @@ -785,10 +785,10 @@ def _update_existing_account( def _create_new_account( self, provider: Provider, - active_organization: Optional[str], - api_key_storage_method: Optional[KeyStorageMethodEnum], - api_key: Optional[SecretStr], - custom_base_url: Optional[str], + active_organization: str | None, + api_key_storage_method: KeyStorageMethodEnum | None, + api_key: SecretStr | None, + custom_base_url: str | None, ) -> None: """Create new account from form data.""" self.account = Account( From 5c2d5efcbaf573d6bbe8b056d2371c7bb3852d3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Sat, 15 Feb 2025 09:29:24 +0100 Subject: [PATCH 9/9] feat(account_dialog): show default base URL Enhanced the account dialog to display the default base URL when available, aiding in customization clarity. --- basilisk/gui/account_dialog.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/basilisk/gui/account_dialog.py b/basilisk/gui/account_dialog.py index ae32ed77..ddb37f6e 100644 --- a/basilisk/gui/account_dialog.py +++ b/basilisk/gui/account_dialog.py @@ -656,6 +656,13 @@ def _update_base_url_fields(self, provider: Provider) -> None: """Update base URL related fields.""" self.custom_base_url_label.Enable(provider.allow_custom_base_url) self.custom_base_url_text_ctrl.Enable(provider.allow_custom_base_url) + default_base_url = provider.base_url + if default_base_url: + self.custom_base_url_label.SetLabel( + _("Custom &base URL (default: {})").format(default_base_url) + ) + else: + self.custom_base_url_label.SetLabel(_("Custom &base URL")) def on_ok(self, event: wx.CommandEvent) -> None: """Handle the OK button click event.