From 213891cb86c81fff1b1f891bf91bff01ca2e28e8 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Fri, 5 Apr 2024 09:15:50 +0100 Subject: [PATCH 01/19] Distinguish between completion and chat models --- .../jupyter_ai_magics/providers.py | 10 + .../completions/handlers/llm_mixin.py | 4 +- .../jupyter_ai/config/config_schema.json | 23 ++ .../jupyter-ai/jupyter_ai/config_manager.py | 124 +++---- packages/jupyter-ai/jupyter_ai/handlers.py | 6 + packages/jupyter-ai/jupyter_ai/models.py | 11 + packages/jupyter-ai/src/completions/plugin.ts | 47 ++- .../jupyter-ai/src/completions/settings.tsx | 181 ++++++++++ .../src/components/chat-settings.tsx | 230 +------------ .../src/components/model-settings.tsx | 308 ++++++++++++++++++ .../components/settings/use-server-info.ts | 35 +- packages/jupyter-ai/src/handler.ts | 7 + 12 files changed, 706 insertions(+), 280 deletions(-) create mode 100644 packages/jupyter-ai/src/completions/settings.tsx create mode 100644 packages/jupyter-ai/src/components/model-settings.tsx diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index c07061b1..091b78fe 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -239,6 +239,16 @@ class Config: provider is selected. """ + @classmethod + def chat_models(self): + """Models which are suitable for chat.""" + return self.models + + @classmethod + def completion_models(self): + """Models which are suitable for completions.""" + return self.models + # # instance attrs # diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py index 1371e3cb..fa16920d 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py @@ -27,8 +27,8 @@ def __init__(self, *args, **kwargs): self.llm_chain = None def get_llm_chain(self): - lm_provider = self.config_manager.lm_provider - lm_provider_params = self.config_manager.lm_provider_params + lm_provider = self.config_manager.completions_lm_provider + lm_provider_params = self.config_manager.completions_lm_provider_params if not lm_provider or not lm_provider_params: return None diff --git a/packages/jupyter-ai/jupyter_ai/config/config_schema.json b/packages/jupyter-ai/jupyter_ai/config/config_schema.json index ff7c717c..0dd910ab 100644 --- a/packages/jupyter-ai/jupyter_ai/config/config_schema.json +++ b/packages/jupyter-ai/jupyter_ai/config/config_schema.json @@ -16,6 +16,18 @@ "default": null, "readOnly": false }, + "completions_model_provider_id": { + "$comment": "Language model global ID for completions.", + "type": ["string", "null"], + "default": null, + "readOnly": false + }, + "completions_embeddings_provider_id": { + "$comment": "Embedding model global ID for completions.", + "type": ["string", "null"], + "default": null, + "readOnly": false + }, "api_keys": { "$comment": "Dictionary of API keys, mapping key names to key values.", "type": "object", @@ -37,6 +49,17 @@ } }, "additionalProperties": false + }, + "completions_fields": { + "$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.", + "type": "object", + "default": {}, + "patternProperties": { + "^.*$": { + "anyOf": [{ "type": "object" }] + } + }, + "additionalProperties": false } } } diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 392e4460..5969c9f9 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -164,15 +164,22 @@ def _process_existing_config(self, default_config): {k: v for k, v in existing_config.items() if v is not None}, ) config = GlobalConfig(**merged_config) - validated_config = self._validate_lm_em_id(config) + validated_config = self._validate_lm_em_id( + config, lm_key="model_provider_id", em_key="embeddings_provider_id" + ) + validated_config = self._validate_lm_em_id( + config, + lm_key="completions_model_provider_id", + em_key="completions_embeddings_provider_id", + ) # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(validated_config) - def _validate_lm_em_id(self, config): - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id + def _validate_lm_em_id(self, config, lm_key, em_key): + lm_id = getattr(config, lm_key) + em_id = getattr(config, em_key) # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. @@ -180,12 +187,12 @@ def _validate_lm_em_id(self, config): self.log.warning( f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." ) - config.model_provider_id = None + setattr(config, lm_key, None) if em_id is not None and not self._validate_model(em_id, raise_exc=False): self.log.warning( f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." ) - config.embeddings_provider_id = None + setattr(config, em_key, None) # if the currently selected language or embedding model ids are # not associated with models, set them to `None` and log a warning. @@ -193,12 +200,12 @@ def _validate_lm_em_id(self, config): self.log.warning( f"No language model is associated with '{lm_id}'. Setting to None." ) - config.model_provider_id = None + setattr(config, lm_key, None) if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: self.log.warning( f"No embedding model is associated with '{em_id}'. Setting to None." ) - config.embeddings_provider_id = None + setattr(config, em_key, None) return config @@ -321,6 +328,9 @@ def _write_config(self, new_config: GlobalConfig): complete `GlobalConfig` object, and should not be called publicly.""" # remove any empty field dictionaries new_config.fields = {k: v for k, v in new_config.fields.items() if v} + new_config.completions_fields = { + k: v for k, v in new_config.completions_fields.items() if v + } self._validate_config(new_config) with open(self.config_path, "w") as f: @@ -328,21 +338,19 @@ def _write_config(self, new_config: GlobalConfig): def delete_api_key(self, key_name: str): config_dict = self._read_config().dict() - lm_provider = self.lm_provider - em_provider = self.em_provider required_keys = [] - if ( - lm_provider - and lm_provider.auth_strategy - and lm_provider.auth_strategy.type == "env" - ): - required_keys.append(lm_provider.auth_strategy.name) - if ( - em_provider - and em_provider.auth_strategy - and em_provider.auth_strategy.type == "env" - ): - required_keys.append(self.em_provider.auth_strategy.name) + for provider in [ + self.lm_provider, + self.em_provider, + self.completions_lm_provider, + self.completions_em_provider, + ]: + if ( + provider + and provider.auth_strategy + and provider.auth_strategy.type == "env" + ): + required_keys.append(provider.auth_strategy.name) if key_name in required_keys: raise KeyInUseError( @@ -390,67 +398,69 @@ def em_gid(self): @property def lm_provider(self): - config = self._read_config() - lm_gid = config.model_provider_id - if lm_gid is None: - return None - - _, Provider = get_lm_provider(config.model_provider_id, self._lm_providers) - return Provider + return self._get_provider("model_provider_id", self._lm_providers) @property def em_provider(self): + return self._get_provider("embeddings_provider_id", self._em_providers) + + @property + def completions_lm_provider(self): + return self._get_provider("completions_model_provider_id", self._lm_providers) + + @property + def completions_em_provider(self): + return self._get_provider( + "completions_embeddings_provider_id", self._em_providers + ) + + def _get_provider(self, key, listing): config = self._read_config() - em_gid = config.embeddings_provider_id - if em_gid is None: + gid = getattr(config, key) + if gid is None: return None - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_lm_provider(gid, listing) return Provider @property def lm_provider_params(self): - # get generic fields - config = self._read_config() - lm_gid = config.model_provider_id - if not lm_gid: - return None + return self._provider_params("model_provider_id", self._lm_providers) - lm_lid = lm_gid.split(":", 1)[1] - fields = config.fields.get(lm_gid, {}) - - # get authn fields - _, Provider = get_lm_provider(lm_gid, self._lm_providers) - authn_fields = {} - if Provider.auth_strategy and Provider.auth_strategy.type == "env": - key_name = Provider.auth_strategy.name - authn_fields[key_name.lower()] = config.api_keys[key_name] + @property + def em_provider_params(self): + return self._provider_params("embeddings_provider_id", self._em_providers) - return { - "model_id": lm_lid, - **fields, - **authn_fields, - } + @property + def completions_lm_provider_params(self): + return self._provider_params( + "completions_model_provider_id", self._lm_providers + ) @property - def em_provider_params(self): + def completions_em_provider_params(self): + return self._provider_params( + "completions_embeddings_provider_id", self._em_providers + ) + + def _provider_params(self, key, listing): # get generic fields config = self._read_config() - em_gid = config.embeddings_provider_id - if not em_gid: + gid = getattr(config, key) + if not gid: return None - em_lid = em_gid.split(":", 1)[1] + lid = gid.split(":", 1)[1] # get authn fields - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_em_provider(gid, listing) authn_fields = {} if Provider.auth_strategy and Provider.auth_strategy.type == "env": key_name = Provider.auth_strategy.name authn_fields[key_name.lower()] = config.api_keys[key_name] return { - "model_id": em_lid, + "model_id": lid, **authn_fields, } diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 6002310b..976b7df6 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -290,6 +290,10 @@ def filter_predicate(local_model_id: str): # filter out every model w/ model ID according to allow/blocklist for provider in providers: provider.models = list(filter(filter_predicate, provider.models)) + provider.chat_models = list(filter(filter_predicate, provider.chat_models)) + provider.completion_models = list( + filter(filter_predicate, provider.completion_models) + ) # filter out every provider with no models which satisfy the allow/blocklist, then return return filter((lambda p: len(p.models) > 0), providers) @@ -311,6 +315,8 @@ def get(self): id=provider.id, name=provider.name, models=provider.models, + chat_models=provider.chat_models(), + completion_models=provider.completion_models(), help=provider.help, auth_strategy=provider.auth_strategy, registry=provider.registry, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 32353a69..5675c636 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -94,6 +94,8 @@ class ListProvidersEntry(BaseModel): auth_strategy: AuthStrategy registry: bool fields: List[Field] + chat_models: Optional[List[str]] + completion_models: Optional[List[str]] class ListProvidersResponse(BaseModel): @@ -121,6 +123,9 @@ class DescribeConfigResponse(BaseModel): # timestamp indicating when the configuration file was last read. should be # passed to the subsequent UpdateConfig request. last_read: int + completions_model_provider_id: Optional[str] + completions_embeddings_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] def forbid_none(cls, v): @@ -137,6 +142,9 @@ class UpdateConfigRequest(BaseModel): # if passed, this will raise an Error if the config was written to after the # time specified by `last_read` to prevent write-write conflicts. last_read: Optional[int] + completions_model_provider_id: Optional[str] + completions_embeddings_provider_id: Optional[str] + completions_fields: Optional[Dict[str, Dict[str, Any]]] _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( forbid_none @@ -154,3 +162,6 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] + completions_model_provider_id: Optional[str] + completions_embeddings_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index 83556114..c2519fd9 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -4,6 +4,8 @@ import { } from '@jupyterlab/application'; import { ICompletionProviderManager } from '@jupyterlab/completer'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; +import { MainAreaWidget } from '@jupyterlab/apputils'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { IEditorLanguageRegistry, IEditorLanguage @@ -12,6 +14,8 @@ import { getEditor } from '../selection-watcher'; import { IJaiStatusItem } from '../tokens'; import { displayName, JaiInlineProvider } from './provider'; import { CompletionWebsocketHandler } from './handler'; +import { jupyternautIcon } from '../icons'; +import { ModelSettingsWidget } from './settings'; export namespace CommandIDs { /** @@ -23,6 +27,10 @@ export namespace CommandIDs { */ export const toggleLanguageCompletions = 'jupyter-ai:toggle-language-completions'; + /** + * Command to open provider/model configuration. + */ + export const configureModel = 'jupyter-ai:configure-completions'; } const INLINE_COMPLETER_PLUGIN = @@ -52,7 +60,8 @@ export const completionPlugin: JupyterFrontEndPlugin = { requires: [ ICompletionProviderManager, IEditorLanguageRegistry, - ISettingRegistry + ISettingRegistry, + IRenderMimeRegistry ], optional: [IJaiStatusItem], activate: async ( @@ -60,6 +69,7 @@ export const completionPlugin: JupyterFrontEndPlugin = { completionManager: ICompletionProviderManager, languageRegistry: IEditorLanguageRegistry, settingRegistry: ISettingRegistry, + rmRegistry: IRenderMimeRegistry, statusItem: IJaiStatusItem | null ): Promise => { if (typeof completionManager.registerInlineProvider === 'undefined') { @@ -176,6 +186,37 @@ export const completionPlugin: JupyterFrontEndPlugin = { } }); + let settingsWidget: MainAreaWidget | null = null; + const newSettingsWidget = () => { + const content = new ModelSettingsWidget({ + rmRegistry, + isProviderEnabled: () => provider.isEnabled(), + openInlineCompleterSettings: () => { + app.commands.execute('settingeditor:open', { + query: 'Inline Completer' + }); + } + }); + const widget = new MainAreaWidget({ content }); + widget.id = 'jupyterlab-inline-completions-model'; + widget.title.label = 'AI Completions Model Settings'; + widget.title.closable = true; + widget.title.icon = jupyternautIcon; + return widget; + }; + app.commands.addCommand(CommandIDs.configureModel, { + execute: () => { + if (!settingsWidget || settingsWidget.isDisposed) { + settingsWidget = newSettingsWidget(); + } + if (!settingsWidget.isAttached) { + app.shell.add(settingsWidget, 'main'); + } + app.shell.activateById(settingsWidget.id); + }, + label: 'Configure Jupyternaut Completions Model' + }); + if (statusItem) { statusItem.addItem({ command: CommandIDs.toggleCompletions, @@ -185,6 +226,10 @@ export const completionPlugin: JupyterFrontEndPlugin = { command: CommandIDs.toggleLanguageCompletions, rank: 2 }); + statusItem.addItem({ + command: CommandIDs.configureModel, + rank: 3 + }); } } }; diff --git a/packages/jupyter-ai/src/completions/settings.tsx b/packages/jupyter-ai/src/completions/settings.tsx new file mode 100644 index 00000000..23e0efb2 --- /dev/null +++ b/packages/jupyter-ai/src/completions/settings.tsx @@ -0,0 +1,181 @@ +import { ReactWidget } from '@jupyterlab/ui-components'; +import React, { useState } from 'react'; + +import { Box } from '@mui/system'; +import { Alert, Button, CircularProgress } from '@mui/material'; + +import { AiService } from '../handler'; +import { + ServerInfoState, + useServerInfo +} from '../components/settings/use-server-info'; +import { ModelSettings, IModelSettings } from '../components/model-settings'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import { minifyUpdate } from '../components/settings/minify'; +import { useStackingAlert } from '../components/mui-extras/stacking-alert'; + +type CompleterSettingsProps = { + rmRegistry: IRenderMimeRegistry; + isProviderEnabled: () => boolean; + openInlineCompleterSettings: () => void; +}; + +/** + * Component that returns the settings view. + */ +export function CompleterSettings(props: CompleterSettingsProps): JSX.Element { + // state fetched on initial render + const server = useServerInfo(); + + // initialize alert helper + const alert = useStackingAlert(); + + // whether the form is currently saving + const [saving, setSaving] = useState(false); + + // provider/model settings + const [modelSettings, setModelSettings] = useState({ + fields: {}, + apiKeys: {}, + emGlobalId: null, + lmGlobalId: null + }); + + const handleSave = async () => { + // compress fields with JSON values + if (server.state !== ServerInfoState.Ready) { + return; + } + + const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; + + for (const fieldKey in fields) { + const fieldVal = fields[fieldKey]; + if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { + continue; + } + + try { + const parsedFieldVal = JSON.parse(fieldVal); + const compressedFieldVal = JSON.stringify(parsedFieldVal); + fields[fieldKey] = compressedFieldVal; + } catch (e) { + continue; + } + } + + let updateRequest: AiService.UpdateConfigRequest = { + completions_model_provider_id: lmGlobalId, + completions_embeddings_provider_id: emGlobalId, + api_keys: apiKeys, + ...(lmGlobalId && { + completions_fields: { + [lmGlobalId]: fields + } + }) + }; + updateRequest = minifyUpdate(server.config, updateRequest); + updateRequest.last_read = server.config.last_read; + + setSaving(true); + try { + await AiService.updateConfig(updateRequest); + } catch (e) { + console.error(e); + const msg = + e instanceof Error || typeof e === 'string' + ? e.toString() + : 'An unknown error occurred. Check the console for more details.'; + alert.show('error', msg); + return; + } finally { + setSaving(false); + } + await server.refetchAll(); + alert.show('success', 'Settings saved successfully.'); + }; + + if (server.state === ServerInfoState.Loading) { + return ( + + + + ); + } + + if (server.state === ServerInfoState.Error) { + return ( + + + {server.error || + 'An unknown error occurred. Check the console for more details.'} + + + ); + } + + return ( + + {props.isProviderEnabled() ? null : ( + + The jupyter-ai inline completion provider is not enabled in the Inline + Completer settings. + + + )} + + + + + + + {alert.jsx} + + ); +} + +export class ModelSettingsWidget extends ReactWidget { + constructor(protected options: CompleterSettingsProps) { + super(); + } + render(): JSX.Element { + return ; + } +} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 889342a2..2a2da034 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState, useMemo } from 'react'; +import React, { useEffect, useState } from 'react'; import { Box } from '@mui/system'; import { @@ -7,23 +7,17 @@ import { FormControl, FormControlLabel, FormLabel, - MenuItem, Radio, RadioGroup, - TextField, CircularProgress } from '@mui/material'; -import { Select } from './select'; import { AiService } from '../handler'; -import { ModelFields } from './settings/model-fields'; import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ExistingApiKeys } from './settings/existing-api-keys'; +import { ModelSettings, IModelSettings } from './model-settings'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { minifyUpdate } from './settings/minify'; import { useStackingAlert } from './mui-extras/stacking-alert'; -import { RendermimeMarkdown } from './rendermime-markdown'; -import { getProviderId, getModelLocalId } from '../utils'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; @@ -38,38 +32,21 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { // initialize alert helper const alert = useStackingAlert(); - const apiKeysAlert = useStackingAlert(); // user inputs - const [lmProvider, setLmProvider] = - useState(null); - const [showLmLocalId, setShowLmLocalId] = useState(false); - const [helpMarkdown, setHelpMarkdown] = useState(null); - const [lmLocalId, setLmLocalId] = useState(''); - const lmGlobalId = useMemo(() => { - if (!lmProvider) { - return null; - } - - return lmProvider.id + ':' + lmLocalId; - }, [lmProvider, lmLocalId]); - - const [emGlobalId, setEmGlobalId] = useState(null); - const emProvider = useMemo(() => { - if (emGlobalId === null || server.state !== ServerInfoState.Ready) { - return null; - } - - return getProvider(emGlobalId, server.emProviders); - }, [emGlobalId, server]); - - const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); - const [fields, setFields] = useState>({}); // whether the form is currently saving const [saving, setSaving] = useState(false); + // provider/model settings + const [modelSettings, setModelSettings] = useState({ + fields: {}, + apiKeys: {}, + emGlobalId: null, + lmGlobalId: null + }); + /** * Effect: initialize inputs after fetching server info. */ @@ -77,79 +54,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { if (server.state !== ServerInfoState.Ready) { return; } - - setLmLocalId(server.lmLocalId); - setEmGlobalId(server.config.embeddings_provider_id); setSendWse(server.config.send_with_shift_enter); - setHelpMarkdown(server.lmProvider?.help ?? null); - if (server.lmProvider?.registry) { - setShowLmLocalId(true); - } - setLmProvider(server.lmProvider); }, [server]); - /** - * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. - * Properties with a value of '' indicate necessary user input. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - - const newApiKeys: Record = {}; - const lmAuth = lmProvider?.auth_strategy; - const emAuth = emProvider?.auth_strategy; - if ( - lmAuth?.type === 'env' && - !server.config.api_keys.includes(lmAuth.name) - ) { - newApiKeys[lmAuth.name] = ''; - } - if (lmAuth?.type === 'multienv') { - lmAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - if ( - emAuth?.type === 'env' && - !server.config.api_keys.includes(emAuth.name) - ) { - newApiKeys[emAuth.name] = ''; - } - if (emAuth?.type === 'multienv') { - emAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - setApiKeys(newApiKeys); - }, [lmProvider, emProvider, server]); - - /** - * Effect: re-initialize fields object whenever the selected LM changes. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready || !lmGlobalId) { - return; - } - - const currFields: Record = - server.config.fields?.[lmGlobalId] ?? {}; - setFields(currFields); - }, [server, lmProvider]); - const handleSave = async () => { // compress fields with JSON values if (server.state !== ServerInfoState.Ready) { return; } + const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; + for (const fieldKey in fields) { const fieldVal = fields[fieldKey]; if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { @@ -181,7 +96,6 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { setSaving(true); try { - await apiKeysAlert.clear(); await AiService.updateConfig(updateRequest); } catch (e) { console.error(e); @@ -244,112 +158,11 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { overflowY: 'auto' }} > - {/* Language model section */} -

