Skip to content

Commit

Permalink
Distinguish between completion and chat models (#711)
Browse files Browse the repository at this point in the history
* Distinguish between completion and chat models

* Fix tests

* Shorten the tab name, move settings button

Lint

* Implement the completion model selection in chat UI

* Improve docstring

* Call `_validate_lm_em_id` only once, add typing annotations

* Remove embeddings provider for completions

as the team has no plans to support it :(

* Use type alias to reduce changeset/make review easier

Without this change prettier reformats the plugin with an extra
indentation, which leads to bad changeset display on GitHub.

* Rename `_validate_lm_em_id` to `_validate_model_ids`

* Rename `LLMHandlerMixin` to `CompletionsModelMixin`

and rename the file from `llm_mixin` to `model_mixin` fro consistency.
Of note, the file name does not need `completions_` prefix as the file
is in `completions/` subdirectory.

* Rename "Chat LM" to "LM"; add title attribute; note

using the title attribute because getting the icon to show up nicely
(getting they nice grey color and positioning as it gets in buttons,
compared to just plain black) was not trivial; I think the icon might
be the way to go in the future but I would postpone it to another PR.

That said, I still think it should say "Chat LM" because it has no
effect on magics nor completions.

* Rename heading "Completer model" → "Inline completions model"

* Move `UseSignal` down to `CompleterSettingsButton` implementation

* Rename the label in the select to "Inline completion model"

* Disable selection when completer is not enabled

* Remove use of `UseSignal`, tweak naming of `useState`

from `completerIsEnabled` to `isCompleterEnabled`

* Use mui tooltips

* Fix use of `jai_config_manager`

* Fix tests
  • Loading branch information
krassowski authored May 3, 2024
1 parent 3b4aa37 commit cf20800
Show file tree
Hide file tree
Showing 18 changed files with 402 additions and 117 deletions.
10 changes: 10 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ class Config:
provider is selected.
"""

@classmethod
def chat_models(self):
"""Models which are suitable for chat."""
return self.models

@classmethod
def completion_models(self):
"""Models which are suitable for completions."""
return self.models

#
# instance attrs
#
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

import tornado
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin
from jupyter_ai.completions.handlers.model_mixin import CompletionsModelMixin
from jupyter_ai.completions.models import (
CompletionError,
InlineCompletionList,
Expand All @@ -18,7 +18,7 @@


class BaseInlineCompletionHandler(
LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler
CompletionsModelMixin, JupyterHandler, tornado.websocket.WebSocketHandler
):
"""A Tornado WebSocket handler that receives inline completion requests and
fulfills them accordingly. This class is instantiated once per WebSocket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from jupyter_ai_magics.providers import BaseProvider


class LLMHandlerMixin:
"""Base class containing shared methods and attributes used by LLM handler classes."""
class CompletionsModelMixin:
"""Mixin class containing methods and attributes used by completions LLM handler."""

handler_kind: str
settings: dict
Expand All @@ -26,8 +26,8 @@ def __init__(self, *args, **kwargs) -> None:
self._llm_params = None

def get_llm(self) -> Optional[BaseProvider]:
lm_provider = self.jai_config_manager.lm_provider
lm_provider_params = self.jai_config_manager.lm_provider_params
lm_provider = self.jai_config_manager.completions_lm_provider
lm_provider_params = self.jai_config_manager.completions_lm_provider_params

if not lm_provider or not lm_provider_params:
return None
Expand Down
17 changes: 17 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
"default": null,
"readOnly": false
},
"completions_model_provider_id": {
"$comment": "Language model global ID for completions.",
"type": ["string", "null"],
"default": null,
"readOnly": false
},
"api_keys": {
"$comment": "Dictionary of API keys, mapping key names to key values.",
"type": "object",
Expand All @@ -37,6 +43,17 @@
}
},
"additionalProperties": false
},
"completions_fields": {
"$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.",
"type": "object",
"default": {},
"patternProperties": {
"^.*$": {
"anyOf": [{ "type": "object" }]
}
},
"additionalProperties": false
}
}
}
150 changes: 76 additions & 74 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class ConfigManager(Configurable):
config=True,
)

