Skip to content

Commit

Permalink
Merge pull request #1 from ptorru/octoai-integration
Browse files Browse the repository at this point in the history
Add OctoAI integrations
  • Loading branch information
ptorru authored Feb 17, 2024
2 parents 52460e4 + a7b890d commit 7128cf1
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ These optional environment variables are used to authenticate to other supported
| `JINA_API_KEY` | API key for Jina AI. Used to authenticate to JinaAI's services for embedding and chat API | You can find your OpenAI API key [here](https://platform.openai.com/account/api-keys). You might need to login or register to OpenAI services |
| `AZURE_OPENAI_ENDOINT`| The URL of the Azure OpenAI endpoint you deployed. | You can find this in the Azure OpenAI portal under _Keys and Endpoints`|
| `AZURE_OPENAI_API_KEY` | The API key to use for your Azure OpenAI models. | You can find this in the Azure OpenAI portal under _Keys and Endpoints`|
| `OCTOAI_API_KEY` | API key for OctoAI. Used to authenticate for open source LLMs served in OctoAI | You can sign up for OctoAI and find your API key [here](https://octo.ai/)

</details>

Expand Down Expand Up @@ -281,4 +282,3 @@ gunicorn canopy_server.app:app --worker-class uvicorn.workers.UvicornWorker --bi
> The server interacts with services like Pinecone and OpenAI using your own authentication credentials.
When deploying the server on a public web hosting provider, it is recommended to enable an authentication mechanism,
so that your server would only take requests from authenticated users.

50 changes: 50 additions & 0 deletions src/canopy/config_templates/octoai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# ===========================================================
# Configuration file for Canopy Server
# ===========================================================
tokenizer:
# -------------------------------------------------------------------------------------------
# Tokenizer configuration
# Use LLamaTokenizer from HuggingFace with the relevant OSS model (e.g. LLama2)
# -------------------------------------------------------------------------------------------
type: LlamaTokenizer # Options: [OpenAITokenizer, LlamaTokenizer]
params:
model_name: hf-internal-testing/llama-tokenizer

chat_engine:
# -------------------------------------------------------------------------------------------
# Chat engine configuration
# Use OctoAI as the open source LLM provider
# You can find the list of supported LLMs at https://octo.ai/docs/text-gen-solution/rest-api
# -------------------------------------------------------------------------------------------
params:
max_prompt_tokens: 2048 # The maximum number of tokens to use for input prompt to the LLM.
llm: &llm
type: OctoAILLM
params:
model_name: mistral-7b-instruct-fp16 # The name of the model to use.

# query_builder:
# type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator, InstructionQueryGenerator]
# llm:
# type: OctoAILLM
# params:
# model_name: mistral-7b-instruct-fp16

context_engine:
# -------------------------------------------------------------------------------------------------------------
# ContextEngine configuration
# -------------------------------------------------------------------------------------------------------------
knowledge_base:
# -----------------------------------------------------------------------------------------------------------
# KnowledgeBase configuration
# -----------------------------------------------------------------------------------------------------------
record_encoder:
# --------------------------------------------------------------------------
# Configuration for the RecordEncoder subcomponent of the knowledge base.
# Use OctoAI's Embedding endpoint for dense encoding
# --------------------------------------------------------------------------
type: OctoAIRecordEncoder
params:
model_name: # The name of the model to use for encoding
thenlper/gte-large
batch_size: 2048 # The number of document chunks to encode in each call to the encoding model
1 change: 1 addition & 0 deletions src/canopy/knowledge_base/record_encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .jina import JinaRecordEncoder
from .sentence_transformers import SentenceTransformerRecordEncoder
from .hybrid import HybridRecordEncoder
from .octoai import OctoAIRecordEncoder
68 changes: 68 additions & 0 deletions src/canopy/knowledge_base/record_encoder/octoai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
from typing import List
from pinecone_text.dense.openai_encoder import OpenAIEncoder
from canopy.knowledge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery
from canopy.knowledge_base.record_encoder.dense import DenseRecordEncoder
from canopy.models.data_models import Query

OCTOAI_BASE_URL = "https://text.octoai.run/v1"


class OctoAIRecordEncoder(DenseRecordEncoder):
"""
OctoAIRecordEncoder is a type of DenseRecordEncoder that uses the OpenAI `embeddings` API.
The implementation uses the `OpenAIEncoder` class from the `pinecone-text` library.
For more information about see: https://github.com/pinecone-io/pinecone-text
""" # noqa: E501
"""
Initialize the OctoAIRecordEncoder
Args:
api_key: The OctoAI Endpoint API Key
base_url: The Base URL for the OctoAI Endpoint
model_name: The name of the OctoAI embeddings model to use for encoding. See https://octo.ai/docs/text-gen-solution/getting-started
batch_size: The number of documents or queries to encode at once.
Defaults to 1.
**kwargs: Additional arguments to pass to the underlying `pinecone-text. OpenAIEncoder`.
""" # noqa: E501
def __init__(self,
*,
api_key: str = "",
base_url: str = OCTOAI_BASE_URL,
model_name: str = "thenlper/gte-large",
batch_size: int = 1024,
**kwargs):

ae_api_key = api_key or os.environ.get("OCTOAI_API_KEY")
if not ae_api_key:
raise ValueError(
"An OctoAI API token is required to use OctoAI. "
"Please provide it as an argument "
"or set the OCTOAI_API_KEY environment variable."
)
ae_base_url = base_url
encoder = OpenAIEncoder(model_name,
base_url=ae_base_url, api_key=ae_api_key,
**kwargs)
super().__init__(dense_encoder=encoder, batch_size=batch_size)

def encode_documents(self, documents: List[KBDocChunk]) -> List[KBEncodedDocChunk]:
"""
Encode a list of documents, takes a list of KBDocChunk and returns a list of KBEncodedDocChunk.
Args:
documents: A list of KBDocChunk to encode.
Returns:
encoded chunks: A list of KBEncodedDocChunk, with the `values` field populated by the generated embeddings vector.
""" # noqa: E501
return super().encode_documents(documents)

async def _aencode_documents_batch(self,
documents: List[KBDocChunk]
) -> List[KBEncodedDocChunk]:
raise NotImplementedError

async def _aencode_queries_batch(self, queries: List[Query]) -> List[KBQuery]:
raise NotImplementedError
1 change: 1 addition & 0 deletions src/canopy/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .anyscale import AnyscaleLLM
from .azure_openai_llm import AzureOpenAILLM
from .cohere import CohereLLM
from .octoai import OctoAILLM
55 changes: 55 additions & 0 deletions src/canopy/llm/octoai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Optional, Any
import os
from canopy.llm import OpenAILLM
from canopy.llm.models import Function
from canopy.models.data_models import Messages

OCTOAI_BASE_URL = "https://text.octoai.run/v1"


class OctoAILLM(OpenAILLM):
"""
OctoAI LLM wrapper built on top of the OpenAI Python client.
Note: OctoAI requires a valid API key to use this class.
You can set the "OCTOAI_API_KEY" environment variable.
"""

def __init__(
self,
model_name: str = "mistral-7b-instruct-fp16",
*,
base_url: Optional[str] = OCTOAI_BASE_URL,
api_key: Optional[str] = None,
**kwargs: Any,
):
ae_api_key = api_key or os.environ.get("OCTOAI_API_KEY")
if not ae_api_key:
raise ValueError(
"OctoAI API key is required to use OctoAI. "
"If you haven't done it, please sign up at https://octo.ai"
"The key can be provided as an argument or via the OCTOAI_API_KEY environment variable."
)
ae_base_url = base_url
super().__init__(model_name, api_key=ae_api_key, base_url=ae_base_url, **kwargs)

def enforced_function_call(
self,
system_prompt: str,
chat_history: Messages,
function: Function,
*,
max_tokens: Optional[int] = None,
model_params: Optional[dict] = None,
) -> dict:
raise NotImplementedError("OctoAI doesn't support function calling.")

def aenforced_function_call(self,
system_prompt: str,
chat_history: Messages,
function: Function,
*,
max_tokens: Optional[int] = None,
model_params: Optional[dict] = None
):
raise NotImplementedError("OctoAI doesn't support function calling.")
53 changes: 53 additions & 0 deletions tests/system/record_encoder/test_octoai_record_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from canopy.knowledge_base.models import KBDocChunk
from canopy.knowledge_base.record_encoder.octoai import OctoAIRecordEncoder
from canopy.models.data_models import Query


documents = [KBDocChunk(
id=f"doc_1_{i}",
text=f"Sample document {i}",
document_id=f"doc_{i}",
metadata={"test": i},
source="doc_1",
)
for i in range(4)
]

queries = [Query(text="Sample query 1"),
Query(text="Sample query 2"),
Query(text="Sample query 3"),
Query(text="Sample query 4")]


@pytest.fixture
def encoder():
return OctoAIRecordEncoder(batch_size=2)


def test_dimension(encoder):
assert encoder.dimension == 1024


@pytest.mark.parametrize("items,function",
[(documents, "encode_documents"),
(queries, "encode_queries"),
([], "encode_documents"),
([], "encode_queries")])
def test_encode_documents(encoder, items, function):

encoded_documents = getattr(encoder, function)(items)

assert len(encoded_documents) == len(items)
assert all(len(encoded.values) == encoder.dimension
for encoded in encoded_documents)


@pytest.mark.asyncio
@pytest.mark.parametrize("items,function",
[("aencode_documents", documents),
("aencode_queries", queries)])
async def test_aencode_not_implemented(encoder, function, items):
with pytest.raises(NotImplementedError):
await encoder.aencode_queries(items)

0 comments on commit 7128cf1

Please sign in to comment.