Language model

- - {showLmLocalId && ( - setLmLocalId(e.target.value)} - fullWidth - /> - )} - {helpMarkdown && ( - - )} - {lmGlobalId && ( - - )} - - {/* Embedding model section */} -

Embedding model

- - - {/* API Keys section */} -

API Keys

- {/* API key inputs for newly-used providers */} - {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( - - setApiKeys(apiKeys => ({ - ...apiKeys, - [apiKeyName]: e.target.value - })) - } - /> - ))} - {/* Pre-existing API keys */} - {/* Input */} @@ -391,12 +204,3 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { ); } - -function getProvider( - globalModelId: string, - providers: AiService.ListProvidersResponse -): AiService.ListProvidersEntry | null { - const providerId = getProviderId(globalModelId); - const provider = providers.providers.find(p => p.id === providerId); - return provider ?? null; -} diff --git a/packages/jupyter-ai/src/components/model-settings.tsx b/packages/jupyter-ai/src/components/model-settings.tsx new file mode 100644 index 00000000..e99aa748 --- /dev/null +++ b/packages/jupyter-ai/src/components/model-settings.tsx @@ -0,0 +1,308 @@ +import React, { useEffect, useState, useMemo } from 'react'; + +import { Box } from '@mui/system'; +import { Alert, MenuItem, TextField, CircularProgress } from '@mui/material'; + +import { Select } from './select'; +import { AiService } from '../handler'; +import { ModelFields } from './settings/model-fields'; +import { ServerInfoState, useServerInfo } from './settings/use-server-info'; +import { ExistingApiKeys } from './settings/existing-api-keys'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import { useStackingAlert } from './mui-extras/stacking-alert'; +import { RendermimeMarkdown } from './rendermime-markdown'; +import { getProviderId, getModelLocalId } from '../utils'; + +type ModelSettingsProps = { + rmRegistry: IRenderMimeRegistry; + label: string; + onChange: (settings: IModelSettings) => void; + modelKind: 'chat' | 'completions'; +}; + +export interface IModelSettings { + fields: Record; + apiKeys: Record; + emGlobalId: string | null; + lmGlobalId: string | null; +} + +/** + * Component that returns the settings view in the chat panel. + */ +export function ModelSettings(props: ModelSettingsProps): JSX.Element { + // state fetched on initial render + const server = useServerInfo(); + + // initialize alert helper + const apiKeysAlert = useStackingAlert(); + + // user inputs + const [lmProvider, setLmProvider] = + useState(null); + const [showLmLocalId, setShowLmLocalId] = useState(false); + const [helpMarkdown, setHelpMarkdown] = useState(null); + const [lmLocalId, setLmLocalId] = useState(''); + const lmGlobalId = useMemo(() => { + if (!lmProvider) { + return null; + } + + return lmProvider.id + ':' + lmLocalId; + }, [lmProvider, lmLocalId]); + + const [emGlobalId, setEmGlobalId] = + useState(null); + const emProvider = useMemo(() => { + if (emGlobalId === null || server.state !== ServerInfoState.Ready) { + return null; + } + + return getProvider(emGlobalId, server.emProviders); + }, [emGlobalId, server]); + + const [apiKeys, setApiKeys] = useState({}); + const [fields, setFields] = useState({}); + + /** + * Effect: initialize inputs after fetching server info. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready) { + return; + } + const kind = props.modelKind; + + setLmLocalId(server[kind].lmLocalId); + setEmGlobalId( + kind === 'chat' + ? server.config.embeddings_provider_id + : server.config.completions_embeddings_provider_id + ); + setHelpMarkdown(server[kind].lmProvider?.help ?? null); + if (server[kind].lmProvider?.registry) { + setShowLmLocalId(true); + } + setLmProvider(server[kind].lmProvider); + }, [server]); + + /** + * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. + * Properties with a value of '' indicate necessary user input. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready) { + return; + } + + const newApiKeys: Record = {}; + const lmAuth = lmProvider?.auth_strategy; + const emAuth = emProvider?.auth_strategy; + if ( + lmAuth?.type === 'env' && + !server.config.api_keys.includes(lmAuth.name) + ) { + newApiKeys[lmAuth.name] = ''; + } + if (lmAuth?.type === 'multienv') { + lmAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + if ( + emAuth?.type === 'env' && + !server.config.api_keys.includes(emAuth.name) + ) { + newApiKeys[emAuth.name] = ''; + } + if (emAuth?.type === 'multienv') { + emAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + setApiKeys(newApiKeys); + }, [lmProvider, emProvider, server]); + + /** + * Effect: re-initialize fields object whenever the selected LM changes. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready || !lmGlobalId) { + return; + } + + const currFields: Record = + server.config.fields?.[lmGlobalId] ?? {}; + setFields(currFields); + }, [server, lmProvider]); + + useEffect(() => { + props.onChange({ + fields, + apiKeys, + lmGlobalId, + emGlobalId + }); + }, [lmProvider, emProvider, apiKeys, fields]); + + if (server.state === ServerInfoState.Loading) { + return ( + + + + ); + } + + if (server.state === ServerInfoState.Error) { + return ( + <> + + {server.error || + 'An unknown error occurred. Check the console for more details.'} + + + ); + } + + return ( + <> + {/* Language model section */} +

{props.label}

+ + {showLmLocalId && ( + setLmLocalId(e.target.value)} + fullWidth + /> + )} + {helpMarkdown && ( + + )} + {lmGlobalId && ( + + )} + + {/* Embedding model section */} +

