-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f1efcc4
commit 134d3e6
Showing
15 changed files
with
1,328 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""RAG Option 2.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""RAG chain for Option 2.""" | ||
|
||
from langchain_core.output_parsers.string import StrOutputParser | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_core.runnables import RunnablePassthrough | ||
from langchain_core.runnables.base import RunnableSequence | ||
from omegaconf.dictconfig import DictConfig | ||
|
||
from backend.utils.llm import get_text_llm | ||
from backend.utils.retriever import get_retriever | ||
|
||
from . import prompts | ||
|
||
|
||
def get_chain(config: DictConfig) -> RunnableSequence: | ||
"""Constructs a RAG pipeline that retrieves text data from documents. | ||
The pipeline consists of the following steps: | ||
1. Retrieval of documents using a retriever object. | ||
2. Prompting the model with the text data. | ||
4. Generating responses using a text language model. | ||
5. Parsing the string output. | ||
Args: | ||
config (DictConfig): Configuration object. | ||
Returns: | ||
RunnableSequence: RAG pipeline. | ||
""" | ||
retriever = get_retriever(config) | ||
model = get_text_llm(config) | ||
|
||
# Prompt template | ||
prompt = ChatPromptTemplate.from_template(prompts.RAG_PROMPT) | ||
|
||
# Define the RAG pipeline | ||
chain = ( | ||
{ | ||
"context": retriever, | ||
"question": RunnablePassthrough(), | ||
} | ||
| prompt | ||
| model | ||
| StrOutputParser() | ||
) | ||
|
||
return chain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
"""Configuration schema for the RAG Option 2.""" | ||
|
||
from typing import Literal | ||
|
||
from omegaconf import OmegaConf | ||
from omegaconf.dictconfig import DictConfig | ||
from pydantic import BaseModel, ConfigDict, root_validator | ||
from pydantic.dataclasses import dataclass | ||
|
||
|
||
class HydraObject(BaseModel): | ||
"""Configuration for objects to be instantiated by Hydra.""" | ||
|
||
target: str | ||
partial: bool | None | ||
|
||
class Config: | ||
"""Pydantic configuration.""" | ||
|
||
extra = "allow" | ||
fields = {"target": "_target_", "partial": "_partial_"} | ||
|
||
|
||
@dataclass(config=ConfigDict(extra="forbid")) | ||
class PathConfig: | ||
"""Configuration for paths.""" | ||
|
||
docs: str | ||
database: str | ||
export_extracted: str | ||
|
||
|
||
@dataclass(config=ConfigDict(extra="forbid")) | ||
class IngestConfig: | ||
"""Configuration for PDF ingestion.""" | ||
|
||
clear_database: bool | ||
|
||
chunking_enable: bool | ||
chunking_func: HydraObject | ||
|
||
table_format: Literal["text", "html", "image"] | ||
summarize_text: bool | ||
summarize_table: bool | ||
|
||
metadata_keys: list[str] | ||
|
||
export_extracted: bool | ||
|
||
@root_validator(pre=False) | ||
def validate_table_format(cls, values: dict) -> dict: | ||
"""Validate the 'table_format' field in relation to 'summarize_table'. | ||
This validator ensures that if the 'table_format' is set to 'image', | ||
then 'summarize_table' must also be set to True. It enforces the rule | ||
that image tables require summarization. | ||
Args: | ||
values (dict): Dictionnary of field values for the IngestConfig class. | ||
Raises: | ||
ValueError: If 'table_format' is 'image' and 'summarize_table' is not True. | ||
Returns: | ||
dict: The validated field values. | ||
""" | ||
table_format = values.get("table_format") | ||
summarize_table = values.get("summarize_table") | ||
|
||
if table_format == "image" and not summarize_table: | ||
raise ValueError("summarize_table must be True for table_format=image") | ||
|
||
return values | ||
|
||
|
||
@dataclass(config=ConfigDict(extra="forbid")) | ||
class Config: | ||
"""Configuration for the RAG Option 2.""" | ||
|
||
name: str | ||
|
||
path: PathConfig | ||
|
||
text_llm: HydraObject | ||
vision_llm: HydraObject | ||
embedding: HydraObject | ||
vectorstore: HydraObject | ||
store: HydraObject | ||
retriever: HydraObject | ||
|
||
ingest: IngestConfig | ||
|
||
|
||
def validate_config(config: DictConfig) -> Config: | ||
"""Validate the configuration. | ||
Args: | ||
config (DictConfig): Configuration object. | ||
Returns: | ||
Config: Validated configuration object. | ||
""" | ||
# Resolve the DictConfig to a native Python object | ||
cfg_obj = OmegaConf.to_object(config) | ||
# Instantiate the Config class | ||
validated_config = Config(**cfg_obj) | ||
return validated_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
name: rag_2 | ||
|
||
path: | ||
docs: "docs" | ||
export_extracted: "${.docs}/extracted/${..name}" | ||
database: "database/${..name}" | ||
|
||
text_llm: | ||
_target_: langchain_openai.AzureChatOpenAI | ||
azure_endpoint: ${oc.env:TEXT_OPENAI_ENDPOINT} | ||
openai_api_key: ${oc.env:TEXT_OPENAI_API_KEY} | ||
openai_api_version: "2024-02-15-preview" | ||
deployment_name: "gpt4" | ||
temperature: 0.0 | ||
max_tokens: 1024 | ||
|
||
vision_llm: | ||
_target_: langchain_openai.AzureChatOpenAI | ||
azure_endpoint: ${oc.env:VISION_OPENAI_ENDPOINT} | ||
openai_api_key: ${oc.env:VISION_OPENAI_API_KEY} | ||
openai_api_version: "2024-02-15-preview" | ||
deployment_name: "gpt-4-vision" | ||
temperature: 0.0 | ||
max_tokens: 1024 | ||
|
||
embedding: | ||
_target_: langchain_openai.AzureOpenAIEmbeddings | ||
azure_endpoint: ${oc.env:EMBEDDING_OPENAI_ENDPOINT} | ||
openai_api_key: ${oc.env:EMBEDDING_OPENAI_API_KEY} | ||
deployment: "ada" | ||
chunk_size: 500 | ||
|
||
vectorstore: | ||
_target_: langchain_community.vectorstores.Chroma | ||
collection_name: "summaries" | ||
embedding_function: ${..embedding} | ||
persist_directory: "${..path.database}/chroma_db" | ||
|
||
store: | ||
_target_: langchain.storage.LocalFileStore | ||
root_path: "${..path.database}/multi_vector_retriever_metadata/" | ||
|
||
retriever: | ||
_target_: langchain.retrievers.multi_vector.MultiVectorRetriever | ||
vectorstore: ${..vectorstore} | ||
byte_store: ${..store} | ||
id_key: "doc_id" | ||
|
||
ingest: | ||
clear_database: True | ||
|
||
chunking_enable: True | ||
chunking_func: | ||
_target_: unstructured.chunking.title.chunk_by_title | ||
_partial_: True | ||
max_characters: 4000 | ||
new_after_n_chars: 3800 | ||
combine_text_under_n_chars: 2000 | ||
|
||
table_format: "html" # "text" or "html" or "image" | ||
summarize_text: False | ||
summarize_table: True | ||
|
||
vectorstore_source: # retrieval step | ||
text: "content" # "content" or "summary" if enabled | ||
table: "summary" | ||
image: "summary" | ||
|
||
docstore_source: # RAG step | ||
text: "content" | ||
table: "content" | ||
image: "summary" | ||
|
||
export_extracted: True | ||
|
||
metadata_keys: | ||
- filename | ||
- page_number |
Oops, something went wrong.