model_provider_id: Optional[str]
embeddings_provider_id: Optional[str]
completions_model_provider_id: Optional[str]

def __init__(
self,
log: Logger,
Expand Down Expand Up @@ -164,41 +168,49 @@ def _process_existing_config(self, default_config):
{k: v for k, v in existing_config.items() if v is not None},
)
config = GlobalConfig(**merged_config)
validated_config = self._validate_lm_em_id(config)
validated_config = self._validate_model_ids(config)

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(validated_config)

def _validate_lm_em_id(self, config):
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id
def _validate_model_ids(self, config):
lm_provider_keys = ["model_provider_id", "completions_model_provider_id"]
em_provider_keys = ["embeddings_provider_id"]

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(em_id, raise_exc=False):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None
for lm_key in lm_provider_keys:
lm_id = getattr(config, lm_key)
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
setattr(config, lm_key, None)
for em_key in em_provider_keys:
em_id = getattr(config, em_key)
if em_id is not None and not self._validate_model(em_id, raise_exc=False):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
setattr(config, em_key, None)

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None
for lm_key in lm_provider_keys:
lm_id = getattr(config, lm_key)
if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
setattr(config, lm_key, None)
for em_key in em_provider_keys:
em_id = getattr(config, em_key)
if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
setattr(config, em_key, None)

return config