Embedding model

+ + + {/* API Keys section */} +

API Keys

+ {/* API key inputs for newly-used providers */} + {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( + + setApiKeys(apiKeys => ({ + ...apiKeys, + [apiKeyName]: e.target.value + })) + } + /> + ))} + {/* Pre-existing API keys */} + + + ); +} + +function getProvider( + globalModelId: string, + providers: AiService.ListProvidersResponse +): AiService.ListProvidersEntry | null { + const providerId = getProviderId(globalModelId); + const provider = providers.providers.find(p => p.id === providerId); + return provider ?? null; +} diff --git a/packages/jupyter-ai/src/components/settings/use-server-info.ts b/packages/jupyter-ai/src/components/settings/use-server-info.ts index 80c77fbe..3695bfe4 100644 --- a/packages/jupyter-ai/src/components/settings/use-server-info.ts +++ b/packages/jupyter-ai/src/components/settings/use-server-info.ts @@ -2,15 +2,20 @@ import { useState, useEffect, useMemo, useCallback } from 'react'; import { AiService } from '../../handler'; import { getProviderId, getModelLocalId } from '../../utils'; -type ServerInfoProperties = { - config: AiService.DescribeConfigResponse; - lmProviders: AiService.ListProvidersResponse; - emProviders: AiService.ListProvidersResponse; +type ProvidersInfo = { lmProvider: AiService.ListProvidersEntry | null; emProvider: AiService.ListProvidersEntry | null; lmLocalId: string; }; +type ServerInfoProperties = { + lmProviders: AiService.ListProvidersResponse; + emProviders: AiService.ListProvidersResponse; + config: AiService.DescribeConfigResponse; + chat: ProvidersInfo; + completions: ProvidersInfo; +}; + type ServerInfoMethods = { refetchAll: () => Promise; refetchApiKeys: () => Promise; @@ -65,13 +70,29 @@ export function useServerInfo(): ServerInfo { const emProvider = emGid === null ? null : getProvider(emGid, emProviders); const lmLocalId = (lmGid && getModelLocalId(lmGid)) ?? ''; + + const cLmGid = config.completions_model_provider_id; + const cEmGid = config.completions_embeddings_provider_id; + const cLmProvider = + cLmGid === null ? null : getProvider(cLmGid, lmProviders); + const cEmProvider = + cEmGid === null ? null : getProvider(cEmGid, emProviders); + const cLmLocalId = (cLmGid && getModelLocalId(cLmGid)) ?? ''; + setServerInfoProps({ config, lmProviders, emProviders, - lmProvider, - emProvider, - lmLocalId + chat: { + lmProvider, + emProvider, + lmLocalId + }, + completions: { + lmProvider: cLmProvider, + emProvider: cEmProvider, + lmLocalId: cLmLocalId + } }); setState(ServerInfoState.Ready); diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 7848dc20..600a257f 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -117,6 +117,8 @@ export namespace AiService { send_with_shift_enter: boolean; fields: Record>; last_read: number; + completions_model_provider_id: string | null; + completions_embeddings_provider_id: string | null; }; export type UpdateConfigRequest = { @@ -126,6 +128,9 @@ export namespace AiService { send_with_shift_enter?: boolean; fields?: Record>; last_read?: number; + completions_model_provider_id?: string | null; + completions_embeddings_provider_id?: string | null; + completions_fields?: Record>; }; export async function getConfig(): Promise { @@ -182,6 +187,8 @@ export namespace AiService { help?: string; auth_strategy: AuthStrategy; registry: boolean; + completion_models: string[]; + chat_models: string[]; fields: Field[]; }; From 2c196ef897903cd660a210856362df9fdb971a7b Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:15:12 +0100 Subject: [PATCH 02/19] Fix tests --- .../jupyter_ai/tests/__snapshots__/test_config_manager.ambr | 4 ++++ .../jupyter-ai/jupyter_ai/tests/completions/test_handlers.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr index cd254bed..b5d75bdc 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr +++ b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr @@ -3,6 +3,10 @@ dict({ 'api_keys': list([ ]), + 'completions_embeddings_provider_id': None, + 'completions_fields': dict({ + }), + 'completions_model_provider_id': None, 'embeddings_provider_id': None, 'fields': dict({ }), diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 1b950af7..0028356a 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -28,7 +28,8 @@ def __init__(self): self.messages = [] self.tasks = [] self.settings["jai_config_manager"] = SimpleNamespace( - lm_provider=MockProvider, lm_provider_params={"model_id": "model"} + completions_lm_provider=MockProvider, + completions_lm_provider_params={"model_id": "model"}, ) self.settings["jai_event_loop"] = SimpleNamespace( create_task=lambda x: self.tasks.append(x) From 578f63c071fa7996b500a9795bfd200e2b87d4e3 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Mon, 8 Apr 2024 16:44:35 +0100 Subject: [PATCH 03/19] Shorten the tab name, move settings button Lint --- packages/jupyter-ai/src/completions/plugin.ts | 2 +- packages/jupyter-ai/src/completions/settings.tsx | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index c2519fd9..3b2b6f2c 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -199,7 +199,7 @@ export const completionPlugin: JupyterFrontEndPlugin = { }); const widget = new MainAreaWidget({ content }); widget.id = 'jupyterlab-inline-completions-model'; - widget.title.label = 'AI Completions Model Settings'; + widget.title.label = 'Completer Model Settings'; widget.title.closable = true; widget.title.icon = jupyternautIcon; return widget; diff --git a/packages/jupyter-ai/src/completions/settings.tsx b/packages/jupyter-ai/src/completions/settings.tsx index 23e0efb2..fea8d7c8 100644 --- a/packages/jupyter-ai/src/completions/settings.tsx +++ b/packages/jupyter-ai/src/completions/settings.tsx @@ -145,12 +145,6 @@ export function CompleterSettings(props: CompleterSettingsProps): JSX.Element { The jupyter-ai inline completion provider is not enabled in the Inline Completer settings. - )} @@ -161,7 +155,10 @@ export function CompleterSettings(props: CompleterSettingsProps): JSX.Element { modelKind="completions" /> - + + From 575f3f48cc854cd247459eeea5357d266171577a Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 17 Apr 2024 12:54:11 +0100 Subject: [PATCH 04/19] Implement the completion model selection in chat UI --- packages/jupyter-ai/src/completions/plugin.ts | 296 ++++++-------- .../jupyter-ai/src/completions/provider.ts | 12 +- .../jupyter-ai/src/completions/settings.tsx | 178 --------- .../src/components/chat-settings.tsx | 378 +++++++++++++++++- packages/jupyter-ai/src/components/chat.tsx | 9 +- .../src/components/model-settings.tsx | 308 -------------- packages/jupyter-ai/src/index.ts | 21 +- packages/jupyter-ai/src/tokens.ts | 14 + .../jupyter-ai/src/widgets/chat-sidebar.tsx | 7 +- 9 files changed, 543 insertions(+), 680 deletions(-) delete mode 100644 packages/jupyter-ai/src/completions/settings.tsx delete mode 100644 packages/jupyter-ai/src/components/model-settings.tsx diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index 3b2b6f2c..adf3db55 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -4,18 +4,14 @@ import { } from '@jupyterlab/application'; import { ICompletionProviderManager } from '@jupyterlab/completer'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; -import { MainAreaWidget } from '@jupyterlab/apputils'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { IEditorLanguageRegistry, IEditorLanguage } from '@jupyterlab/codemirror'; import { getEditor } from '../selection-watcher'; -import { IJaiStatusItem } from '../tokens'; +import { IJaiStatusItem, IJaiCompletionProvider } from '../tokens'; import { displayName, JaiInlineProvider } from './provider'; import { CompletionWebsocketHandler } from './handler'; -import { jupyternautIcon } from '../icons'; -import { ModelSettingsWidget } from './settings'; export namespace CommandIDs { /** @@ -27,10 +23,6 @@ export namespace CommandIDs { */ export const toggleLanguageCompletions = 'jupyter-ai:toggle-language-completions'; - /** - * Command to open provider/model configuration. - */ - export const configureModel = 'jupyter-ai:configure-completions'; } const INLINE_COMPLETER_PLUGIN = @@ -54,182 +46,148 @@ type IcPluginSettings = ISettingRegistry.ISettings & { }; }; -export const completionPlugin: JupyterFrontEndPlugin = { - id: 'jupyter_ai:inline-completions', - autoStart: true, - requires: [ - ICompletionProviderManager, - IEditorLanguageRegistry, - ISettingRegistry, - IRenderMimeRegistry - ], - optional: [IJaiStatusItem], - activate: async ( - app: JupyterFrontEnd, - completionManager: ICompletionProviderManager, - languageRegistry: IEditorLanguageRegistry, - settingRegistry: ISettingRegistry, - rmRegistry: IRenderMimeRegistry, - statusItem: IJaiStatusItem | null - ): Promise => { - if (typeof completionManager.registerInlineProvider === 'undefined') { - // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 - console.warn( - 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' - ); - return; - } - - const completionHandler = new CompletionWebsocketHandler(); - const provider = new JaiInlineProvider({ - completionHandler, - languageRegistry - }); - - await completionHandler.initialize(); - completionManager.registerInlineProvider(provider); - - const findCurrentLanguage = (): IEditorLanguage | null => { - const widget = app.shell.currentWidget; - const editor = getEditor(widget); - if (!editor) { +export const completionPlugin: JupyterFrontEndPlugin = + { + id: 'jupyter_ai:inline-completions', + autoStart: true, + requires: [ + ICompletionProviderManager, + IEditorLanguageRegistry, + ISettingRegistry + ], + optional: [IJaiStatusItem], + provides: IJaiCompletionProvider, + activate: async ( + app: JupyterFrontEnd, + completionManager: ICompletionProviderManager, + languageRegistry: IEditorLanguageRegistry, + settingRegistry: ISettingRegistry, + statusItem: IJaiStatusItem | null + ): Promise => { + if (typeof completionManager.registerInlineProvider === 'undefined') { + // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 + console.warn( + 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' + ); return null; } - return languageRegistry.findByMIME(editor.model.mimeType); - }; - // ic := inline completion - async function getIcSettings() { - return (await settingRegistry.load( - INLINE_COMPLETER_PLUGIN - )) as IcPluginSettings; - } + const completionHandler = new CompletionWebsocketHandler(); + const provider = new JaiInlineProvider({ + completionHandler, + languageRegistry + }); - /** - * Gets the composite settings for the Jupyter AI inline completion provider - * (JaiIcp). - * - * This reads from the `ISettings.composite` property, which merges the user - * settings with the provider defaults, defined in - * `JaiInlineProvider.DEFAULT_SETTINGS`. - */ - async function getJaiIcpSettings() { - const icSettings = await getIcSettings(); - return icSettings.composite.providers[JaiInlineProvider.ID]; - } + await completionHandler.initialize(); + completionManager.registerInlineProvider(provider); - /** - * Updates the JaiIcp user settings. - */ - async function updateJaiIcpSettings( - newJaiIcpSettings: Partial - ) { - const icSettings = await getIcSettings(); - const oldUserIcpSettings = icSettings.user.providers; - const newUserIcpSettings = { - ...oldUserIcpSettings, - [JaiInlineProvider.ID]: { - ...oldUserIcpSettings?.[JaiInlineProvider.ID], - ...newJaiIcpSettings + const findCurrentLanguage = (): IEditorLanguage | null => { + const widget = app.shell.currentWidget; + const editor = getEditor(widget); + if (!editor) { + return null; } + return languageRegistry.findByMIME(editor.model.mimeType); }; - icSettings.set('providers', newUserIcpSettings); - } - app.commands.addCommand(CommandIDs.toggleCompletions, { - execute: async () => { - const jaiIcpSettings = await getJaiIcpSettings(); - updateJaiIcpSettings({ - enabled: !jaiIcpSettings.enabled - }); - }, - label: 'Enable completions by Jupyternaut', - isToggled: () => { - return provider.isEnabled(); + // ic := inline completion + async function getIcSettings() { + return (await settingRegistry.load( + INLINE_COMPLETER_PLUGIN + )) as IcPluginSettings; } - }); - app.commands.addCommand(CommandIDs.toggleLanguageCompletions, { - execute: async () => { - const jaiIcpSettings = await getJaiIcpSettings(); - const language = findCurrentLanguage(); - if (!language) { - return; - } - - const disabledLanguages = [...jaiIcpSettings.disabledLanguages]; - const newDisabledLanguages = disabledLanguages.includes(language.name) - ? disabledLanguages.filter(l => l !== language.name) - : disabledLanguages.concat(language.name); + /** + * Gets the composite settings for the Jupyter AI inline completion provider + * (JaiIcp). + * + * This reads from the `ISettings.composite` property, which merges the user + * settings with the provider defaults, defined in + * `JaiInlineProvider.DEFAULT_SETTINGS`. + */ + async function getJaiIcpSettings() { + const icSettings = await getIcSettings(); + return icSettings.composite.providers[JaiInlineProvider.ID]; + } - updateJaiIcpSettings({ - disabledLanguages: newDisabledLanguages - }); - }, - label: () => { - const language = findCurrentLanguage(); - return language - ? `Disable completions in ${displayName(language)}` - : 'Disable completions in files'; - }, - isToggled: () => { - const language = findCurrentLanguage(); - return !!language && !provider.isLanguageEnabled(language.name); - }, - isVisible: () => { - const language = findCurrentLanguage(); - return !!language; - }, - isEnabled: () => { - const language = findCurrentLanguage(); - return !!language && provider.isEnabled(); + /** + * Updates the JaiIcp user settings. + */ + async function updateJaiIcpSettings( + newJaiIcpSettings: Partial + ) { + const icSettings = await getIcSettings(); + const oldUserIcpSettings = icSettings.user.providers; + const newUserIcpSettings = { + ...oldUserIcpSettings, + [JaiInlineProvider.ID]: { + ...oldUserIcpSettings?.[JaiInlineProvider.ID], + ...newJaiIcpSettings + } + }; + icSettings.set('providers', newUserIcpSettings); } - }); - let settingsWidget: MainAreaWidget | null = null; - const newSettingsWidget = () => { - const content = new ModelSettingsWidget({ - rmRegistry, - isProviderEnabled: () => provider.isEnabled(), - openInlineCompleterSettings: () => { - app.commands.execute('settingeditor:open', { - query: 'Inline Completer' + app.commands.addCommand(CommandIDs.toggleCompletions, { + execute: async () => { + const jaiIcpSettings = await getJaiIcpSettings(); + updateJaiIcpSettings({ + enabled: !jaiIcpSettings.enabled }); + }, + label: 'Enable completions by Jupyternaut', + isToggled: () => { + return provider.isEnabled(); } }); - const widget = new MainAreaWidget({ content }); - widget.id = 'jupyterlab-inline-completions-model'; - widget.title.label = 'Completer Model Settings'; - widget.title.closable = true; - widget.title.icon = jupyternautIcon; - return widget; - }; - app.commands.addCommand(CommandIDs.configureModel, { - execute: () => { - if (!settingsWidget || settingsWidget.isDisposed) { - settingsWidget = newSettingsWidget(); - } - if (!settingsWidget.isAttached) { - app.shell.add(settingsWidget, 'main'); - } - app.shell.activateById(settingsWidget.id); - }, - label: 'Configure Jupyternaut Completions Model' - }); - if (statusItem) { - statusItem.addItem({ - command: CommandIDs.toggleCompletions, - rank: 1 - }); - statusItem.addItem({ - command: CommandIDs.toggleLanguageCompletions, - rank: 2 - }); - statusItem.addItem({ - command: CommandIDs.configureModel, - rank: 3 + app.commands.addCommand(CommandIDs.toggleLanguageCompletions, { + execute: async () => { + const jaiIcpSettings = await getJaiIcpSettings(); + const language = findCurrentLanguage(); + if (!language) { + return; + } + + const disabledLanguages = [...jaiIcpSettings.disabledLanguages]; + const newDisabledLanguages = disabledLanguages.includes(language.name) + ? disabledLanguages.filter(l => l !== language.name) + : disabledLanguages.concat(language.name); + + updateJaiIcpSettings({ + disabledLanguages: newDisabledLanguages + }); + }, + label: () => { + const language = findCurrentLanguage(); + return language + ? `Disable completions in ${displayName(language)}` + : 'Disable completions in files'; + }, + isToggled: () => { + const language = findCurrentLanguage(); + return !!language && !provider.isLanguageEnabled(language.name); + }, + isVisible: () => { + const language = findCurrentLanguage(); + return !!language; + }, + isEnabled: () => { + const language = findCurrentLanguage(); + return !!language && provider.isEnabled(); + } }); + + if (statusItem) { + statusItem.addItem({ + command: CommandIDs.toggleCompletions, + rank: 1 + }); + statusItem.addItem({ + command: CommandIDs.toggleLanguageCompletions, + rank: 2 + }); + } + return provider; } - } -}; + }; diff --git a/packages/jupyter-ai/src/completions/provider.ts b/packages/jupyter-ai/src/completions/provider.ts index b1e0f6a8..786ced85 100644 --- a/packages/jupyter-ai/src/completions/provider.ts +++ b/packages/jupyter-ai/src/completions/provider.ts @@ -9,11 +9,13 @@ import { import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { Notification, showErrorMessage } from '@jupyterlab/apputils'; import { JSONValue, PromiseDelegate } from '@lumino/coreutils'; +import { ISignal, Signal } from '@lumino/signaling'; import { IEditorLanguageRegistry, IEditorLanguage } from '@jupyterlab/codemirror'; import { NotebookPanel } from '@jupyterlab/notebook'; +import { IJaiCompletionProvider } from '../tokens'; import { AiCompleterService as AiService } from './types'; import { DocumentWidget } from '@jupyterlab/docregistry'; import { jupyternautIcon } from '../icons'; @@ -34,7 +36,9 @@ export function displayName(language: IEditorLanguage): string { return language.displayName ?? language.name; } -export class JaiInlineProvider implements IInlineCompletionProvider { +export class JaiInlineProvider + implements IInlineCompletionProvider, IJaiCompletionProvider +{ readonly identifier = JaiInlineProvider.ID; readonly icon = jupyternautIcon.bindprops({ width: 16, top: 1 }); @@ -181,6 +185,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { async configure(settings: { [property: string]: JSONValue }): Promise { this._settings = settings as unknown as JaiInlineProvider.ISettings; + this._settingsChanged.emit(); } isEnabled(): boolean { @@ -191,6 +196,10 @@ export class JaiInlineProvider implements IInlineCompletionProvider { return !this._settings.disabledLanguages.includes(language); } + get settingsChanged(): ISignal { + return this._settingsChanged; + } + /** * Process the stream chunk to make it available in the awaiting generator. */ @@ -250,6 +259,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { private _settings: JaiInlineProvider.ISettings = JaiInlineProvider.DEFAULT_SETTINGS; + private _settingsChanged = new Signal(this); private _streamPromises: Map> = new Map(); diff --git a/packages/jupyter-ai/src/completions/settings.tsx b/packages/jupyter-ai/src/completions/settings.tsx deleted file mode 100644 index fea8d7c8..00000000 --- a/packages/jupyter-ai/src/completions/settings.tsx +++ /dev/null @@ -1,178 +0,0 @@ -import { ReactWidget } from '@jupyterlab/ui-components'; -import React, { useState } from 'react'; - -import { Box } from '@mui/system'; -import { Alert, Button, CircularProgress } from '@mui/material'; - -import { AiService } from '../handler'; -import { - ServerInfoState, - useServerInfo -} from '../components/settings/use-server-info'; -import { ModelSettings, IModelSettings } from '../components/model-settings'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; -import { minifyUpdate } from '../components/settings/minify'; -import { useStackingAlert } from '../components/mui-extras/stacking-alert'; - -type CompleterSettingsProps = { - rmRegistry: IRenderMimeRegistry; - isProviderEnabled: () => boolean; - openInlineCompleterSettings: () => void; -}; - -/** - * Component that returns the settings view. - */ -export function CompleterSettings(props: CompleterSettingsProps): JSX.Element { - // state fetched on initial render - const server = useServerInfo(); - - // initialize alert helper - const alert = useStackingAlert(); - - // whether the form is currently saving - const [saving, setSaving] = useState(false); - - // provider/model settings - const [modelSettings, setModelSettings] = useState({ - fields: {}, - apiKeys: {}, - emGlobalId: null, - lmGlobalId: null - }); - - const handleSave = async () => { - // compress fields with JSON values - if (server.state !== ServerInfoState.Ready) { - return; - } - - const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; - - for (const fieldKey in fields) { - const fieldVal = fields[fieldKey]; - if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { - continue; - } - - try { - const parsedFieldVal = JSON.parse(fieldVal); - const compressedFieldVal = JSON.stringify(parsedFieldVal); - fields[fieldKey] = compressedFieldVal; - } catch (e) { - continue; - } - } - - let updateRequest: AiService.UpdateConfigRequest = { - completions_model_provider_id: lmGlobalId, - completions_embeddings_provider_id: emGlobalId, - api_keys: apiKeys, - ...(lmGlobalId && { - completions_fields: { - [lmGlobalId]: fields - } - }) - }; - updateRequest = minifyUpdate(server.config, updateRequest); - updateRequest.last_read = server.config.last_read; - - setSaving(true); - try { - await AiService.updateConfig(updateRequest); - } catch (e) { - console.error(e); - const msg = - e instanceof Error || typeof e === 'string' - ? e.toString() - : 'An unknown error occurred. Check the console for more details.'; - alert.show('error', msg); - return; - } finally { - setSaving(false); - } - await server.refetchAll(); - alert.show('success', 'Settings saved successfully.'); - }; - - if (server.state === ServerInfoState.Loading) { - return ( - - - - ); - } - - if (server.state === ServerInfoState.Error) { - return ( - - - {server.error || - 'An unknown error occurred. Check the console for more details.'} - - - ); - } - - return ( - - {props.isProviderEnabled() ? null : ( - - The jupyter-ai inline completion provider is not enabled in the Inline - Completer settings. - - )} - - - - - - - - {alert.jsx} - - ); -} - -export class ModelSettingsWidget extends ReactWidget { - constructor(protected options: CompleterSettingsProps) { - super(); - } - render(): JSX.Element { - return ; - } -} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 2a2da034..0d811ca7 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -1,26 +1,39 @@ -import React, { useEffect, useState } from 'react'; +import React, { useEffect, useState, useMemo } from 'react'; import { Box } from '@mui/system'; import { Alert, Button, + IconButton, FormControl, FormControlLabel, FormLabel, + MenuItem, Radio, RadioGroup, + TextField, CircularProgress } from '@mui/material'; +import SettingsIcon from '@mui/icons-material/Settings'; +import WarningAmberIcon from '@mui/icons-material/WarningAmber'; +import { UseSignal } from '@jupyterlab/ui-components'; +import { Select } from './select'; import { AiService } from '../handler'; +import { ModelFields } from './settings/model-fields'; import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ModelSettings, IModelSettings } from './model-settings'; +import { ExistingApiKeys } from './settings/existing-api-keys'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { minifyUpdate } from './settings/minify'; import { useStackingAlert } from './mui-extras/stacking-alert'; +import { RendermimeMarkdown } from './rendermime-markdown'; +import { IJaiCompletionProvider } from '../tokens'; +import { getProviderId, getModelLocalId } from '../utils'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; /** @@ -32,21 +45,53 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { // initialize alert helper const alert = useStackingAlert(); + const apiKeysAlert = useStackingAlert(); // user inputs + const [lmProvider, setLmProvider] = + useState(null); + const [clmProvider, setClmProvider] = + useState(null); + const [showLmLocalId, setShowLmLocalId] = useState(false); + const [showClmLocalId, setShowClmLocalId] = useState(false); + const [chatHelpMarkdown, setChatHelpMarkdown] = useState(null); + const [completionHelpMarkdown, setCompletionHelpMarkdown] = useState< + string | null + >(null); + const [lmLocalId, setLmLocalId] = useState(''); + const [clmLocalId, setClmLocalId] = useState(''); + + const lmGlobalId = useMemo(() => { + if (!lmProvider) { + return null; + } + + return lmProvider.id + ':' + lmLocalId; + }, [lmProvider, lmLocalId]); + const clmGlobalId = useMemo(() => { + if (!clmProvider) { + return null; + } + + return clmProvider.id + ':' + clmLocalId; + }, [clmProvider, clmLocalId]); + + const [emGlobalId, setEmGlobalId] = useState(null); + const emProvider = useMemo(() => { + if (emGlobalId === null || server.state !== ServerInfoState.Ready) { + return null; + } + + return getProvider(emGlobalId, server.emProviders); + }, [emGlobalId, server]); + + const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); + const [fields, setFields] = useState>({}); // whether the form is currently saving const [saving, setSaving] = useState(false); - // provider/model settings - const [modelSettings, setModelSettings] = useState({ - fields: {}, - apiKeys: {}, - emGlobalId: null, - lmGlobalId: null - }); - /** * Effect: initialize inputs after fetching server info. */ @@ -54,17 +99,85 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { if (server.state !== ServerInfoState.Ready) { return; } + + setLmLocalId(server.chat.lmLocalId); + setClmLocalId(server.completions.lmLocalId); + setEmGlobalId(server.config.embeddings_provider_id); setSendWse(server.config.send_with_shift_enter); + setChatHelpMarkdown(server.chat.lmProvider?.help ?? null); + setCompletionHelpMarkdown(server.completions.lmProvider?.help ?? null); + if (server.chat.lmProvider?.registry) { + setShowLmLocalId(true); + } + if (server.completions.lmProvider?.registry) { + setShowClmLocalId(true); + } + setLmProvider(server.chat.lmProvider); + setClmProvider(server.completions.lmProvider); }, [server]); + /** + * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. + * Properties with a value of '' indicate necessary user input. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready) { + return; + } + + const newApiKeys: Record = {}; + const lmAuth = lmProvider?.auth_strategy; + const emAuth = emProvider?.auth_strategy; + if ( + lmAuth?.type === 'env' && + !server.config.api_keys.includes(lmAuth.name) + ) { + newApiKeys[lmAuth.name] = ''; + } + if (lmAuth?.type === 'multienv') { + lmAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + if ( + emAuth?.type === 'env' && + !server.config.api_keys.includes(emAuth.name) + ) { + newApiKeys[emAuth.name] = ''; + } + if (emAuth?.type === 'multienv') { + emAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + setApiKeys(newApiKeys); + }, [lmProvider, emProvider, server]); + + /** + * Effect: re-initialize fields object whenever the selected LM changes. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready || !lmGlobalId) { + return; + } + + const currFields: Record = + server.config.fields?.[lmGlobalId] ?? {}; + setFields(currFields); + }, [server, lmProvider]); + const handleSave = async () => { // compress fields with JSON values if (server.state !== ServerInfoState.Ready) { return; } - const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; - for (const fieldKey in fields) { const fieldVal = fields[fieldKey]; if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { @@ -84,11 +197,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { model_provider_id: lmGlobalId, embeddings_provider_id: emGlobalId, api_keys: apiKeys, - ...(lmGlobalId && { + ...((lmGlobalId || clmGlobalId) && { fields: { - [lmGlobalId]: fields + ...(lmGlobalId && { + [lmGlobalId]: fields + }), + ...(clmGlobalId && { + [clmGlobalId]: fields + }) } }), + completions_model_provider_id: clmGlobalId, send_with_shift_enter: sendWse }; updateRequest = minifyUpdate(server.config, updateRequest); @@ -96,6 +215,7 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { setSaving(true); try { + await apiKeysAlert.clear(); await AiService.updateConfig(updateRequest); } catch (e) { console.error(e); @@ -158,11 +278,195 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { overflowY: 'auto' }} > - Chat language model + + {showLmLocalId && ( + setLmLocalId(e.target.value)} + fullWidth + /> + )} + {chatHelpMarkdown && ( + + )} + {lmGlobalId && ( + + )} + + {/* Embedding model section */} +

Embedding model

+ + + {/* Completer language model section */} +

+ Completer model + {props.completionProvider ? ( + + {(): JSX.Element => ( + + )} + + ) : ( + + )} +

+ + {showClmLocalId && ( + setClmLocalId(e.target.value)} + fullWidth + /> + )} + {completionHelpMarkdown && ( + + )} + {clmGlobalId && ( + + )} + + {/* API Keys section */} +

API Keys

+ {/* API key inputs for newly-used providers */} + {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( + + setApiKeys(apiKeys => ({ + ...apiKeys, + [apiKeyName]: e.target.value + })) + } + /> + ))} + {/* Pre-existing API keys */} + {/* Input */} @@ -204,3 +508,39 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
); } + +function CompleterSettingsButton(props: { + selection: AiService.ListProvidersEntry | null; + provider: IJaiCompletionProvider | null; + openSettings: () => void; +}): JSX.Element { + if (props.selection && !props.provider?.isEnabled()) { + return ( + + + + ); + } + return ( + + + + ); +} + +function getProvider( + globalModelId: string, + providers: AiService.ListProvidersResponse +): AiService.ListProvidersEntry | null { + const providerId = getProviderId(globalModelId); + const provider = providers.providers.find(p => p.id === providerId); + return provider ?? null; +} diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index c33d36f2..721d157f 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -18,6 +18,7 @@ import { import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; import { CollaboratorsContextProvider } from '../contexts/collaborators-context'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ScrollContainer } from './scroll-container'; @@ -185,6 +186,8 @@ export type ChatProps = { themeManager: IThemeManager | null; rmRegistry: IRenderMimeRegistry; chatView?: ChatView; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; enum ChatView { @@ -237,7 +240,11 @@ export function Chat(props: ChatProps): JSX.Element { /> )} {view === ChatView.Settings && ( - + )}
diff --git a/packages/jupyter-ai/src/components/model-settings.tsx b/packages/jupyter-ai/src/components/model-settings.tsx deleted file mode 100644 index e99aa748..00000000 --- a/packages/jupyter-ai/src/components/model-settings.tsx +++ /dev/null @@ -1,308 +0,0 @@ -import React, { useEffect, useState, useMemo } from 'react'; - -import { Box } from '@mui/system'; -import { Alert, MenuItem, TextField, CircularProgress } from '@mui/material'; - -import { Select } from './select'; -import { AiService } from '../handler'; -import { ModelFields } from './settings/model-fields'; -import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ExistingApiKeys } from './settings/existing-api-keys'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; -import { useStackingAlert } from './mui-extras/stacking-alert'; -import { RendermimeMarkdown } from './rendermime-markdown'; -import { getProviderId, getModelLocalId } from '../utils'; - -type ModelSettingsProps = { - rmRegistry: IRenderMimeRegistry; - label: string; - onChange: (settings: IModelSettings) => void; - modelKind: 'chat' | 'completions'; -}; - -export interface IModelSettings { - fields: Record; - apiKeys: Record; - emGlobalId: string | null; - lmGlobalId: string | null; -} - -/** - * Component that returns the settings view in the chat panel. - */ -export function ModelSettings(props: ModelSettingsProps): JSX.Element { - // state fetched on initial render - const server = useServerInfo(); - - // initialize alert helper - const apiKeysAlert = useStackingAlert(); - - // user inputs - const [lmProvider, setLmProvider] = - useState(null); - const [showLmLocalId, setShowLmLocalId] = useState(false); - const [helpMarkdown, setHelpMarkdown] = useState(null); - const [lmLocalId, setLmLocalId] = useState(''); - const lmGlobalId = useMemo(() => { - if (!lmProvider) { - return null; - } - - return lmProvider.id + ':' + lmLocalId; - }, [lmProvider, lmLocalId]); - - const [emGlobalId, setEmGlobalId] = - useState(null); - const emProvider = useMemo(() => { - if (emGlobalId === null || server.state !== ServerInfoState.Ready) { - return null; - } - - return getProvider(emGlobalId, server.emProviders); - }, [emGlobalId, server]); - - const [apiKeys, setApiKeys] = useState({}); - const [fields, setFields] = useState({}); - - /** - * Effect: initialize inputs after fetching server info. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - const kind = props.modelKind; - - setLmLocalId(server[kind].lmLocalId); - setEmGlobalId( - kind === 'chat' - ? server.config.embeddings_provider_id - : server.config.completions_embeddings_provider_id - ); - setHelpMarkdown(server[kind].lmProvider?.help ?? null); - if (server[kind].lmProvider?.registry) { - setShowLmLocalId(true); - } - setLmProvider(server[kind].lmProvider); - }, [server]); - - /** - * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. - * Properties with a value of '' indicate necessary user input. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - - const newApiKeys: Record = {}; - const lmAuth = lmProvider?.auth_strategy; - const emAuth = emProvider?.auth_strategy; - if ( - lmAuth?.type === 'env' && - !server.config.api_keys.includes(lmAuth.name) - ) { - newApiKeys[lmAuth.name] = ''; - } - if (lmAuth?.type === 'multienv') { - lmAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - if ( - emAuth?.type === 'env' && - !server.config.api_keys.includes(emAuth.name) - ) { - newApiKeys[emAuth.name] = ''; - } - if (emAuth?.type === 'multienv') { - emAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - setApiKeys(newApiKeys); - }, [lmProvider, emProvider, server]); - - /** - * Effect: re-initialize fields object whenever the selected LM changes. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready || !lmGlobalId) { - return; - } - - const currFields: Record = - server.config.fields?.[lmGlobalId] ?? {}; - setFields(currFields); - }, [server, lmProvider]); - - useEffect(() => { - props.onChange({ - fields, - apiKeys, - lmGlobalId, - emGlobalId - }); - }, [lmProvider, emProvider, apiKeys, fields]); - - if (server.state === ServerInfoState.Loading) { - return ( - - - - ); - } - - if (server.state === ServerInfoState.Error) { - return ( - <> - - {server.error || - 'An unknown error occurred. Check the console for more details.'} - - - ); - } - - return ( - <> - {/* Language model section */} -

{props.label}

- - {showLmLocalId && ( - setLmLocalId(e.target.value)} - fullWidth - /> - )} - {helpMarkdown && ( - - )} - {lmGlobalId && ( - - )} - - {/* Embedding model section */} -

Embedding model

- - - {/* API Keys section */} -

API Keys

- {/* API key inputs for newly-used providers */} - {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( - - setApiKeys(apiKeys => ({ - ...apiKeys, - [apiKeyName]: e.target.value - })) - } - /> - ))} - {/* Pre-existing API keys */} - - - ); -} - -function getProvider( - globalModelId: string, - providers: AiService.ListProvidersResponse -): AiService.ListProvidersEntry | null { - const providerId = getProviderId(globalModelId); - const provider = providers.providers.find(p => p.id === providerId); - return provider ?? null; -} diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index df2ac69d..d07a0dc3 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -18,6 +18,7 @@ import { ChatHandler } from './chat_handler'; import { buildErrorWidget } from './widgets/chat-error'; import { completionPlugin } from './completions'; import { statusItemPlugin } from './status'; +import { IJaiCompletionProvider } from './tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export type DocumentTracker = IWidgetTracker; @@ -28,14 +29,20 @@ export type DocumentTracker = IWidgetTracker; const plugin: JupyterFrontEndPlugin = { id: 'jupyter_ai:plugin', autoStart: true, - optional: [IGlobalAwareness, ILayoutRestorer, IThemeManager], + optional: [ + IGlobalAwareness, + ILayoutRestorer, + IThemeManager, + IJaiCompletionProvider + ], requires: [IRenderMimeRegistry], activate: async ( app: JupyterFrontEnd, rmRegistry: IRenderMimeRegistry, globalAwareness: Awareness | null, restorer: ILayoutRestorer | null, - themeManager: IThemeManager | null + themeManager: IThemeManager | null, + completionProvider: IJaiCompletionProvider | null ) => { /** * Initialize selection watcher singleton @@ -47,6 +54,12 @@ const plugin: JupyterFrontEndPlugin = { */ const chatHandler = new ChatHandler(); + const openInlineCompleterSettings = () => { + app.commands.execute('settingeditor:open', { + query: 'Inline Completer' + }); + }; + let chatWidget: ReactWidget | null = null; try { await chatHandler.initialize(); @@ -55,7 +68,9 @@ const plugin: JupyterFrontEndPlugin = { chatHandler, globalAwareness, themeManager, - rmRegistry + rmRegistry, + completionProvider, + openInlineCompleterSettings ); } catch (e) { chatWidget = buildErrorWidget(themeManager); diff --git a/packages/jupyter-ai/src/tokens.ts b/packages/jupyter-ai/src/tokens.ts index f240a819..924745ee 100644 --- a/packages/jupyter-ai/src/tokens.ts +++ b/packages/jupyter-ai/src/tokens.ts @@ -1,4 +1,5 @@ import { Token } from '@lumino/coreutils'; +import { ISignal } from '@lumino/signaling'; import type { IRankedMenu } from '@jupyterlab/ui-components'; export interface IJaiStatusItem { @@ -12,3 +13,16 @@ export const IJaiStatusItem = new Token( 'jupyter_ai:IJupyternautStatus', 'Status indicator displayed in the statusbar' ); + +export interface IJaiCompletionProvider { + isEnabled(): boolean; + settingsChanged: ISignal; +} + +/** + * The incline completion provider token. + */ +export const IJaiCompletionProvider = new Token( + 'jupyter_ai:IJaiCompletionProvider', + 'Status the incline completion provider' +); diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index 291593b4..bb575feb 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -7,6 +7,7 @@ import { Chat } from '../components/chat'; import { chatIcon } from '../icons'; import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export function buildChatSidebar( @@ -14,7 +15,9 @@ export function buildChatSidebar( chatHandler: ChatHandler, globalAwareness: Awareness | null, themeManager: IThemeManager | null, - rmRegistry: IRenderMimeRegistry + rmRegistry: IRenderMimeRegistry, + completionProvider: IJaiCompletionProvider | null, + openInlineCompleterSettings: () => void ): ReactWidget { const ChatWidget = ReactWidget.create( ); ChatWidget.id = 'jupyter-ai::chat'; From 9b90e4d894cccdb576cc204f54835c2268a8ba03 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 17 Apr 2024 12:59:07 +0100 Subject: [PATCH 05/19] Improve docstring --- packages/jupyter-ai/src/tokens.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai/src/tokens.ts b/packages/jupyter-ai/src/tokens.ts index 924745ee..f0f30198 100644 --- a/packages/jupyter-ai/src/tokens.ts +++ b/packages/jupyter-ai/src/tokens.ts @@ -20,9 +20,9 @@ export interface IJaiCompletionProvider { } /** - * The incline completion provider token. + * The inline completion provider token. */ export const IJaiCompletionProvider = new Token( 'jupyter_ai:IJaiCompletionProvider', - 'Status the incline completion provider' + 'The jupyter-ai inline completion provider API' ); From 988213457371050f8f27198af5fb2846b93b705d Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:00:37 +0100 Subject: [PATCH 06/19] Call `_validate_lm_em_id` only once, add typing annotations --- .../jupyter-ai/jupyter_ai/config_manager.py | 71 +++++++++++-------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 5969c9f9..def992be 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -97,6 +97,11 @@ class ConfigManager(Configurable): config=True, ) + model_provider_id: Optional[str] + embeddings_provider_id: Optional[str] + completions_model_provider_id: Optional[str] + completions_embeddings_provider_id: Optional[str] + def __init__( self, log: Logger, @@ -164,48 +169,52 @@ def _process_existing_config(self, default_config): {k: v for k, v in existing_config.items() if v is not None}, ) config = GlobalConfig(**merged_config) - validated_config = self._validate_lm_em_id( - config, lm_key="model_provider_id", em_key="embeddings_provider_id" - ) - validated_config = self._validate_lm_em_id( - config, - lm_key="completions_model_provider_id", - em_key="completions_embeddings_provider_id", - ) + validated_config = self._validate_lm_em_id(config) # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(validated_config) - def _validate_lm_em_id(self, config, lm_key, em_key): - lm_id = getattr(config, lm_key) - em_id = getattr(config, em_key) + def _validate_lm_em_id(self, config): + lm_provider_keys = ["model_provider_id", "completions_model_provider_id"] + em_provider_keys = [ + "embeddings_provider_id", + "completions_embeddings_provider_id", + ] # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. - if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - setattr(config, lm_key, None) - if em_id is not None and not self._validate_model(em_id, raise_exc=False): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - setattr(config, em_key, None) + for lm_key in lm_provider_keys: + lm_id = getattr(config, lm_key) + if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + setattr(config, lm_key, None) + for em_key in em_provider_keys: + em_id = getattr(config, em_key) + if em_id is not None and not self._validate_model(em_id, raise_exc=False): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + ) + setattr(config, em_key, None) # if the currently selected language or embedding model ids are # not associated with models, set them to `None` and log a warning. - if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - setattr(config, lm_key, None) - if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - setattr(config, em_key, None) + for lm_key in lm_provider_keys: + lm_id = getattr(config, lm_key) + if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + setattr(config, lm_key, None) + for em_key in em_provider_keys: + em_id = getattr(config, em_key) + if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + setattr(config, em_key, None) return config From ac7cd846fcd39019da583094c96d825ce8d3922e Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:11:29 +0100 Subject: [PATCH 07/19] Remove embeddings provider for completions as the team has no plans to support it :( --- .../jupyter_ai/config/config_schema.json | 6 ------ .../jupyter-ai/jupyter_ai/config_manager.py | 19 +------------------ packages/jupyter-ai/jupyter_ai/models.py | 3 --- .../__snapshots__/test_config_manager.ambr | 1 - .../components/settings/use-server-info.ts | 6 +----- packages/jupyter-ai/src/handler.ts | 2 -- 6 files changed, 2 insertions(+), 35 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config/config_schema.json b/packages/jupyter-ai/jupyter_ai/config/config_schema.json index 0dd910ab..d276dd31 100644 --- a/packages/jupyter-ai/jupyter_ai/config/config_schema.json +++ b/packages/jupyter-ai/jupyter_ai/config/config_schema.json @@ -22,12 +22,6 @@ "default": null, "readOnly": false }, - "completions_embeddings_provider_id": { - "$comment": "Embedding model global ID for completions.", - "type": ["string", "null"], - "default": null, - "readOnly": false - }, "api_keys": { "$comment": "Dictionary of API keys, mapping key names to key values.", "type": "object", diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index def992be..fa8a53ef 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -100,7 +100,6 @@ class ConfigManager(Configurable): model_provider_id: Optional[str] embeddings_provider_id: Optional[str] completions_model_provider_id: Optional[str] - completions_embeddings_provider_id: Optional[str] def __init__( self, @@ -177,10 +176,7 @@ def _process_existing_config(self, default_config): def _validate_lm_em_id(self, config): lm_provider_keys = ["model_provider_id", "completions_model_provider_id"] - em_provider_keys = [ - "embeddings_provider_id", - "completions_embeddings_provider_id", - ] + em_provider_keys = ["embeddings_provider_id"] # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. @@ -352,7 +348,6 @@ def delete_api_key(self, key_name: str): self.lm_provider, self.em_provider, self.completions_lm_provider, - self.completions_em_provider, ]: if ( provider @@ -417,12 +412,6 @@ def em_provider(self): def completions_lm_provider(self): return self._get_provider("completions_model_provider_id", self._lm_providers) - @property - def completions_em_provider(self): - return self._get_provider( - "completions_embeddings_provider_id", self._em_providers - ) - def _get_provider(self, key, listing): config = self._read_config() gid = getattr(config, key) @@ -446,12 +435,6 @@ def completions_lm_provider_params(self): "completions_model_provider_id", self._lm_providers ) - @property - def completions_em_provider_params(self): - return self._provider_params( - "completions_embeddings_provider_id", self._em_providers - ) - def _provider_params(self, key, listing): # get generic fields config = self._read_config() diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 5675c636..6742a922 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -124,7 +124,6 @@ class DescribeConfigResponse(BaseModel): # passed to the subsequent UpdateConfig request. last_read: int completions_model_provider_id: Optional[str] - completions_embeddings_provider_id: Optional[str] completions_fields: Dict[str, Dict[str, Any]] @@ -143,7 +142,6 @@ class UpdateConfigRequest(BaseModel): # time specified by `last_read` to prevent write-write conflicts. last_read: Optional[int] completions_model_provider_id: Optional[str] - completions_embeddings_provider_id: Optional[str] completions_fields: Optional[Dict[str, Dict[str, Any]]] _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( @@ -163,5 +161,4 @@ class GlobalConfig(BaseModel): fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] completions_model_provider_id: Optional[str] - completions_embeddings_provider_id: Optional[str] completions_fields: Dict[str, Dict[str, Any]] diff --git a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr index b5d75bdc..111ebe6e 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr +++ b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr @@ -3,7 +3,6 @@ dict({ 'api_keys': list([ ]), - 'completions_embeddings_provider_id': None, 'completions_fields': dict({ }), 'completions_model_provider_id': None, diff --git a/packages/jupyter-ai/src/components/settings/use-server-info.ts b/packages/jupyter-ai/src/components/settings/use-server-info.ts index 3695bfe4..d583b957 100644 --- a/packages/jupyter-ai/src/components/settings/use-server-info.ts +++ b/packages/jupyter-ai/src/components/settings/use-server-info.ts @@ -13,7 +13,7 @@ type ServerInfoProperties = { emProviders: AiService.ListProvidersResponse; config: AiService.DescribeConfigResponse; chat: ProvidersInfo; - completions: ProvidersInfo; + completions: Omit; }; type ServerInfoMethods = { @@ -72,11 +72,8 @@ export function useServerInfo(): ServerInfo { const lmLocalId = (lmGid && getModelLocalId(lmGid)) ?? ''; const cLmGid = config.completions_model_provider_id; - const cEmGid = config.completions_embeddings_provider_id; const cLmProvider = cLmGid === null ? null : getProvider(cLmGid, lmProviders); - const cEmProvider = - cEmGid === null ? null : getProvider(cEmGid, emProviders); const cLmLocalId = (cLmGid && getModelLocalId(cLmGid)) ?? ''; setServerInfoProps({ @@ -90,7 +87,6 @@ export function useServerInfo(): ServerInfo { }, completions: { lmProvider: cLmProvider, - emProvider: cEmProvider, lmLocalId: cLmLocalId } }); diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 600a257f..512479b7 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -118,7 +118,6 @@ export namespace AiService { fields: Record>; last_read: number; completions_model_provider_id: string | null; - completions_embeddings_provider_id: string | null; }; export type UpdateConfigRequest = { @@ -129,7 +128,6 @@ export namespace AiService { fields?: Record>; last_read?: number; completions_model_provider_id?: string | null; - completions_embeddings_provider_id?: string | null; completions_fields?: Record>; }; From 59156f9691f63ba0146fb3e21eee181509ed8fbb Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:16:09 +0100 Subject: [PATCH 08/19] Use type alias to reduce changeset/make review easier Without this change prettier reformats the plugin with an extra indentation, which leads to bad changeset display on GitHub. --- packages/jupyter-ai/src/completions/plugin.ts | 259 +++++++++--------- 1 file changed, 130 insertions(+), 129 deletions(-) diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index adf3db55..f5963f36 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -46,148 +46,149 @@ type IcPluginSettings = ISettingRegistry.ISettings & { }; }; -export const completionPlugin: JupyterFrontEndPlugin = - { - id: 'jupyter_ai:inline-completions', - autoStart: true, - requires: [ - ICompletionProviderManager, - IEditorLanguageRegistry, - ISettingRegistry - ], - optional: [IJaiStatusItem], - provides: IJaiCompletionProvider, - activate: async ( - app: JupyterFrontEnd, - completionManager: ICompletionProviderManager, - languageRegistry: IEditorLanguageRegistry, - settingRegistry: ISettingRegistry, - statusItem: IJaiStatusItem | null - ): Promise => { - if (typeof completionManager.registerInlineProvider === 'undefined') { - // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 - console.warn( - 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' - ); +type JaiCompletionToken = IJaiCompletionProvider | null; + +export const completionPlugin: JupyterFrontEndPlugin = { + id: 'jupyter_ai:inline-completions', + autoStart: true, + requires: [ + ICompletionProviderManager, + IEditorLanguageRegistry, + ISettingRegistry + ], + optional: [IJaiStatusItem], + provides: IJaiCompletionProvider, + activate: async ( + app: JupyterFrontEnd, + completionManager: ICompletionProviderManager, + languageRegistry: IEditorLanguageRegistry, + settingRegistry: ISettingRegistry, + statusItem: IJaiStatusItem | null + ): Promise => { + if (typeof completionManager.registerInlineProvider === 'undefined') { + // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 + console.warn( + 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' + ); + return null; + } + + const completionHandler = new CompletionWebsocketHandler(); + const provider = new JaiInlineProvider({ + completionHandler, + languageRegistry + }); + + await completionHandler.initialize(); + completionManager.registerInlineProvider(provider); + + const findCurrentLanguage = (): IEditorLanguage | null => { + const widget = app.shell.currentWidget; + const editor = getEditor(widget); + if (!editor) { return null; } + return languageRegistry.findByMIME(editor.model.mimeType); + }; - const completionHandler = new CompletionWebsocketHandler(); - const provider = new JaiInlineProvider({ - completionHandler, - languageRegistry - }); + // ic := inline completion + async function getIcSettings() { + return (await settingRegistry.load( + INLINE_COMPLETER_PLUGIN + )) as IcPluginSettings; + } - await completionHandler.initialize(); - completionManager.registerInlineProvider(provider); + /** + * Gets the composite settings for the Jupyter AI inline completion provider + * (JaiIcp). + * + * This reads from the `ISettings.composite` property, which merges the user + * settings with the provider defaults, defined in + * `JaiInlineProvider.DEFAULT_SETTINGS`. + */ + async function getJaiIcpSettings() { + const icSettings = await getIcSettings(); + return icSettings.composite.providers[JaiInlineProvider.ID]; + } - const findCurrentLanguage = (): IEditorLanguage | null => { - const widget = app.shell.currentWidget; - const editor = getEditor(widget); - if (!editor) { - return null; + /** + * Updates the JaiIcp user settings. + */ + async function updateJaiIcpSettings( + newJaiIcpSettings: Partial + ) { + const icSettings = await getIcSettings(); + const oldUserIcpSettings = icSettings.user.providers; + const newUserIcpSettings = { + ...oldUserIcpSettings, + [JaiInlineProvider.ID]: { + ...oldUserIcpSettings?.[JaiInlineProvider.ID], + ...newJaiIcpSettings } - return languageRegistry.findByMIME(editor.model.mimeType); }; + icSettings.set('providers', newUserIcpSettings); + } - // ic := inline completion - async function getIcSettings() { - return (await settingRegistry.load( - INLINE_COMPLETER_PLUGIN - )) as IcPluginSettings; - } - - /** - * Gets the composite settings for the Jupyter AI inline completion provider - * (JaiIcp). - * - * This reads from the `ISettings.composite` property, which merges the user - * settings with the provider defaults, defined in - * `JaiInlineProvider.DEFAULT_SETTINGS`. - */ - async function getJaiIcpSettings() { - const icSettings = await getIcSettings(); - return icSettings.composite.providers[JaiInlineProvider.ID]; - } - - /** - * Updates the JaiIcp user settings. - */ - async function updateJaiIcpSettings( - newJaiIcpSettings: Partial - ) { - const icSettings = await getIcSettings(); - const oldUserIcpSettings = icSettings.user.providers; - const newUserIcpSettings = { - ...oldUserIcpSettings, - [JaiInlineProvider.ID]: { - ...oldUserIcpSettings?.[JaiInlineProvider.ID], - ...newJaiIcpSettings - } - }; - icSettings.set('providers', newUserIcpSettings); + app.commands.addCommand(CommandIDs.toggleCompletions, { + execute: async () => { + const jaiIcpSettings = await getJaiIcpSettings(); + updateJaiIcpSettings({ + enabled: !jaiIcpSettings.enabled + }); + }, + label: 'Enable completions by Jupyternaut', + isToggled: () => { + return provider.isEnabled(); } + }); - app.commands.addCommand(CommandIDs.toggleCompletions, { - execute: async () => { - const jaiIcpSettings = await getJaiIcpSettings(); - updateJaiIcpSettings({ - enabled: !jaiIcpSettings.enabled - }); - }, - label: 'Enable completions by Jupyternaut', - isToggled: () => { - return provider.isEnabled(); + app.commands.addCommand(CommandIDs.toggleLanguageCompletions, { + execute: async () => { + const jaiIcpSettings = await getJaiIcpSettings(); + const language = findCurrentLanguage(); + if (!language) { + return; } - }); - app.commands.addCommand(CommandIDs.toggleLanguageCompletions, { - execute: async () => { - const jaiIcpSettings = await getJaiIcpSettings(); - const language = findCurrentLanguage(); - if (!language) { - return; - } - - const disabledLanguages = [...jaiIcpSettings.disabledLanguages]; - const newDisabledLanguages = disabledLanguages.includes(language.name) - ? disabledLanguages.filter(l => l !== language.name) - : disabledLanguages.concat(language.name); - - updateJaiIcpSettings({ - disabledLanguages: newDisabledLanguages - }); - }, - label: () => { - const language = findCurrentLanguage(); - return language - ? `Disable completions in ${displayName(language)}` - : 'Disable completions in files'; - }, - isToggled: () => { - const language = findCurrentLanguage(); - return !!language && !provider.isLanguageEnabled(language.name); - }, - isVisible: () => { - const language = findCurrentLanguage(); - return !!language; - }, - isEnabled: () => { - const language = findCurrentLanguage(); - return !!language && provider.isEnabled(); - } - }); + const disabledLanguages = [...jaiIcpSettings.disabledLanguages]; + const newDisabledLanguages = disabledLanguages.includes(language.name) + ? disabledLanguages.filter(l => l !== language.name) + : disabledLanguages.concat(language.name); - if (statusItem) { - statusItem.addItem({ - command: CommandIDs.toggleCompletions, - rank: 1 - }); - statusItem.addItem({ - command: CommandIDs.toggleLanguageCompletions, - rank: 2 + updateJaiIcpSettings({ + disabledLanguages: newDisabledLanguages }); + }, + label: () => { + const language = findCurrentLanguage(); + return language + ? `Disable completions in ${displayName(language)}` + : 'Disable completions in files'; + }, + isToggled: () => { + const language = findCurrentLanguage(); + return !!language && !provider.isLanguageEnabled(language.name); + }, + isVisible: () => { + const language = findCurrentLanguage(); + return !!language; + }, + isEnabled: () => { + const language = findCurrentLanguage(); + return !!language && provider.isEnabled(); } - return provider; + }); + + if (statusItem) { + statusItem.addItem({ + command: CommandIDs.toggleCompletions, + rank: 1 + }); + statusItem.addItem({ + command: CommandIDs.toggleLanguageCompletions, + rank: 2 + }); } - }; + return provider; + } +}; From b06481a75d9f9dce9a74e8a83081ffcb96000216 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:25:25 +0100 Subject: [PATCH 09/19] Rename `_validate_lm_em_id` to `_validate_model_ids` --- packages/jupyter-ai/jupyter_ai/config_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index fa8a53ef..a05f9b3c 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -168,13 +168,13 @@ def _process_existing_config(self, default_config): {k: v for k, v in existing_config.items() if v is not None}, ) config = GlobalConfig(**merged_config) - validated_config = self._validate_lm_em_id(config) + validated_config = self._validate_model_ids(config) # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(validated_config) - def _validate_lm_em_id(self, config): + def _validate_model_ids(self, config): lm_provider_keys = ["model_provider_id", "completions_model_provider_id"] em_provider_keys = ["embeddings_provider_id"] From ec5b1b8663fe2af554a5e459258f4f642ce8276a Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Thu, 2 May 2024 16:23:47 +0100 Subject: [PATCH 10/19] Rename `LLMHandlerMixin` to `CompletionsModelMixin` and rename the file from `llm_mixin` to `model_mixin` fro consistency. Of note, the file name does not need `completions_` prefix as the file is in `completions/` subdirectory. --- packages/jupyter-ai/jupyter_ai/completions/handlers/base.py | 4 ++-- .../completions/handlers/{llm_mixin.py => model_mixin.py} | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) rename packages/jupyter-ai/jupyter_ai/completions/handlers/{llm_mixin.py => model_mixin.py} (95%) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index 9eb4f845..32920dc8 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -5,7 +5,7 @@ from typing import Union import tornado -from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin +from jupyter_ai.completions.handlers.model_mixin import CompletionsModelMixin from jupyter_ai.completions.models import ( CompletionError, InlineCompletionList, @@ -18,7 +18,7 @@ class BaseInlineCompletionHandler( - LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler + CompletionsModelMixin, JupyterHandler, tornado.websocket.WebSocketHandler ): """A Tornado WebSocket handler that receives inline completion requests and fulfills them accordingly. This class is instantiated once per WebSocket diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py similarity index 95% rename from packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py rename to packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py index e31abae8..4bd3f9ce 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py @@ -5,8 +5,8 @@ from jupyter_ai_magics.providers import BaseProvider -class LLMHandlerMixin: - """Base class containing shared methods and attributes used by LLM handler classes.""" +class CompletionsModelMixin: + """Mixin class containing methods and attributes used by completions LLM handler.""" handler_kind: str settings: dict From 84165e5efe0a67f90d518bd1af0b94d1deac0c19 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Thu, 2 May 2024 16:38:22 +0100 Subject: [PATCH 11/19] Rename "Chat LM" to "LM"; add title attribute; note using the title attribute because getting the icon to show up nicely (getting they nice grey color and positioning as it gets in buttons, compared to just plain black) was not trivial; I think the icon might be the way to go in the future but I would postpone it to another PR. That said, I still think it should say "Chat LM" because it has no effect on magics nor completions. --- packages/jupyter-ai/src/components/chat-settings.tsx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 0d811ca7..036388dc 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -279,7 +279,12 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { }} > {/* Chat language model section */} -

Chat language model

+

+ Language model +

void; }): JSX.Element { - if (props.selection && !props.provider?.isEnabled()) { - return ( - - - - ); - } return ( - - - + + {() => + props.selection && !props.provider?.isEnabled() ? ( + + + + ) : ( + + + + ) + } + ); } From a9dc569af411c90cba5782c61ca0293d880a2cd3 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Thu, 2 May 2024 17:10:08 +0100 Subject: [PATCH 14/19] Rename the label in the select to "Inline completion model" --- packages/jupyter-ai/src/components/chat-settings.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 1ca6ff8e..4000ef05 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -288,7 +288,7 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { { const clmGid = e.target.value === 'null' ? null : e.target.value; if (clmGid === null) { From 13d0ecf13417c07139f8031e18dca080e2ca0f94 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Thu, 2 May 2024 17:22:22 +0100 Subject: [PATCH 15/19] Disable selection when completer is not enabled --- .../src/components/chat-settings.tsx | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 4000ef05..a0237887 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -90,6 +90,25 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const [sendWse, setSendWse] = useState(false); const [fields, setFields] = useState>({}); + const [completerIsEnabled, setCompleterIsEnabled] = useState( + props.completionProvider && props.completionProvider.isEnabled() + ); + + const refreshCompleterState = () => { + setCompleterIsEnabled( + props.completionProvider && props.completionProvider.isEnabled() + ); + }; + + useEffect(() => { + props.completionProvider?.settingsChanged.connect(refreshCompleterState); + return () => { + props.completionProvider?.settingsChanged.disconnect( + refreshCompleterState + ); + }; + }, [props.completionProvider]); + // whether the form is currently saving const [saving, setSaving] = useState(false); @@ -381,6 +400,7 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { { const clmGid = e.target.value === 'null' ? null : e.target.value; if (clmGid === null) { @@ -526,32 +525,28 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { function CompleterSettingsButton(props: { selection: AiService.ListProvidersEntry | null; provider: IJaiCompletionProvider | null; + isCompleterEnabled: boolean | null; openSettings: () => void; }): JSX.Element { + if (props.selection && !props.isCompleterEnabled) { + return ( + + + + ); + } return ( - - {() => - props.selection && !props.provider?.isEnabled() ? ( - - - - ) : ( - - - - ) - } - + + + ); } From 18e57f45ce908f8eef66f4e881ef5533eb40ef62 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Thu, 2 May 2024 17:43:19 +0100 Subject: [PATCH 17/19] Use mui tooltips --- .../src/components/chat-settings.tsx | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index ee530d19..87009ecf 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -12,6 +12,7 @@ import { Radio, RadioGroup, TextField, + Tooltip, CircularProgress } from '@mui/material'; import SettingsIcon from '@mui/icons-material/Settings'; @@ -530,23 +531,26 @@ function CompleterSettingsButton(props: { }): JSX.Element { if (props.selection && !props.isCompleterEnabled) { return ( - - - + + + + ); } return ( - - - + + + + + ); } From 5ac3cf307dad46c409c5f4ea3673ceefe9af4a05 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Thu, 2 May 2024 18:09:24 +0100 Subject: [PATCH 18/19] Fix use of `jai_config_manager` --- .../jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py index 4bd3f9ce..fd19498b 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py @@ -26,8 +26,8 @@ def __init__(self, *args, **kwargs) -> None: self._llm_params = None def get_llm(self) -> Optional[BaseProvider]: - lm_provider = self.config_manager.completions_lm_provider - lm_provider_params = self.config_manager.completions_lm_provider_params + lm_provider = self.jai_config_manager.completions_lm_provider + lm_provider_params = self.jai_config_manager.completions_lm_provider_params if not lm_provider or not lm_provider_params: return None From b5587fc8229fa5ac16281db09db99011ee14cea4 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Thu, 2 May 2024 18:18:32 +0100 Subject: [PATCH 19/19] Fix tests --- .../jupyter-ai/jupyter_ai/tests/completions/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 26c274ad..c5b5d1ee 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -34,7 +34,7 @@ def __init__(self, lm_provider=None, lm_provider_params=None): self.application = Application() self.messages = [] self.tasks = [] - self.settings["config_manager"] = SimpleNamespace( + self.settings["jai_config_manager"] = SimpleNamespace( completions_lm_provider=lm_provider or MockProvider, completions_lm_provider_params=lm_provider_params or {"model_id": "model"}, )