Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SageMaker Endpoints in chat #197

Merged
merged 11 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/check-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }}
version_spec: minor
python-version: '3.10.x'
- name: Runner debug info
if: always()
run: |
Expand Down
Binary file added docs/source/_static/chat-sagemaker-endpoints.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,42 @@ To compose a message, type it in the text box at the bottom of the chat interfac
alt='Screen shot of an example "Hello world" message sent to Jupyternaut, who responds with "Hello world, how are you today?"'
class="screenshot" />

### Usage with SageMaker Endpoints

Jupyter AI supports language models hosted on SageMaker Endpoints that use JSON
APIs. The first step is to authenticate with AWS via the `boto3` SDK and have
the credentials stored in the `default` profile. Guidance on how to do this can
be found in the
[`boto3` documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html).

When selecting the SageMaker Endpoints provider in the settings panel, you will
see the following interface:

<img src="../_static/chat-sagemaker-endpoints.png"
width="50%"
alt='Screenshot of the settings panel with the SageMaker Endpoints provider selected.'
class="screenshot" />

Each of the additional fields under "Language model" is required. These fields
should contain the following data:

- **Local model ID**: The name of your endpoint. This can be retrieved from the
AWS Console at the URL
`https://<region>.console.aws.amazon.com/sagemaker/home?region=<region>#/endpoints`.

- **Region name**: The AWS region your SageMaker endpoint is hosted in, e.g. `us-west-2`.

- **Request schema**: The JSON object the endpoint expects, with the prompt
being substituted into any value that matches the string literal `"<prompt>"`.
In this example, the request schema `{"text_inputs":"<prompt>"}` generates a JSON
object with the prompt stored under the `text_inputs` key.

- **Response path**: A [JSONPath](https://goessner.net/articles/JsonPath/index.html)
string that retrieves the language model's output from the endpoint's JSON
response. In this example, the endpoint returns an object with the schema
`{"generated_texts":["<output>"]}`, hence the response path is
`generated_texts.[0]`.

### Asking about something in your notebook

Jupyter AI's chat interface can include a portion of your notebook in your prompt.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import ClassVar, List, Type
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy, Field
from pydantic import BaseModel, Extra
from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings
from langchain.embeddings.base import Embeddings
Expand Down Expand Up @@ -35,7 +35,14 @@ class Config:

provider_klass: ClassVar[Type[Embeddings]]

registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""

fields: ClassVar[List[Field]] = []
"""Fields expected by this provider in its constructor. Each `Field` `f`
should be passed as a keyword argument, keyed by `f.key`."""


class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider):
id = "openai"
name = "OpenAI"
Expand Down Expand Up @@ -73,3 +80,4 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider):
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
provider_klass = HuggingFaceHubEmbeddings
registry = True
80 changes: 76 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import ClassVar, Dict, List, Union, Literal, Optional

from typing import Any, ClassVar, Dict, List, Union, Literal, Optional
import base64

import io
import json
import copy

from jsonpath_ng import jsonpath, parse
from langchain.schema import BaseModel as BaseLangchainProvider
from langchain.llms import (
AI21,
Expand All @@ -14,6 +15,7 @@
OpenAIChat,
SagemakerEndpoint
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.utils import get_from_dict_or_env
from langchain.llms.utils import enforce_stop_tokens

Expand Down Expand Up @@ -45,6 +47,18 @@ class AwsAuthStrategy(BaseModel):
]
]

class TextField(BaseModel):
type: Literal["text"] = "text"
key: str
label: str

class MultilineTextField(BaseModel):
type: Literal["text-multiline"] = "text-multiline"
key: str
label: str

Field = Union[TextField, MultilineTextField]

class BaseProvider(BaseLangchainProvider):
#
# pydantic config
Expand Down Expand Up @@ -75,6 +89,13 @@ class Config:
"""Authentication/authorization strategy. Declares what credentials are
required to use this model provider. Generally should not be `None`."""

registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""

fields: ClassVar[List[Field]] = []
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""

#
# instance attrs
#
Expand Down Expand Up @@ -144,6 +165,7 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
# tqdm is a dependency of huggingface_hub
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
registry = True

# Override the parent's validate_environment with a custom list of valid tasks
@root_validator()
Expand Down Expand Up @@ -292,12 +314,62 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def __init__(self, request_schema, response_path):
self.request_schema = json.loads(request_schema)
self.response_path = response_path
self.response_parser = parse(response_path)

def replace_values(self, old_val, new_val, d: Dict[str, Any]):
"""Replaces values of a dictionary recursively."""
for key, val in d.items():
if val == old_val:
d[key] = new_val
if isinstance(val, dict):
self.replace_values(old_val, new_val, val)

return d

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
request_obj = copy.deepcopy(self.request_schema)
self.replace_values("<prompt>", prompt, request_obj)
request = json.dumps(request_obj).encode('utf-8')
return request

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
matches = self.response_parser.find(response_json)
return matches[0].value

class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
models = ["*"]
model_id_key = "endpoint_name"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()

registry = True
fields = [
TextField(
key="region_name",
label="Region name",
),
MultilineTextField(
key="request_schema",
label="Request schema",
),
TextField(
key="response_path",
label="Response path",
)
]

def __init__(self, *args, **kwargs):
request_schema = kwargs.pop('request_schema')
response_path = kwargs.pop('response_path')
content_handler = JsonContentHandler(request_schema=request_schema, response_path=response_path)
super().__init__(*args, **kwargs, content_handler=content_handler)

1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"langchain==0.0.159",
"typing_extensions==4.5.0",
"click~=8.0",
"jsonpath-ng~=1.5.3",
]

[project.optional-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/actors/chat_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def update(self, config: GlobalConfig):
if not provider:
raise ValueError(f"No provider and model found with '{model_id}'")

provider_params = { "model_id": local_model_id}
fields = config.fields.get(model_id, {})
provider_params = { "model_id": local_model_id, **fields }

auth_strategy = provider.auth_strategy
if auth_strategy and auth_strategy.type == "env":
Expand Down
8 changes: 6 additions & 2 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def get(self):
id=provider.id,
name=provider.name,
models=provider.models,
auth_strategy=provider.auth_strategy
auth_strategy=provider.auth_strategy,
registry=provider.registry,
fields=provider.fields,
)
)

Expand All @@ -304,7 +306,9 @@ def get(self):
id=provider.id,
name=provider.name,
models=provider.models,
auth_strategy=provider.auth_strategy
auth_strategy=provider.auth_strategy,
registry=provider.registry,
fields=provider.fields,
)
)

Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jupyter_ai_magics.providers import AuthStrategy
from jupyter_ai_magics.providers import AuthStrategy, Field

from pydantic import BaseModel
from typing import Dict, List, Union, Literal, Optional
from typing import Any, Dict, List, Union, Literal, Optional

class PromptRequest(BaseModel):
task_id: str
Expand Down Expand Up @@ -92,6 +92,8 @@ class ListProvidersEntry(BaseModel):
name: str
models: List[str]
auth_strategy: AuthStrategy
registry: bool
fields: List[Field]


class ListProvidersResponse(BaseModel):
Expand All @@ -108,3 +110,4 @@ class GlobalConfig(BaseModel):
embeddings_provider_id: Optional[str] = None
api_keys: Dict[str, str] = {}
send_with_shift_enter: Optional[bool] = None
fields: Dict[str, Dict[str, Any]] = {}
Loading