diff --git a/ingest.py b/ingest.py index f2219020a..e0cd85cfe 100644 --- a/ingest.py +++ b/ingest.py @@ -24,6 +24,16 @@ from constants import CHROMA_SETTINGS +load_dotenv() + + +# Load environment variables +persist_directory = os.environ.get('PERSIST_DIRECTORY') +source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents') +embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME') +chunk_size = 500 +chunk_overlap = 50 + # Map file extensions to document loaders and their arguments LOADER_MAPPING = { ".csv": (CSVLoader, {}), @@ -44,7 +54,6 @@ load_dotenv() - def load_single_document(file_path: str) -> Document: ext = "." + file_path.rsplit(".", 1)[-1] if ext in LOADER_MAPPING: @@ -55,37 +64,61 @@ def load_single_document(file_path: str) -> Document: raise ValueError(f"Unsupported file extension '{ext}'") -def load_documents(source_dir: str) -> List[Document]: - # Loads all documents from source documents directory +def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]: + """ + Loads all documents from the source documents directory, ignoring specified files + """ all_files = [] for ext in LOADER_MAPPING: all_files.extend( glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True) ) - return [load_single_document(file_path) for file_path in all_files] + filtered_files = [file_path for file_path in all_files if file_path not in ignored_files] + return [load_single_document(file_path) for file_path in filtered_files] - -def main(): - # Load environment variables - persist_directory = os.environ.get('PERSIST_DIRECTORY') - source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents') - embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME') - - # Load documents and split in chunks +def process_documents(ignored_files: List[str] = []) -> List[Document]: + """ + Load documents and split in chunks + """ print(f"Loading documents from {source_directory}") - chunk_size = 500 - chunk_overlap = 50 - documents = load_documents(source_directory) + documents = load_documents(source_directory, ignored_files) + if not documents: + print("No new documents to load") + exit(0) + print(f"Loaded {len(documents)} new documents from {source_directory}") text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) texts = text_splitter.split_documents(documents) - print(f"Loaded {len(documents)} documents from {source_directory}") - print(f"Split into {len(texts)} chunks of text (max. {chunk_size} characters each)") + print(f"Split into {len(texts)} chunks of text (max. 500 tokens each)") + return texts + +def does_vectorstore_exist(persist_directory: str) -> bool: + """ + Checks if vectorstore exists + """ + if os.path.exists(os.path.join(persist_directory, 'index')): + if os.path.exists(os.path.join(persist_directory, 'chroma-collections.parquet')) and os.path.exists(os.path.join(persist_directory, 'chroma-embeddings.parquet')): + list_index_files = glob.glob(os.path.join(persist_directory, 'index/*.bin')) + list_index_files += glob.glob(os.path.join(persist_directory, 'index/*.pkl')) + if len(list_index_files) == 4: + return True + return False +def main(): # Create embeddings embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) - # Create and store locally vectorstore - db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS) + if does_vectorstore_exist(persist_directory): + # Update and store locally vectorstore + print(f"Appending to existing vectorstore at {persist_directory}") + db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) + collection = db.get() + texts = process_documents([metadata['source'] for metadata in collection['metadatas']]) + db.add_documents(texts) + else: + # Create and store locally vectorstore + print("Creating new vectorstore") + texts = process_documents() + db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS) db.persist() db = None