-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/text rewriter #164
Open
Sarang-Nambiar
wants to merge
6
commits into
marvelai-org:Develop
Choose a base branch
from
Sarang-Nambiar:feat/text-rewriter
base: Develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Feat/text rewriter #164
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
2d85c7d
Initial commit
Sarang-Nambiar 86be739
chore: essential functions added in tools.py
Sarang-Nambiar ae2515d
feat: rewriter functionality complete
Sarang-Nambiar 03fe2f9
feat: Added unit test cases + bug fixes
Sarang-Nambiar 0bad428
fix: test case bugs
Sarang-Nambiar f59a1be
chore: imported langchain in core.py + reverted tools_config.json
Sarang-Nambiar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from app.services.logger import setup_logger | ||
from app.utils.document_loaders import get_docs | ||
from app.features.text_rewriter.tools import TextRewriter | ||
from app.api.error_utilities import ToolExecutorError | ||
import langchain | ||
|
||
logger = setup_logger() | ||
|
||
ALLOWED_FILE_TYPES = {"pptx", "pdf", "docx", "txt", "csv", "youtube_url", "url", "gsheet"} | ||
|
||
def executor( | ||
raw_text: str, | ||
instructions: str, | ||
file_url: str, | ||
file_type: str, | ||
verbose=False): | ||
|
||
try: | ||
if verbose: logger.info(f"File URL loaded: {file_url}") | ||
|
||
if file_type and file_url and file_type in ALLOWED_FILE_TYPES: | ||
logger.info(f"Generating docs. from {file_url} with type {file_type}") | ||
docs = get_docs(file_url, file_type, verbose=True) | ||
elif raw_text: | ||
docs = None | ||
else: | ||
if file_type not in ALLOWED_FILE_TYPES: | ||
raise ToolExecutorError(f"File type {file_type} not supported") | ||
raise ToolExecutorError("File URL and file type must be provided") | ||
|
||
output = TextRewriter(instructions, verbose=verbose).rewrite(raw_text, docs) | ||
|
||
except Exception as e: | ||
error_message = f"Error in executor: {e}" | ||
logger.error(error_message) | ||
raise ValueError(error_message) | ||
|
||
return output | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"inputs": [ | ||
{ | ||
"label": "Instructions", | ||
"name": "instructions", | ||
"type": "text" | ||
}, | ||
{ | ||
"label": "Text to Rewrite", | ||
"name": "raw_text", | ||
"type": "text" | ||
}, | ||
{ | ||
"label": "File URL", | ||
"name": "file_url", | ||
"type": "text" | ||
}, | ||
{ | ||
"label": "File Type", | ||
"name": "file_type", | ||
"type": "text" | ||
} | ||
] | ||
} |
10 changes: 10 additions & 0 deletions
10
app/features/text_rewriter/prompt/text-rewriter-prompt.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
You are a helpful AI assistant thatt helps users rewrite text from documents. | ||
|
||
You must follow the instructions given below on how the text should be rewritten: | ||
{instructions} | ||
|
||
You must respond as a JSON object with the below format: | ||
{format_instructions} | ||
|
||
Text to rewrite: | ||
{context} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import pytest | ||
|
||
from app.features.text_rewriter.core import executor | ||
from app.api.error_utilities import InputValidationError | ||
|
||
# Base attributes reused across all tests | ||
base_attributes = { | ||
"instructions": "Rewrite the text in a more formal tone.", | ||
"raw_text": "", | ||
} | ||
|
||
base_attributes_without_raw_text = { | ||
"instructions": "Rewrite the text in a more formal tone.", | ||
} | ||
|
||
# PDF Tests | ||
def test_executor_pdf_url_valid(): | ||
rewritten_text = executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/pdf/sample1.pdf", | ||
file_type="pdf" | ||
) | ||
assert isinstance(rewritten_text, dict) | ||
|
||
def test_executor_pdf_url_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/pdf/sample1.pdf", | ||
file_type=1 | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# CSV Tests | ||
def test_executor_csv_url_valid(): | ||
rewritten_text = executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/csv/sample1.csv", | ||
file_type="csv" | ||
) | ||
assert isinstance(rewritten_text, dict) | ||
|
||
def test_executor_csv_url_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/csv/sample1.csv", | ||
file_type=1 | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# TXT Tests | ||
def test_executor_txt_url_valid(): | ||
rewritten_text = executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/txt/sample1.txt", | ||
file_type="txt" | ||
) | ||
assert isinstance(rewritten_text, dict) | ||
|
||
def test_executor_txt_url_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/txt/sample1.txt", | ||
file_type=1 | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# PPTX Tests | ||
def test_executor_pptx_url_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://scholar.harvard.edu/files/torman_personal/files/samplepptx.pptx", | ||
file_type=1 | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# DOCX Tests | ||
def test_executor_docx_url_valid(): | ||
rewritten_text = executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/docx/sample1.docx", | ||
file_type="docx" | ||
) | ||
assert isinstance(rewritten_text, dict) | ||
|
||
def test_executor_docx_url_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/docx/sample1.docx", | ||
file_type=1 | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# Invalid file type test | ||
def test_executor_invalid_file_type(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://filesamples.com/samples/document/xlsx/sample1.xlsx", | ||
file_type="xlsx" | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# GSheets Tests | ||
def test_executor_gsheets_url_valid(): | ||
rewritten_text = executor( | ||
**base_attributes, | ||
file_url="https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms/edit?gid=0#gid=0", | ||
file_type="gsheet" | ||
) | ||
assert isinstance(rewritten_text, dict) | ||
|
||
def test_executor_gsheets_url_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms/edit?gid=0#gid=0", | ||
file_type=1 | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# Youtube URL Tests | ||
def test_executor_youtube_url_valid(): | ||
rewritten_text = executor( | ||
**base_attributes, | ||
file_url="https://www.youtube.com/watch?v=HgBpFaATdoA", | ||
file_type="youtube_url" | ||
) | ||
assert isinstance(rewritten_text, dict) | ||
|
||
def test_executor_youtube_url_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url="https://www.youtube.com/watch?v=HgBpFaATdoA", | ||
file_type=1 | ||
) | ||
assert isinstance(exc_info.value, ValueError) | ||
|
||
def test_executor_pptx_url_invalid(): | ||
|
||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes, | ||
file_url = "https://getsamplefiles.com/download/pptx/sample-1.pptx", | ||
file_type = "pptx", | ||
) | ||
|
||
assert isinstance(exc_info.value, ValueError) | ||
|
||
# Plain text through text box | ||
def test_executor_plain_text_valid(): | ||
rewritten_text = executor( | ||
**base_attributes_without_raw_text, | ||
raw_text="The quick brown fox jumps over the lazy dog.", | ||
file_url="", | ||
file_type="", | ||
) | ||
|
||
assert isinstance(rewritten_text, dict) | ||
|
||
def test_executor_plain_text_invalid(): | ||
with pytest.raises(ValueError) as exc_info: | ||
executor( | ||
**base_attributes_without_raw_text, | ||
raw_text="", | ||
file_url="", | ||
file_type=1, | ||
) | ||
|
||
assert isinstance(exc_info.value, ValueError) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from pydantic import BaseModel, Field | ||
from typing import List, Dict | ||
import os | ||
from langchain_core.documents import Document | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_core.output_parsers import JsonOutputParser | ||
from langchain_google_genai import GoogleGenerativeAI | ||
|
||
from app.services.logger import setup_logger | ||
|
||
logger = setup_logger(__name__) | ||
|
||
def read_text_file(file_path): | ||
# Get the directory containing the script file | ||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
# Combine the script directory with the relative file path | ||
absolute_file_path = os.path.join(script_dir, file_path) | ||
|
||
with open(absolute_file_path, 'r') as file: | ||
return file.read() | ||
|
||
class TextRewriter: | ||
def __init__(self, instructions, prompt=None, model=None, parser=None, verbose=False): | ||
default_config = { | ||
"model": GoogleGenerativeAI(model="gemini-1.5-flash"), | ||
"parser": JsonOutputParser(pydantic_object=RewrittenOutput), | ||
"prompt": read_text_file("prompt/text-rewriter-prompt.txt"), | ||
} | ||
|
||
self.prompt = prompt or default_config["prompt"] | ||
self.model = model or default_config["model"] | ||
self.parser = parser or default_config["parser"] | ||
|
||
self.instructions = instructions | ||
self.verbose = verbose | ||
|
||
def compile(self): | ||
prompt = PromptTemplate( | ||
template=self.prompt, | ||
input_variables=["instructions", "context"], | ||
partial_variables={"format_instructions": self.parser.get_format_instructions()} | ||
) | ||
|
||
chain = prompt | self.model | self.parser | ||
|
||
if self.verbose: logger.info(f"Chain compiled: {chain}") | ||
|
||
return chain | ||
|
||
def validate_output(self, response: Dict) -> bool: | ||
if 'rewritten_text' in response: | ||
return True | ||
return False | ||
|
||
def rewrite(self, raw_text: str, documents: List[Document]): | ||
chain = self.compile() | ||
if documents: | ||
doc_content = "\n".join([doc.page_content for doc in documents]) | ||
else: | ||
doc_content = raw_text | ||
|
||
attempts = 0 | ||
max_attempts = 5 | ||
|
||
while attempts < max_attempts: | ||
response = chain.invoke({ | ||
"instructions": self.instructions, | ||
"context": doc_content | ||
}) | ||
|
||
if self.verbose: | ||
logger.info(f"Generated response attempt {attempts + 1}: {response}") | ||
|
||
# validate response incase of LLM hallucinations | ||
if self.validate_output(response): | ||
break | ||
|
||
if self.verbose: logger.warning(f"Invalid response generated, retrying...") | ||
# if response is invalid, retry | ||
attempts += 1 | ||
|
||
if self.verbose: logger.info(f"Final response generated: {response}") | ||
|
||
return response | ||
|
||
class RewrittenOutput(BaseModel): | ||
rewritten_text: str = Field(description="The rewritten text") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please include Langchain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made the requested changes