-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Théo SARDIN
committed
Dec 12, 2023
1 parent
3274f54
commit 54cabbe
Showing
3 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from langchain.docstore.document import Document | ||
from enum import Enum | ||
from typing_extensions import List | ||
from langchain.vectorstores import Chroma | ||
import chromadb | ||
from langchain.embeddings import OpenAIEmbeddings | ||
|
||
|
||
class StorageBackend(Enum): | ||
LOCAL = "local" | ||
MEMORY = "memory" | ||
GCS = "gcs" | ||
S3 = "s3" | ||
AZURE = "az" | ||
|
||
|
||
def get_storage_root_path(bucket_name, storage_backend: StorageBackend): | ||
root_path = Path(f"{storage_backend.value}://{bucket_name}") | ||
return root_path | ||
|
||
def persist_to_bucket(bucket_path: str, store: Chroma): | ||
store.persist('./db/chroma') | ||
#TODO: Uplaod persisted file on disk to gcs | ||
|
||
|
||
def store_documents(docs: List[Document], bucket_path: str, storage_backend: StorageBackend): | ||
lagnchain_documents = [doc.to_langchain_document() for doc in docs] | ||
embeddings_model = OpenAIEmbeddings() | ||
persistent_client = chromadb.PersistentClient() | ||
collection = persistent_client.get_or_create_collection(get_storage_root_path(bucket_path, storage_backend)) | ||
collection.add(documents=lagnchain_documents) | ||
langchain_chroma = Chroma( | ||
client=persistent_client, | ||
collection_name=bucket_path, | ||
embedding_function=embeddings_model.embed_documents, | ||
) | ||
print("There are", langchain_chroma._collection.count(), "in the collection") | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from fastapi import FastAPI, HTTPException, status, Body | ||
from typing import List | ||
from langchain.docstore.document import Document | ||
from document_store import StorageBackend | ||
import document_store | ||
from model import ChatMessage | ||
from model import Doc | ||
|
||
app = FastAPI() | ||
|
||
@app.post("/index/documents") | ||
async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend): | ||
document_store.store_documents(chunks, bucket, storage_backend) | ||
|
||
@app.post("/chat") | ||
async def chat(chat_message: ChatMessage): | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from pydantic import BaseModel | ||
from langchain.docstore.document import Document | ||
|
||
class ChatMessage(BaseModel): | ||
message: str | ||
session_id: str | ||
|
||
class Doc(BaseModel): | ||
content: str | ||
metadata: dict | ||
|
||
def to_langchain_document(): | ||
return Document(page_content=self.content, metadata=self.metadata) |