diff --git a/document_store.py b/document_store.py new file mode 100644 index 0000000..327cb1c --- /dev/null +++ b/document_store.py @@ -0,0 +1,45 @@ +from langchain.docstore.document import Document +from enum import Enum +from typing_extensions import List +from langchain.vectorstores import Chroma +import chromadb +from langchain.embeddings import OpenAIEmbeddings + + +class StorageBackend(Enum): + LOCAL = "local" + MEMORY = "memory" + GCS = "gcs" + S3 = "s3" + AZURE = "az" + + +def get_storage_root_path(bucket_name, storage_backend: StorageBackend): + root_path = Path(f"{storage_backend.value}://{bucket_name}") + return root_path + +def persist_to_bucket(bucket_path: str, store: Chroma): + store.persist('./db/chroma') + #TODO: Uplaod persisted file on disk to gcs + + +def store_documents(docs: List[Document], bucket_path: str, storage_backend: StorageBackend): + lagnchain_documents = [doc.to_langchain_document() for doc in docs] + embeddings_model = OpenAIEmbeddings() + persistent_client = chromadb.PersistentClient() + collection = persistent_client.get_or_create_collection(get_storage_root_path(bucket_path, storage_backend)) + collection.add(documents=lagnchain_documents) + langchain_chroma = Chroma( + client=persistent_client, + collection_name=bucket_path, + embedding_function=embeddings_model.embed_documents, + ) + print("There are", langchain_chroma._collection.count(), "in the collection") + + + + + + + + diff --git a/main.py b/main.py new file mode 100644 index 0000000..f857d95 --- /dev/null +++ b/main.py @@ -0,0 +1,22 @@ +from fastapi import FastAPI, HTTPException, status, Body +from typing import List +from langchain.docstore.document import Document +from document_store import StorageBackend +import document_store +from model import ChatMessage +from model import Doc + +app = FastAPI() + +@app.post("/index/documents") +async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend): + document_store.store_documents(chunks, bucket, storage_backend) + +@app.post("/chat") +async def chat(chat_message: ChatMessage): + pass + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..c1c3740 --- /dev/null +++ b/model.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel +from langchain.docstore.document import Document + +class ChatMessage(BaseModel): + message: str + session_id: str + +class Doc(BaseModel): + content: str + metadata: dict + + def to_langchain_document(): + return Document(page_content=self.content, metadata=self.metadata) \ No newline at end of file