diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index aac92328..6763486d 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -228,6 +228,94 @@ class MyCompletionProvider(BaseProvider, FakeListLLM): ) ``` + +#### Using the full notebook content for completions + +The `InlineCompletionRequest` contains the `path` of the current document (file or notebook). +Inline completion providers can use this path to extract the content of the notebook from the disk, +however such content may be outdated if the user has not saved the notebook recently. + +The accuracy of the suggestions can be slightly improved by combining the potentially outdated content of previous/following cells +with the `prefix` and `suffix` which describe the up-to-date state of the current cell (identified by `cell_id`). + +Still, reading the full notebook from the disk may be slow for larger notebooks, which conflicts with the low latency requirement of inline completion. + +A better approach is to use the live copy of the notebook document that is persisted on the jupyter-server when *collaborative* document models are enabled. +Two packages need to be installed to access the collaborative models: +- `jupyter-server-ydoc` (>= 1.0) stores the collaborative models in the jupyter-server on runtime +- `jupyter-docprovider` (>= 1.0) reconfigures JupyterLab/Notebook to use the collaborative models + +Both packages are automatically installed with `jupyter-collaboration` (in v3.0 or newer), however installing `jupyter-collaboration` is not required to take advantage of *collaborative* models. + +The snippet below demonstrates how to retrieve the content of all cells of a given type from the in-memory copy of the collaborative model (without additional disk reads). + +```python +from jupyter_ydoc import YNotebook + + +class MyCompletionProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = ["model_a"] + + def __init__(self, **kwargs): + kwargs["responses"] = ["This fake response will not be used for completion"] + super().__init__(**kwargs) + + async def _get_prefix_and_suffix(self, request: InlineCompletionRequest): + prefix = request.prefix + suffix = request.suffix.strip() + + server_ydoc = self.server_settings.get("jupyter_server_ydoc", None) + if not server_ydoc: + # fallback to prefix/suffix from single cell + return prefix, suffix + + is_notebook = request.path.endswith("ipynb") + document = await server_ydoc.get_document( + path=request.path, + content_type="notebook" if is_notebook else "file", + file_format="json" if is_notebook else "text" + ) + if not document or not isinstance(document, YNotebook): + return prefix, suffix + + cell_type = "markdown" if request.language == "markdown" else "code" + + is_before_request_cell = True + before = [] + after = [suffix] + + for cell in document.ycells: + if is_before_request_cell and cell["id"] == request.cell_id: + is_before_request_cell = False + continue + if cell["cell_type"] != cell_type: + continue + source = cell["source"].to_py() + if is_before_request_cell: + before.append(source) + else: + after.append(source) + + before.append(prefix) + prefix = "\n\n".join(before) + suffix = "\n\n".join(after) + return prefix, suffix + + async def generate_inline_completions(self, request: InlineCompletionRequest): + prefix, suffix = await self._get_prefix_and_suffix(request) + + return InlineCompletionReply( + list=InlineCompletionList(items=[ + {"insertText": your_llm_function(prefix, suffix)} + ]), + reply_to=request.number, + ) +``` + + ## Prompt templates Each provider can define **prompt templates** for each supported format. A prompt diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 1bbc4ce5..bef7efac 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -5,6 +5,7 @@ import io import json from concurrent.futures import ThreadPoolExecutor +from types import MappingProxyType from typing import ( Any, AsyncIterator, @@ -265,6 +266,13 @@ class Config: provider is selected. """ + server_settings: ClassVar[Optional[MappingProxyType[str, Any]]] = None + """Settings passed on from jupyter-ai package. + + The same server settings are shared between all providers. + Providers are not allowed to mutate this dictionary. + """ + @classmethod def chat_models(self): """Models which are suitable for chat.""" diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 420b72ad..efb83b3d 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,11 +1,12 @@ import os import re import time +import types from dask.distributed import Client as DaskClient from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever -from jupyter_ai_magics import JupyternautPersona +from jupyter_ai_magics import BaseProvider, JupyternautPersona from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from tornado.web import StaticFileHandler @@ -202,6 +203,11 @@ def initialize_settings(self): defaults=defaults, ) + # Expose a subset of settings as read-only to the providers + BaseProvider.server_settings = types.MappingProxyType( + self.serverapp.web_app.settings + ) + self.log.info("Registered providers.") self.log.info(f"Registered {self.name} server extension")