Skip to content

Commit

Permalink
Merge pull request #253 from NVIDIA/feature/server-side-threads
Browse files Browse the repository at this point in the history
Add support for server-side threads.
  • Loading branch information
drazvan committed Jan 12, 2024
2 parents fac5f1a + e905636 commit aa15628
Show file tree
Hide file tree
Showing 14 changed files with 552 additions and 1 deletion.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased]

### Added

- [#253](https://github.com/NVIDIA/NeMo-Guardrails/pull/253) Support for [server-side threads](./docs/user_guides/server-guide.md#threads).

### Fixed

- [#239](https://github.com/NVIDIA/NeMo-Guardrails/pull/239)Fixed logging issue where `verbose=true` flag did not trigger expected log output.
- [#239](https://github.com/NVIDIA/NeMo-Guardrails/pull/239) Fixed logging issue where `verbose=true` flag did not trigger expected log output.
- [#228](https://github.com/NVIDIA/NeMo-Guardrails/pull/228) Fix docstrings for various functions.

## [0.6.1] - 2023-12-20
Expand Down
37 changes: 37 additions & 0 deletions docs/user_guides/server-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,43 @@ Sample response:
}]
```

### Threads

The Guardrails Server has basic support for storing the conversation threads. This is useful when you can only send the latest user message(s) for a conversation rather than the entire history (e.g., from a third-party integration hook).

#### Configuration

To use server-side threads, you have to register a datastore. To do this, you must create a `config.py` file in the root of the configurations folder (i.e., the folder containing all the guardrails configurations the server must load). Inside `config.py` use the `register_datastore` function to register the datastore you want to use.

Out-of-the-box, NeMo Guardrails has support for `MemoryStore` (useful for quick testing) and `RedisStore`. If you want to use a different backend, you can implement the [`DataStore`](../../nemoguardrails/server/datastore/datastore.py) interface and register a different instance in `config.py`.

Next, when making a call to the `/v1/chat/completions` endpoint, you must also include a `thread_id` field:

```
POST /v1/chat/completions
```
```json
{
"config_id": "config_1",
"thread_id": "1234567890123456",
"messages": [{
"role":"user",
"content":"Hello! What can you do for me?"
}]
}
```

> NOTE: for security reasons, the `thread_id` must have a minimum length of 16 characters.
As an example, check out this [configuration](../../examples/configs/threads).


#### Limitations

Currently, threads are not supported when streaming mode is used (will be added in a future release).

Threads are stored indefinitely; there is no cleanup mechanism.

### Chat UI

You can use the Chat UI to test a guardrails configuration quickly.
Expand Down
79 changes: 79 additions & 0 deletions examples/configs/threads/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Threads

This is a sample server config folder with server-side threads enabled.

To enable server-side threads, you must register a `DataStore` inside the `config.py` file, which should be placed at the root of the folder containing the rails configurations.

```python
from nemoguardrails.server.api import register_datastore
from nemoguardrails.server.datastore.memory_store import MemoryStore

register_datastore(MemoryStore())
```

## Rails Configurations

For demo purposes, this configuration uses a single dialog rail, which makes the bot respond slightly differently when the user greets the bot the second time in a row.

```colang
define user express greeting
"hi"
define bot express greeting
"Hello!"
define bot express greeting again
"Hello again!"
define flow
user express greeting
bot express greeting
user express greeting
bot express greeting again
```

## Running the server

To run the server, use the following command from the root of the project:

```bash
nemoguardrails server --config=examples/configs/threads
```

## Testing

When sending "hi" to the server without a thread ID, it always responds with "Hello!"

```bash
curl -X POST -H "Content-Type: application/json" -d '{"config_id": "config_1", "messages": [{"content": "hi", "role": "user"}]}' http://localhost:8000/v1/chat/completions
```

```
{"messages":[{"role":"assistant","content":"Hello!"}]}
```

```bash
curl -X POST -H "Content-Type: application/json" -d '{"config_id": "config_1", "messages": [{"content": "hi", "role": "user"}]}' http://localhost:8000/v1/chat/completions
```

```
{"messages":[{"role":"assistant","content":"Hello!"}]}
```

If you use a `thread_id`, then the conversation gets stored on the server side, and the second time we get the response "Hello again!".

```bash
curl -X POST -H "Content-Type: application/json" -d '{"config_id": "config_1", "thread_id": "1231231231231231", "messages": [{"content": "hi", "role": "user"}]}' http://localhost:8000/v1/chat/completions
```

```
{"messages":[{"role":"assistant","content":"Hello!"}]}
```

```bash
curl -X POST -H "Content-Type: application/json" -d '{"config_id": "config_1", "thread_id": "1231231231231231", "messages": [{"content": "hi", "role": "user"}]}' http://localhost:8000/v1/chat/completions
```

```
{"messages":[{"role":"assistant","content":"Hello again!"}]}
```
23 changes: 23 additions & 0 deletions examples/configs/threads/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemoguardrails.server.api import register_datastore
from nemoguardrails.server.datastore.memory_store import MemoryStore

# This example uses an in-memory data store.
register_datastore(MemoryStore())

# You can also use RedisStore
# register_datastore(RedisStore("redis://localhost/1"))
8 changes: 8 additions & 0 deletions examples/configs/threads/config_1/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
models: []

# We set the embeddings only flag so that we don't need to use the LLM.
# The test flow will use pre-defined flows and messages.
rails:
dialog:
user_messages:
embeddings_only: True
14 changes: 14 additions & 0 deletions examples/configs/threads/config_1/rails.co
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
define user express greeting
"hi"

define bot express greeting
"Hello!"

define bot express greeting again
"Hello again!"

define flow
user express greeting
bot express greeting
user express greeting
bot express greeting again
55 changes: 55 additions & 0 deletions nemoguardrails/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from starlette.staticfiles import StaticFiles

from nemoguardrails import LLMRails, RailsConfig
from nemoguardrails.server.datastore.datastore import DataStore
from nemoguardrails.streaming import StreamingHandler

logging.basicConfig(level=logging.INFO)
Expand All @@ -44,6 +45,11 @@
# The headers for each request
api_request_headers = contextvars.ContextVar("headers")

# The datastore that the Server should use.
# This is currently used only for storing threads.
# TODO: refactor to wrap the FastAPI instance inside a RailsServer class
# and get rid of all the global attributes.
datastore: Optional[DataStore] = None

app = FastAPI(
title="Guardrails Server API",
Expand Down Expand Up @@ -86,6 +92,10 @@

class RequestBody(BaseModel):
config_id: str = Field(description="The id of the configuration to be used.")
thread_id: Optional[str] = Field(
default=None,
description="The id of an existing thread to which the messages should be added.",
)
messages: List[dict] = Field(
default=None, description="The list of messages in the current conversation."
)
Expand Down Expand Up @@ -188,6 +198,32 @@ async def chat_completion(body: RequestBody, request: Request):
if body.context:
messages.insert(0, {"role": "context", "content": body.context})

# If we have a `thread_id` specified, we need to look up the thread
datastore_key = None

if body.thread_id:
if datastore is None:
raise RuntimeError("No DataStore has been configured.")

# We make sure the `thread_id` meets the minimum complexity requirement.
if len(body.thread_id) < 16:
return {
"messages": [
{
"role": "assistant",
"content": "The `thread_id` must have a minimum length of 16 characters.",
}
]
}

# Fetch the existing thread messages. For easier management, we prepend
# the string `thread-` to all thread keys.
datastore_key = "thread-" + body.thread_id
thread_messages = json.loads(await datastore.get(datastore_key) or "[]")

# And prepend them.
messages = thread_messages + messages

if (
body.stream
and llm_rails.config.streaming_supported
Expand All @@ -203,9 +239,17 @@ async def chat_completion(body: RequestBody, request: Request):
)
)

# TODO: Add support for thread_ids in streaming mode

return StreamingResponse(streaming_handler)
else:
bot_message = await llm_rails.generate_async(messages=messages)

# If we're using threads, we also need to update the data before returning
# the message.
if body.thread_id:
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))

return {"messages": [bot_message]}

except Exception as ex:
Expand Down Expand Up @@ -238,6 +282,13 @@ async def get_challenges():
return challenges


def register_datastore(datastore_instance: DataStore):
"""Registers a DataStore to be used by the server."""
global datastore

datastore = datastore_instance


@app.on_event("startup")
async def startup_event():
"""Register any additional challenges, if available at startup."""
Expand All @@ -255,6 +306,10 @@ async def startup_event():
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)

# If there is an `init` function, we call it with the reference to the app.
if config_module is not None and hasattr(config_module, "init"):
config_module.init(app)

# Finally, we register the static frontend UI serving

if not app.disable_chat_ui:
Expand Down
14 changes: 14 additions & 0 deletions nemoguardrails/server/datastore/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
42 changes: 42 additions & 0 deletions nemoguardrails/server/datastore/datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional


class DataStore:
"""A basic data store interface."""

async def set(self, key: str, value: str):
"""Save data into the datastore.
Args:
key: The key to use.
value: The value associated with the key.
Returns:
None
"""
raise NotImplementedError()

async def get(self, key: str) -> Optional[str]:
"""Return the value for the specified key.
Args:
key: The key to lookup.
Returns:
None if the key does not exist.
"""
raise NotImplementedError()
Loading

0 comments on commit aa15628

Please sign in to comment.