Skip to content

Commit

Permalink
Remove Execution Wrapper (#905)
Browse files Browse the repository at this point in the history
* Rm files readded by git

* Fix merge botch
  • Loading branch information
NolanTrem authored Aug 21, 2024
1 parent 30f524a commit a2267bc
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 171 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:

- id: black
name: black
entry: bash -c 'cd py && poetry run black $(git diff --cached --name-only --diff-filter=ACM | grep ".py$" | xargs -I {} echo ../{})'
entry: bash -c 'cd py && poetry run black .'
language: system
types: [python]
pass_filenames: false
28 changes: 21 additions & 7 deletions py/cli/commands/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@


@cli.command()
@click.argument("file_paths", nargs=-1, required=True, type=click.Path(exists=True))
@click.option("--document-ids", multiple=True, help="Document IDs for ingestion")
@click.option("--metadatas", type=JSON, help="Metadatas for ingestion as a JSON string")
@click.argument(
"file_paths", nargs=-1, required=True, type=click.Path(exists=True)
)
@click.option(
"--document-ids", multiple=True, help="Document IDs for ingestion"
)
@click.option(
"--metadatas", type=JSON, help="Metadatas for ingestion as a JSON string"
)
@click.option(
"--versions",
multiple=True,
Expand All @@ -27,18 +33,24 @@ def ingest_files(client, file_paths, document_ids, metadatas, versions):
document_ids = list(document_ids) if document_ids else None
versions = list(versions) if versions else None

response = client.ingest_files(file_paths, metadatas, document_ids, versions)
response = client.ingest_files(
file_paths, metadatas, document_ids, versions
)
click.echo(json.dumps(response, indent=2))


@cli.command()
@click.argument("file-paths", nargs=-1, required=True, type=click.Path(exists=True))
@click.argument(
"file-paths", nargs=-1, required=True, type=click.Path(exists=True)
)
@click.option(
"--document-ids",
required=True,
help="Document IDs to update (comma-separated)",
)
@click.option("--metadatas", type=JSON, help="Metadatas for updating as a JSON string")
@click.option(
"--metadatas", type=JSON, help="Metadatas for updating as a JSON string"
)
@click.pass_obj
def update_files(client, file_paths, document_ids, metadatas):
"""Update existing files in R2R."""
Expand Down Expand Up @@ -77,7 +89,9 @@ def ingest_files_from_urls(client, urls):

try:
response = client.ingest_files([temp_file_path])
click.echo(f"File '{filename}' ingested successfully. Response: {response}")
click.echo(
f"File '{filename}' ingested successfully. Response: {response}"
)
ingested_files.append(filename)
finally:
os.unlink(temp_file_path)
Expand Down
28 changes: 21 additions & 7 deletions py/cli/commands/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@


@cli.command()
@click.option("--query", prompt="Enter your search query", help="The search query")
@click.option(
"--query", prompt="Enter your search query", help="The search query"
)
# VectorSearchSettings
@click.option(
"--use-vector-search",
Expand All @@ -18,21 +20,27 @@
type=JSON,
help="Filters to apply to the vector search as a JSON",
)
@click.option("--search-limit", default=None, help="Number of search results to return")
@click.option(
"--search-limit", default=None, help="Number of search results to return"
)
@click.option("--do-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
# KGSearchSettings
@click.option("--use-kg-search", is_flag=True, help="Use knowledge graph search")
@click.option(
"--use-kg-search", is_flag=True, help="Use knowledge graph search"
)
@click.option("--kg-search-type", default=None, help="Local or Global")
@click.option("--kg-search-level", default=None, help="Level of KG search")
@click.option(
"--kg-search-generation-config",
type=JSON,
help="KG search generation config",
)
@click.option("--entity-types", type=JSON, help="Entity types to search for as a JSON")
@click.option(
"--entity-types", type=JSON, help="Entity types to search for as a JSON"
)
@click.option(
"--relationships", type=JSON, help="Relationships to search for as a JSON"
)
Expand Down Expand Up @@ -111,21 +119,27 @@ def search(client, query, **kwargs):
"--use-vector-search", is_flag=True, default=True, help="Use vector search"
)
@click.option("--filters", type=JSON, help="Search filters as JSON")
@click.option("--search-limit", default=10, help="Number of search results to return")
@click.option(
"--search-limit", default=10, help="Number of search results to return"
)
@click.option("--do-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
# KG Search Settings
@click.option("--use-kg-search", is_flag=True, help="Use knowledge graph search")
@click.option(
"--use-kg-search", is_flag=True, help="Use knowledge graph search"
)
@click.option("--kg-search-type", default="global", help="Local or Global")
@click.option(
"--kg-search-level",
default=None,
help="Level of cluster to use for Global KG search",
)
@click.option("--kg-search-model", default=None, help="Model for KG agent")
@click.option("--entity-types", type=JSON, help="Entity types to search for as a JSON")
@click.option(
"--entity-types", type=JSON, help="Entity types to search for as a JSON"
)
@click.option(
"--relationships", type=JSON, help="Relationships to search for as a JSON"
)
Expand Down
8 changes: 6 additions & 2 deletions py/core/base/providers/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class KGProvider(ABC):

def __init__(self, config: KGConfig) -> None:
if not isinstance(config, KGConfig):
raise ValueError("KGProvider must be initialized with a `KGConfig`.")
raise ValueError(
"KGProvider must be initialized with a `KGConfig`."
)
logger.info(f"Initializing KG provider with config: {config}")
self.config = config
self.validate_config()
Expand Down Expand Up @@ -116,7 +118,9 @@ def structured_query(
param_map = {}

@abstractmethod
def vector_query(self, query, **kwargs: Any) -> Tuple[list[Entity], list[float]]:
def vector_query(
self, query, **kwargs: Any
) -> Tuple[list[Entity], list[float]]:
"""Abstract method to query the graph store with a vector store query."""

# TODO - Type this method.
Expand Down
2 changes: 1 addition & 1 deletion py/core/configs/unstructured.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ provider = "unstructured_api"
provider = "unstructured"
method = "recursive"
chunk_size = 512
chunk_overlap = 50
chunk_overlap = 50
5 changes: 4 additions & 1 deletion py/core/main/app_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def r2r_app(
raise ValueError(f"Invalid config name: {config_name}")
config = R2RConfig.from_toml(R2RBuilder.CONFIG_OPTIONS[config_name])

if config.embedding.provider == "openai" and "OPENAI_API_KEY" not in os.environ:
if (
config.embedding.provider == "openai"
and "OPENAI_API_KEY" not in os.environ
):
raise ValueError(
"Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
)
Expand Down
29 changes: 22 additions & 7 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ async def search(
async with manage_run(self.run_manager, "search_app") as run_id:
t0 = time.time()

if kg_search_settings.use_kg_search and self.config.kg.provider is None:
if (
kg_search_settings.use_kg_search
and self.config.kg.provider is None
):
raise R2RException(
status_code=400,
message="Knowledge Graph search is not enabled in the configuration.",
Expand Down Expand Up @@ -164,17 +167,23 @@ async def rag(
)

if len(results) == 0:
raise R2RException(status_code=404, message="No results found")
raise R2RException(
status_code=404, message="No results found"
)
if len(results) > 1:
logger.warning(f"Multiple results found for query: {query}")
logger.warning(
f"Multiple results found for query: {query}"
)

completion_record.search_results = (
results[0].search_results
if hasattr(results[0], "search_results")
else None
)
completion_record.llm_response = (
results[0].completion if hasattr(results[0], "completion") else None
results[0].completion
if hasattr(results[0], "completion")
else None
)
completion_record.completion_end_time = datetime.now()

Expand Down Expand Up @@ -210,7 +219,9 @@ async def stream_rag_response(
):
async def stream_response():
async with manage_run(self.run_manager, "arag"):
async for chunk in await self.pipelines.streaming_rag_pipeline.run(
async for (
chunk
) in await self.pipelines.streaming_rag_pipeline.run(
input=to_async_generator([query]),
run_manager=self.run_manager,
vector_search_settings=vector_search_settings,
Expand Down Expand Up @@ -260,7 +271,9 @@ async def agent(

async def stream_response():
async with manage_run(self.run_manager, "arag_agent"):
async for chunk in self.agents.streaming_rag_agent.arun(
async for (
chunk
) in self.agents.streaming_rag_agent.arun(
messages=messages,
system_instruction=task_prompt_override,
vector_search_settings=vector_search_settings,
Expand Down Expand Up @@ -300,4 +313,6 @@ async def stream_response():
status_code=502,
message="Ollama server not reachable or returned an invalid response",
)
raise R2RException(status_code=500, message="Internal Server Error")
raise R2RException(
status_code=500, message="Internal Server Error"
)
6 changes: 3 additions & 3 deletions py/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ python = ">=3.9,<3.13"
httpx = "^0.27.0"
nest-asyncio = "^1.6.0"
fastapi = "^0.109.2"
click = "^8.1.7"

# Core dependencies (optional)
openai = { version = "^1.11.1" }

pydantic = { extras = ["email"], version = "^2.8.2", optional = true }
python-multipart = { version = "^0.0.9", optional = true }
gunicorn = { version = "^21.2.0", optional = true }
Expand Down
8 changes: 6 additions & 2 deletions py/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def __init__(
self._retrieval,
self._server,
]:
for name, method in inspect.getmembers(group, predicate=inspect.isfunction):
for name, method in inspect.getmembers(
group, predicate=inspect.isfunction
):
if not name.startswith("_"):
self._methods[name] = method

Expand Down Expand Up @@ -181,7 +183,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

def __getattr__(self, name):
if name in self._methods:
return lambda *args, **kwargs: self._methods[name](self, *args, **kwargs)
return lambda *args, **kwargs: self._methods[name](
self, *args, **kwargs
)
raise AttributeError(f"'R2RClient' object has no attribute '{name}'")

def __dir__(self):
Expand Down
8 changes: 6 additions & 2 deletions py/sdk/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ async def ingest_files(
for path in file_paths:
if os.path.isdir(path):
for root, _, files in os.walk(path):
all_file_paths.extend(os.path.join(root, file) for file in files)
all_file_paths.extend(
os.path.join(root, file) for file in files
)
else:
all_file_paths.append(path)

Expand Down Expand Up @@ -94,7 +96,9 @@ async def update_files(
dict: Update results containing processed, failed, and skipped documents.
"""
if len(file_paths) != len(document_ids):
raise ValueError("Number of file paths must match number of document IDs.")
raise ValueError(
"Number of file paths must match number of document IDs."
)

with ExitStack() as stack:
files = [
Expand Down
Loading

0 comments on commit a2267bc

Please sign in to comment.