Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/text rewriter #164

Open
wants to merge 6 commits into
base: Develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
39 changes: 39 additions & 0 deletions app/features/text_rewriter/core.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please include Langchain

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the requested changes

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from app.services.logger import setup_logger
from app.utils.document_loaders import get_docs
from app.features.text_rewriter.tools import TextRewriter
from app.api.error_utilities import ToolExecutorError
import langchain

logger = setup_logger()

ALLOWED_FILE_TYPES = {"pptx", "pdf", "docx", "txt", "csv", "youtube_url", "url", "gsheet"}

def executor(
raw_text: str,
instructions: str,
file_url: str,
file_type: str,
verbose=False):

try:
if verbose: logger.info(f"File URL loaded: {file_url}")

if file_type and file_url and file_type in ALLOWED_FILE_TYPES:
logger.info(f"Generating docs. from {file_url} with type {file_type}")
docs = get_docs(file_url, file_type, verbose=True)
elif raw_text:
docs = None
else:
if file_type not in ALLOWED_FILE_TYPES:
raise ToolExecutorError(f"File type {file_type} not supported")
raise ToolExecutorError("File URL and file type must be provided")

output = TextRewriter(instructions, verbose=verbose).rewrite(raw_text, docs)

except Exception as e:
error_message = f"Error in executor: {e}"
logger.error(error_message)
raise ValueError(error_message)

return output

24 changes: 24 additions & 0 deletions app/features/text_rewriter/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"inputs": [
{
"label": "Instructions",
"name": "instructions",
"type": "text"
},
{
"label": "Text to Rewrite",
"name": "raw_text",
"type": "text"
},
{
"label": "File URL",
"name": "file_url",
"type": "text"
},
{
"label": "File Type",
"name": "file_type",
"type": "text"
}
]
}
10 changes: 10 additions & 0 deletions app/features/text_rewriter/prompt/text-rewriter-prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
You are a helpful AI assistant thatt helps users rewrite text from documents.

You must follow the instructions given below on how the text should be rewritten:
{instructions}

You must respond as a JSON object with the below format:
{format_instructions}

Text to rewrite:
{context}
Empty file.
175 changes: 175 additions & 0 deletions app/features/text_rewriter/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import pytest

from app.features.text_rewriter.core import executor
from app.api.error_utilities import InputValidationError

# Base attributes reused across all tests
base_attributes = {
"instructions": "Rewrite the text in a more formal tone.",
"raw_text": "",
}

base_attributes_without_raw_text = {
"instructions": "Rewrite the text in a more formal tone.",
}

# PDF Tests
def test_executor_pdf_url_valid():
rewritten_text = executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/pdf/sample1.pdf",
file_type="pdf"
)
assert isinstance(rewritten_text, dict)

def test_executor_pdf_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/pdf/sample1.pdf",
file_type=1
)
assert isinstance(exc_info.value, ValueError)

# CSV Tests
def test_executor_csv_url_valid():
rewritten_text = executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/csv/sample1.csv",
file_type="csv"
)
assert isinstance(rewritten_text, dict)

def test_executor_csv_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/csv/sample1.csv",
file_type=1
)
assert isinstance(exc_info.value, ValueError)

# TXT Tests
def test_executor_txt_url_valid():
rewritten_text = executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/txt/sample1.txt",
file_type="txt"
)
assert isinstance(rewritten_text, dict)

def test_executor_txt_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/txt/sample1.txt",
file_type=1
)
assert isinstance(exc_info.value, ValueError)

# PPTX Tests
def test_executor_pptx_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://scholar.harvard.edu/files/torman_personal/files/samplepptx.pptx",
file_type=1
)
assert isinstance(exc_info.value, ValueError)

# DOCX Tests
def test_executor_docx_url_valid():
rewritten_text = executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/docx/sample1.docx",
file_type="docx"
)
assert isinstance(rewritten_text, dict)

def test_executor_docx_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/docx/sample1.docx",
file_type=1
)
assert isinstance(exc_info.value, ValueError)