Expand Down Expand Up @@ -321,28 +333,28 @@ def _write_config(self, new_config: GlobalConfig):
complete `GlobalConfig` object, and should not be called publicly."""
# remove any empty field dictionaries
new_config.fields = {k: v for k, v in new_config.fields.items() if v}
new_config.completions_fields = {
k: v for k, v in new_config.completions_fields.items() if v
}

self._validate_config(new_config)
with open(self.config_path, "w") as f:
json.dump(new_config.dict(), f, indent=self.indentation_depth)

def delete_api_key(self, key_name: str):
config_dict = self._read_config().dict()
lm_provider = self.lm_provider
em_provider = self.em_provider
required_keys = []
if (
lm_provider
and lm_provider.auth_strategy
and lm_provider.auth_strategy.type == "env"
):
required_keys.append(lm_provider.auth_strategy.name)
if (
em_provider
and em_provider.auth_strategy
and em_provider.auth_strategy.type == "env"
):
required_keys.append(self.em_provider.auth_strategy.name)
for provider in [
self.lm_provider,
self.em_provider,
self.completions_lm_provider,
]:
if (
provider
and provider.auth_strategy
and provider.auth_strategy.type == "env"
):
required_keys.append(provider.auth_strategy.name)

if key_name in required_keys:
raise KeyInUseError(
Expand Down Expand Up @@ -390,67 +402,57 @@ def em_gid(self):

@property
def lm_provider(self):
config = self._read_config()
lm_gid = config.model_provider_id
if lm_gid is None:
return None

_, Provider = get_lm_provider(config.model_provider_id, self._lm_providers)
return Provider
return self._get_provider("model_provider_id", self._lm_providers)

@property
def em_provider(self):
return self._get_provider("embeddings_provider_id", self._em_providers)

@property
def completions_lm_provider(self):
return self._get_provider("completions_model_provider_id", self._lm_providers)

def _get_provider(self, key, listing):
config = self._read_config()
em_gid = config.embeddings_provider_id
if em_gid is None:
gid = getattr(config, key)
if gid is None:
return None

_, Provider = get_em_provider(em_gid, self._em_providers)
_, Provider = get_lm_provider(gid, listing)
return Provider

@property
def lm_provider_params(self):
# get generic fields
config = self._read_config()
lm_gid = config.model_provider_id
if not lm_gid:
return None

lm_lid = lm_gid.split(":", 1)[1]
fields = config.fields.get(lm_gid, {})

# get authn fields
_, Provider = get_lm_provider(lm_gid, self._lm_providers)
authn_fields = {}
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
key_name = Provider.auth_strategy.name
authn_fields[key_name.lower()] = config.api_keys[key_name]

return {
"model_id": lm_lid,
**fields,
**authn_fields,
}
return self._provider_params("model_provider_id", self._lm_providers)

@property
def em_provider_params(self):
return self._provider_params("embeddings_provider_id", self._em_providers)

@property
def completions_lm_provider_params(self):
return self._provider_params(
"completions_model_provider_id", self._lm_providers
)

def _provider_params(self, key, listing):
# get generic fields
config = self._read_config()
em_gid = config.embeddings_provider_id
if not em_gid:
gid = getattr(config, key)
if not gid:
return None

em_lid = em_gid.split(":", 1)[1]
lid = gid.split(":", 1)[1]

# get authn fields
_, Provider = get_em_provider(em_gid, self._em_providers)
_, Provider = get_em_provider(gid, listing)
authn_fields = {}
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
key_name = Provider.auth_strategy.name
authn_fields[key_name.lower()] = config.api_keys[key_name]

return {
"model_id": em_lid,
"model_id": lid,
**authn_fields,
}

Expand Down
6 changes: 6 additions & 0 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def filter_predicate(local_model_id: str):
# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
provider.models = list(filter(filter_predicate, provider.models))
provider.chat_models = list(filter(filter_predicate, provider.chat_models))
provider.completion_models = list(
filter(filter_predicate, provider.completion_models)
)

# filter out every provider with no models which satisfy the allow/blocklist, then return
return filter((lambda p: len(p.models) > 0), providers)
Expand All @@ -311,6 +315,8 @@ def get(self):
id=provider.id,
name=provider.name,
models=provider.models,
chat_models=provider.chat_models(),
completion_models=provider.completion_models(),
help=provider.help,
auth_strategy=provider.auth_strategy,
registry=provider.registry,
Expand Down
8 changes: 8 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class ListProvidersEntry(BaseModel):
auth_strategy: AuthStrategy
registry: bool
fields: List[Field]
chat_models: Optional[List[str]]
completion_models: Optional[List[str]]


class ListProvidersResponse(BaseModel):
Expand Down Expand Up @@ -121,6 +123,8 @@ class DescribeConfigResponse(BaseModel):
# timestamp indicating when the configuration file was last read. should be
# passed to the subsequent UpdateConfig request.
last_read: int
completions_model_provider_id: Optional[str]
completions_fields: Dict[str, Dict[str, Any]]


def forbid_none(cls, v):
Expand All @@ -137,6 +141,8 @@ class UpdateConfigRequest(BaseModel):
# if passed, this will raise an Error if the config was written to after the
# time specified by `last_read` to prevent write-write conflicts.
last_read: Optional[int]
completions_model_provider_id: Optional[str]
completions_fields: Optional[Dict[str, Dict[str, Any]]]

_validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)(
forbid_none
Expand All @@ -154,3 +160,5 @@ class GlobalConfig(BaseModel):
send_with_shift_enter: bool
fields: Dict[str, Dict[str, Any]]
api_keys: Dict[str, str]
completions_model_provider_id: Optional[str]
completions_fields: Dict[str, Dict[str, Any]]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
dict({
'api_keys': list([
]),
'completions_fields': dict({
}),
'completions_model_provider_id': None,
'embeddings_provider_id': None,
'fields': dict({
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(self, lm_provider=None, lm_provider_params=None):
self.messages = []
self.tasks = []
self.settings["jai_config_manager"] = SimpleNamespace(
lm_provider=lm_provider or MockProvider,
lm_provider_params=lm_provider_params or {"model_id": "model"},
completions_lm_provider=lm_provider or MockProvider,
completions_lm_provider_params=lm_provider_params or {"model_id": "model"},
)
self.settings["jai_event_loop"] = SimpleNamespace(
create_task=lambda x: self.tasks.append(x)
Expand Down
Loading

0 comments on commit cf20800

Please sign in to comment.