diff --git a/js/sdk/__tests__/DocumentsIntegrationSuperUser.test.ts b/js/sdk/__tests__/DocumentsIntegrationSuperUser.test.ts index 0dcea2871..90f4c6213 100644 --- a/js/sdk/__tests__/DocumentsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/DocumentsIntegrationSuperUser.test.ts @@ -1,5 +1,6 @@ import { r2rClient } from "../src/index"; import { describe, test, beforeAll, expect, afterAll } from "@jest/globals"; +import { assert } from "console"; import fs from "fs"; import path from "path"; @@ -9,6 +10,7 @@ const TEST_OUTPUT_DIR = path.join(__dirname, "test-output"); /** * marmeladov.txt will have an id of 83ef5342-4275-5b75-92d6-692fa32f8523 * The untitled document will have an id of 5556836e-a51c-57c7-916a-de76c79df2b6 + * The default collection id is 122fdf6a-e116-546b-a8f6-e4cb2e2c0a09 */ describe("r2rClient V3 Documents Integration Tests", () => { let client: r2rClient; @@ -35,7 +37,7 @@ describe("r2rClient V3 Documents Integration Tests", () => { test("Create document with file path", async () => { const response = await client.documents.create({ file: { path: "examples/data/marmeladov.txt", name: "marmeladov.txt" }, - metadata: { title: "marmeladov.txt" }, + metadata: { title: "marmeladov.txt", numericId: 123 }, }); expect(response.results.documentId).toBeDefined(); @@ -45,7 +47,7 @@ describe("r2rClient V3 Documents Integration Tests", () => { test("Create document with content", async () => { const response = await client.documents.create({ raw_text: "This is a test document", - metadata: { title: "Test Document" }, + metadata: { title: "Test Document", numericId: 456 }, }); expect(response.results.documentId).toBeDefined(); @@ -175,7 +177,131 @@ describe("r2rClient V3 Documents Integration Tests", () => { ).rejects.toThrow(/Only one of file, raw_text, or chunks may be provided/); }); - test("Delete Raskolnikov.txt", async () => { + test("Search with $lte filter should only return documents with numericId <= 200", async () => { + const response = await client.retrieval.search({ + query: "Test query", + searchSettings: { + filters: { + numericId: { $lte: 200 }, + }, + }, + }); + + expect(response.results.chunkSearchResults).toBeDefined(); + expect( + response.results.chunkSearchResults.every( + (result) => result.metadata?.numericId <= 200, + ), + ).toBe(true); + }); + + test("Search with $gte filter should only return documents with metadata.numericId >= 400", async () => { + const response = await client.retrieval.search({ + query: "Test query", + searchSettings: { + filters: { + "metadata.numericId": { $gte: 400 }, + }, + }, + }); + + expect(response.results.chunkSearchResults).toBeDefined(); + expect( + response.results.chunkSearchResults.every( + (result) => result.metadata?.numericId >= 400, + ), + ).toBe(true); + }); + + test("Search with $eq filter should only return exact matches", async () => { + const response = await client.retrieval.search({ + query: "Test query", + searchSettings: { + filters: { + numericId: { $eq: 123 }, + }, + }, + }); + + expect(response.results.chunkSearchResults).toBeDefined(); + expect( + response.results.chunkSearchResults.every( + (result) => result.metadata?.numericId === 123, + ), + ).toBe(true); + }); + + test("Search with range filter should return documents within range", async () => { + const response = await client.retrieval.search({ + query: "Test query", + searchSettings: { + filters: { + "metadata.numericId": { + $gte: 500, + }, + }, + }, + }); + + expect(response.results.chunkSearchResults).toBeDefined(); + expect( + response.results.chunkSearchResults.every((result) => { + const numericId = result.metadata?.numericId; + return numericId >= 100 && numericId <= 500; + }), + ).toBe(true); + }); + + test("Search without filters should return both documents", async () => { + const response = await client.retrieval.search({ + query: "Test query", + }); + + expect(response.results.chunkSearchResults).toBeDefined(); + expect(response.results.chunkSearchResults.length).toBeGreaterThan(0); + + const numericIds = response.results.chunkSearchResults.map((result) => { + return result.metadata?.numericId || result.metadata?.numericid; + }); + + expect(numericIds.filter((id) => id !== undefined)).toContain(123); + expect(numericIds.filter((id) => id !== undefined)).toContain(456); + }); + + // test("Filter on collection_id", async () => { + // const response = await client.retrieval.search({ + // query: "Test query", + // searchSettings: { + // filters: { + // collection_ids: { + // $in: ["122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"], + // }, + // }, + // }, + // }); + // expect(response.results.chunkSearchResults).toBeDefined(); + // expect(response.results.chunkSearchResults.length).toBeGreaterThan(0); + // expect(response.results.chunkSearchResults[0].collectionIds).toContain( + // "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09", + // ); + // }); + + test("Filter on non-existant column should return empty", async () => { + const response = await expect( + client.retrieval.search({ + query: "Test query", + searchSettings: { + filters: { + nonExistentColumn: { + $eq: ["122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"], + }, + }, + }, + }), + ); + }); + + test("Delete marmeladov.txt", async () => { const response = await client.documents.delete({ id: "83ef5342-4275-5b75-92d6-692fa32f8523", }); diff --git a/py/core/database/filters.py b/py/core/database/filters.py index b76754591..fb6a36625 100644 --- a/py/core/database/filters.py +++ b/py/core/database/filters.py @@ -140,9 +140,7 @@ def __init__( else: self.top_level_columns = set(top_level_columns) self.json_column = json_column - self.params: list[Any] = ( - params # params are mutated during construction - ) + self.params: list[Any] = params # mutated during construction self.mode = mode def build(self, expr: FilterExpression) -> Tuple[str, list[Any]]: @@ -171,9 +169,8 @@ def _build_expression(self, expr: FilterExpression) -> str: @staticmethod def _psql_quote_literal(value: str) -> str: """ - Safely quote a string literal for PostgreSQL to prevent SQL injection. - This is a simple implementation - in production, you should use proper parameterization - or your database driver's quoting functions. + Simple quoting for demonstration. In production, use parameterized queries or + your DB driver's quoting function instead. """ return "'" + value.replace("'", "''") + "'" @@ -183,31 +180,81 @@ def _build_condition(self, cond: FilterCondition) -> str: op = cond.operator val = cond.value - # Handle special logic for collection_id + # 1. If the filter references "parent_id", handle it as a single-UUID column for graphs: + if key == "parent_id": + return self._build_parent_id_condition(op, val) + + # 2. If the filter references "collection_id", handle it as an array column (chunks) if key == "collection_id": return self._build_collection_id_condition(op, val) + # 3. Otherwise, decide if it's top-level or metadata: if field_is_metadata: return self._build_metadata_condition(key, op, val) else: return self._build_column_condition(key, op, val) + def _build_parent_id_condition(self, op: str, val: Any) -> str: + """ + For 'graphs' tables, parent_id is a single UUID (not an array). + We handle the same ops but in a simpler, single-UUID manner. + """ + param_idx = len(self.params) + 1 + + if op == "$eq": + if not isinstance(val, str): + raise FilterError( + "$eq for parent_id expects a single UUID string" + ) + self.params.append(val) + return f"parent_id = ${param_idx}::uuid" + + elif op == "$ne": + if not isinstance(val, str): + raise FilterError( + "$ne for parent_id expects a single UUID string" + ) + self.params.append(val) + return f"parent_id != ${param_idx}::uuid" + + elif op == "$in": + # A list of UUIDs, any of which might match + if not isinstance(val, list): + raise FilterError( + "$in for parent_id expects a list of UUID strings" + ) + self.params.append(val) + return f"parent_id = ANY(${param_idx}::uuid[])" + + elif op == "$nin": + # A list of UUIDs, none of which may match + if not isinstance(val, list): + raise FilterError( + "$nin for parent_id expects a list of UUID strings" + ) + self.params.append(val) + return f"parent_id != ALL(${param_idx}::uuid[])" + + else: + # You could add more (like $gt, $lt, etc.) if your schema wants them + raise FilterError(f"Unsupported operator {op} for parent_id") + def _build_collection_id_condition(self, op: str, val: Any) -> str: + """ + For the 'chunks' table, collection_ids is an array of UUIDs. + This logic stays exactly as you had it. + """ param_idx = len(self.params) + 1 - # Handle operations if op == "$eq": - # Expect a single UUID, ensure val is a string if not isinstance(val, str): raise FilterError( "$eq for collection_id expects a single UUID string" ) self.params.append(val) - # Check if val is in the collection_ids array return f"${param_idx}::uuid = ANY(collection_ids)" elif op == "$ne": - # Not equal means val is not in collection_ids if not isinstance(val, str): raise FilterError( "$ne for collection_id expects a single UUID string" @@ -216,31 +263,25 @@ def _build_collection_id_condition(self, op: str, val: Any) -> str: return f"NOT (${param_idx}::uuid = ANY(collection_ids))" elif op == "$in": - # Expect a list of UUIDs, any of which may match if not isinstance(val, list): raise FilterError( "$in for collection_id expects a list of UUID strings" ) self.params.append(val) - # Use overlap to check if any of the given IDs are in collection_ids return f"collection_ids && ${param_idx}::uuid[]" elif op == "$nin": - # None of the given UUIDs should be in collection_ids if not isinstance(val, list): raise FilterError( "$nin for collection_id expects a list of UUID strings" ) self.params.append(val) - # Negate overlap condition return f"NOT (collection_ids && ${param_idx}::uuid[])" elif op == "$contains": - # If someone tries "$contains" with a single collection_id, we can check if collection_ids fully contain it - # Usually $contains might mean we want to see if collection_ids contain a certain element. - # That's basically $eq logic. For a single value: if isinstance(val, str): - self.params.append([val]) # Array of one element + # single string -> array with one element + self.params.append([val]) return f"collection_ids @> ${param_idx}::uuid[]" elif isinstance(val, list): self.params.append(val) @@ -278,7 +319,6 @@ def _build_column_condition(self, col: str, op: str, val: Any) -> str: self.params.append(val) return f"{col} @> ${param_idx}" elif op == "$any": - # If col == "collection_ids" handle special case if col == "collection_ids": self.params.append(f"%{val}%") return f"array_to_string({col}, ',') LIKE ${param_idx}" @@ -296,8 +336,7 @@ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str: json_col = self.json_column # Strip "metadata." prefix if present - if key.startswith("metadata."): - key = key[len("metadata.") :] + key = key.removeprefix("metadata.") # Split on '.' to handle nested keys parts = key.split(".") @@ -310,26 +349,21 @@ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str: "$gte", "$eq", "$ne", - ) + ) and isinstance(val, (int, float, str)) if op == "$in" or op == "$contains" or isinstance(val, (list, dict)): use_text_extraction = False # Build the JSON path expression if len(parts) == 1: - # Single part key if use_text_extraction: path_expr = f"{json_col}->>'{parts[0]}'" else: path_expr = f"{json_col}->'{parts[0]}'" else: - # Multiple segments - inner_parts = parts[:-1] - last_part = parts[-1] - # Build chain for the inner parts path_expr = json_col - for p in inner_parts: + for p in parts[:-1]: path_expr += f"->'{p}'" - # Last part + last_part = parts[-1] if use_text_extraction: path_expr += f"->>'{last_part}'" else: @@ -337,14 +371,12 @@ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str: # Convert numeric values to strings for text comparison def prepare_value(v): - if isinstance(v, (int, float)): - return str(v) - return v + return str(v) if isinstance(v, (int, float)) else v - # Now apply the operator logic if op == "$eq": if use_text_extraction: - self.params.append(prepare_value(val)) + prepared_val = prepare_value(val) + self.params.append(prepared_val) return f"{path_expr} = ${param_idx}" else: self.params.append(json.dumps(val)) @@ -372,7 +404,6 @@ def prepare_value(v): if not isinstance(val, list): raise FilterError("argument to $in filter must be a list") - # For regular scalar values, use ANY with text extraction if use_text_extraction: str_vals = [ str(v) if isinstance(v, (int, float)) else v for v in val @@ -413,7 +444,6 @@ def apply_filters( """ Apply filters with consistent WHERE clause handling """ - if not filters: return "", params diff --git a/py/core/database/graphs.py b/py/core/database/graphs.py index 235292cb8..a93e58418 100644 --- a/py/core/database/graphs.py +++ b/py/core/database/graphs.py @@ -12,7 +12,7 @@ import asyncpg import httpx -from asyncpg.exceptions import UndefinedTableError, UniqueViolationError +from asyncpg.exceptions import UndefinedColumnError, UniqueViolationError from fastapi import HTTPException from core.base.abstractions import ( @@ -37,6 +37,7 @@ from .base import PostgresConnectionManager from .collections import PostgresCollectionsHandler +from .filters import apply_filters logger = logging.getLogger() @@ -2599,141 +2600,63 @@ async def graph_search( ORDER BY {embedding_type} <=> $1 LIMIT $2; """ + try: + results = await self.connection_manager.fetch_query( + QUERY, tuple(params) + ) - results = await self.connection_manager.fetch_query( - QUERY, tuple(params) - ) - - for result in results: - output = { - prop: result[prop] for prop in property_names if prop in result - } - output["similarity_score"] = 1 - float(result["similarity_score"]) - yield output + for result in results: + output = { + prop: result[prop] + for prop in property_names + if prop in result + } + output["similarity_score"] = 1 - float( + result["similarity_score"] + ) + yield output + except UndefinedColumnError as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while searching: {e}", + ) from e + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the entity: {e}", + ) from e def _build_filters( self, filter_dict: dict, parameters: list[Any], search_type: str ) -> str: - """ - Build a WHERE clause from a nested filter dictionary for the graph search. - For communities we use collection_id as primary key filter; for entities/relationships we use parent_id. - """ - - # Determine primary identifier column depending on search_type - # communities: use collection_id - # entities/relationships: use parent_id - base_id_column = ( - "collection_id" if search_type == "communities" else "parent_id" - ) - - def parse_condition(key: str, value: Any) -> str: - # This function returns a single condition (string) or empty if no valid condition. - # Supported keys: - # - base_id_column (collection_id or parent_id) - # - metadata fields: metadata.some_field - # Supported ops: $eq, $ne, $lt, $lte, $gt, $gte, $in, $contains - if key == base_id_column: - # e.g. {"collection_id": {"$eq": ""}} - if isinstance(value, dict): - op, clause = next(iter(value.items())) - if op == "$eq": - parameters.append(str(clause)) - return f"{base_id_column} = ${len(parameters)}::uuid" - elif op == "$in": - # $in expects a list of UUIDs - parameters.append([str(x) for x in clause]) - return f"{base_id_column} = ANY(${len(parameters)}::uuid[])" - else: - # direct equality? - parameters.append(str(value)) - return f"{base_id_column} = ${len(parameters)}::uuid" + """Use the filter module to build WHERE clause for graph search.""" + if not filter_dict: + return "" - elif key.startswith("metadata."): - # Handle metadata filters - # Example: {"metadata.some_key": {"$eq": "value"}} - field = key.split("metadata.")[1] + working_filter = filter_dict.copy() - if isinstance(value, dict): - op, clause = next(iter(value.items())) - if op == "$eq": - parameters.append(clause) - return f"(metadata->>'{field}') = ${len(parameters)}" - elif op == "$ne": - parameters.append(clause) - return f"(metadata->>'{field}') != ${len(parameters)}" - elif op == "$lt": - parameters.append(clause) - return f"(metadata->>'{field}')::float < ${len(parameters)}::float" - elif op == "$lte": - parameters.append(clause) - return f"(metadata->>'{field}')::float <= ${len(parameters)}::float" - elif op == "$gt": - parameters.append(clause) - return f"(metadata->>'{field}')::float > ${len(parameters)}::float" - elif op == "$gte": - parameters.append(clause) - return f"(metadata->>'{field}')::float >= ${len(parameters)}::float" - elif op == "$in": - # Ensure clause is a list - if not isinstance(clause, list): - raise Exception( - "argument to $in filter must be a list" - ) - # Append the Python list as a parameter; many drivers can convert Python lists to arrays - parameters.append(clause) - # Cast the parameter to a text array type - return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])" - - # elif op == "$in": - # # For $in, we assume an array of values and check if the field is in that set. - # # Note: This is simplistic, adjust as needed. - # parameters.append(clause) - # # convert field to text and check membership - # return f"(metadata->>'{field}') = ANY(SELECT jsonb_array_elements_text(${len(parameters)}::jsonb))" - elif op == "$contains": - # $contains for metadata likely means metadata @> clause in JSON. - # If clause is dict or list, we use json containment. - parameters.append(json.dumps(clause)) - return f"metadata @> ${len(parameters)}::jsonb" - else: - # direct equality - parameters.append(value) - return f"(metadata->>'{field}') = ${len(parameters)}" + # Handle communities case + if search_type == "communities": + working_filter = { + "collection_id" if k == "parent_id" else k: v + for k, v in working_filter.items() + } - # Add additional conditions for other columns if needed - # If key not recognized, return empty so it doesn't break query - return "" + # Handle collection_ids filter for combined search + if "collection_ids" in working_filter: + collection_filter = working_filter.pop("collection_ids") + # Transform the collection_ids filter into an OR condition + working_filter["$or"] = [ + {"collection_ids": collection_filter}, # For chunks table + {"parent_id": collection_filter}, # For graphs tables + ] - def parse_filter(fd: dict) -> str: - filter_conditions = [] - for k, v in fd.items(): - if k == "$and": - and_parts = [parse_filter(sub) for sub in v if sub] - # Remove empty strings - and_parts = [x for x in and_parts if x.strip()] - if and_parts: - filter_conditions.append( - f"({' AND '.join(and_parts)})" - ) - elif k == "$or": - or_parts = [parse_filter(sub) for sub in v if sub] - # Remove empty strings - or_parts = [x for x in or_parts if x.strip()] - if or_parts: - filter_conditions.append(f"({' OR '.join(or_parts)})") - else: - # Regular condition - c = parse_condition(k, v) - if c and c.strip(): - filter_conditions.append(c) - - if not filter_conditions: - return "" - if len(filter_conditions) == 1: - return filter_conditions[0] - return " AND ".join(filter_conditions) - - return parse_filter(filter_dict) + filter_clause, new_params = apply_filters( + filters=working_filter, + params=parameters, + mode="condition_only", + ) + return filter_clause async def _compute_leiden_communities( self, diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 90ed4e20e..23e17d50b 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -409,8 +409,7 @@ def __init__(self, **data): super().__init__(**data) def model_dump(self, *args, **kwargs): - dump = super().model_dump(*args, **kwargs) - return dump + return super().model_dump(*args, **kwargs) @classmethod def get_default(cls, mode: str) -> "SearchSettings":