# Invalid file type test
def test_executor_invalid_file_type():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/xlsx/sample1.xlsx",
file_type="xlsx"
)
assert isinstance(exc_info.value, ValueError)

# GSheets Tests
def test_executor_gsheets_url_valid():
rewritten_text = executor(
**base_attributes,
file_url="https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms/edit?gid=0#gid=0",
file_type="gsheet"
)
assert isinstance(rewritten_text, dict)

def test_executor_gsheets_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms/edit?gid=0#gid=0",
file_type=1
)
assert isinstance(exc_info.value, ValueError)

# Youtube URL Tests
def test_executor_youtube_url_valid():
rewritten_text = executor(
**base_attributes,
file_url="https://www.youtube.com/watch?v=HgBpFaATdoA",
file_type="youtube_url"
)
assert isinstance(rewritten_text, dict)

def test_executor_youtube_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://www.youtube.com/watch?v=HgBpFaATdoA",
file_type=1
)
assert isinstance(exc_info.value, ValueError)

def test_executor_pptx_url_invalid():

with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url = "https://getsamplefiles.com/download/pptx/sample-1.pptx",
file_type = "pptx",
)

assert isinstance(exc_info.value, ValueError)

# Plain text through text box
def test_executor_plain_text_valid():
rewritten_text = executor(
**base_attributes_without_raw_text,
raw_text="The quick brown fox jumps over the lazy dog.",
file_url="",
file_type="",
)

assert isinstance(rewritten_text, dict)

def test_executor_plain_text_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes_without_raw_text,
raw_text="",
file_url="",
file_type=1,
)

assert isinstance(exc_info.value, ValueError)
88 changes: 88 additions & 0 deletions app/features/text_rewriter/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pydantic import BaseModel, Field
from typing import List, Dict
import os
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_google_genai import GoogleGenerativeAI

from app.services.logger import setup_logger

logger = setup_logger(__name__)

def read_text_file(file_path):
# Get the directory containing the script file
script_dir = os.path.dirname(os.path.abspath(__file__))

# Combine the script directory with the relative file path
absolute_file_path = os.path.join(script_dir, file_path)

with open(absolute_file_path, 'r') as file:
return file.read()

class TextRewriter:
def __init__(self, instructions, prompt=None, model=None, parser=None, verbose=False):
default_config = {
"model": GoogleGenerativeAI(model="gemini-1.5-flash"),
"parser": JsonOutputParser(pydantic_object=RewrittenOutput),
"prompt": read_text_file("prompt/text-rewriter-prompt.txt"),
}

self.prompt = prompt or default_config["prompt"]
self.model = model or default_config["model"]
self.parser = parser or default_config["parser"]

self.instructions = instructions
self.verbose = verbose

def compile(self):
prompt = PromptTemplate(
template=self.prompt,
input_variables=["instructions", "context"],
partial_variables={"format_instructions": self.parser.get_format_instructions()}
)

chain = prompt | self.model | self.parser

if self.verbose: logger.info(f"Chain compiled: {chain}")

return chain

def validate_output(self, response: Dict) -> bool:
if 'rewritten_text' in response:
return True
return False

def rewrite(self, raw_text: str, documents: List[Document]):
chain = self.compile()
if documents:
doc_content = "\n".join([doc.page_content for doc in documents])
else:
doc_content = raw_text

attempts = 0
max_attempts = 5

while attempts < max_attempts:
response = chain.invoke({
"instructions": self.instructions,
"context": doc_content
})

if self.verbose:
logger.info(f"Generated response attempt {attempts + 1}: {response}")

# validate response incase of LLM hallucinations
if self.validate_output(response):
break

if self.verbose: logger.warning(f"Invalid response generated, retrying...")
# if response is invalid, retry
attempts += 1

if self.verbose: logger.info(f"Final response generated: {response}")

return response

class RewrittenOutput(BaseModel):
rewritten_text: str = Field(description="The rewritten text")