Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix filter logic bugs #1753

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 129 additions & 3 deletions js/sdk/__tests__/DocumentsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { r2rClient } from "../src/index";
import { describe, test, beforeAll, expect, afterAll } from "@jest/globals";
import { assert } from "console";
import fs from "fs";
import path from "path";

Expand All @@ -9,6 +10,7 @@ const TEST_OUTPUT_DIR = path.join(__dirname, "test-output");
/**
* marmeladov.txt will have an id of 83ef5342-4275-5b75-92d6-692fa32f8523
* The untitled document will have an id of 5556836e-a51c-57c7-916a-de76c79df2b6
* The default collection id is 122fdf6a-e116-546b-a8f6-e4cb2e2c0a09
*/
describe("r2rClient V3 Documents Integration Tests", () => {
let client: r2rClient;
Expand All @@ -35,7 +37,7 @@ describe("r2rClient V3 Documents Integration Tests", () => {
test("Create document with file path", async () => {
const response = await client.documents.create({
file: { path: "examples/data/marmeladov.txt", name: "marmeladov.txt" },
metadata: { title: "marmeladov.txt" },
metadata: { title: "marmeladov.txt", numericId: 123 },
});

expect(response.results.documentId).toBeDefined();
Expand All @@ -45,7 +47,7 @@ describe("r2rClient V3 Documents Integration Tests", () => {
test("Create document with content", async () => {
const response = await client.documents.create({
raw_text: "This is a test document",
metadata: { title: "Test Document" },
metadata: { title: "Test Document", numericId: 456 },
});

expect(response.results.documentId).toBeDefined();
Expand Down Expand Up @@ -175,7 +177,131 @@ describe("r2rClient V3 Documents Integration Tests", () => {
).rejects.toThrow(/Only one of file, raw_text, or chunks may be provided/);
});

test("Delete Raskolnikov.txt", async () => {
test("Search with $lte filter should only return documents with numericId <= 200", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
numericId: { $lte: 200 },
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every(
(result) => result.metadata?.numericId <= 200,
),
).toBe(true);
});

test("Search with $gte filter should only return documents with metadata.numericId >= 400", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
"metadata.numericId": { $gte: 400 },
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every(
(result) => result.metadata?.numericId >= 400,
),
).toBe(true);
});

test("Search with $eq filter should only return exact matches", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
numericId: { $eq: 123 },
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every(
(result) => result.metadata?.numericId === 123,
),
).toBe(true);
});

test("Search with range filter should return documents within range", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
"metadata.numericId": {
$gte: 500,
},
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every((result) => {
const numericId = result.metadata?.numericId;
return numericId >= 100 && numericId <= 500;
}),
).toBe(true);
});

test("Search without filters should return both documents", async () => {
const response = await client.retrieval.search({
query: "Test query",
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(response.results.chunkSearchResults.length).toBeGreaterThan(0);

const numericIds = response.results.chunkSearchResults.map((result) => {
return result.metadata?.numericId || result.metadata?.numericid;
});

expect(numericIds.filter((id) => id !== undefined)).toContain(123);
expect(numericIds.filter((id) => id !== undefined)).toContain(456);
});

// test("Filter on collection_id", async () => {
// const response = await client.retrieval.search({
// query: "Test query",
// searchSettings: {
// filters: {
// collection_ids: {
// $in: ["122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"],
// },
// },
// },
// });
// expect(response.results.chunkSearchResults).toBeDefined();
// expect(response.results.chunkSearchResults.length).toBeGreaterThan(0);
// expect(response.results.chunkSearchResults[0].collectionIds).toContain(
// "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09",
// );
// });

test("Filter on non-existant column should return empty", async () => {
const response = await expect(
client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
nonExistentColumn: {
$eq: ["122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"],
},
},
},
}),
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for filtering on a non-existent column is incomplete. It should check for an empty result set.

});

test("Delete marmeladov.txt", async () => {
const response = await client.documents.delete({
id: "83ef5342-4275-5b75-92d6-692fa32f8523",
});
Expand Down
104 changes: 67 additions & 37 deletions py/core/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def __init__(
else:
self.top_level_columns = set(top_level_columns)
self.json_column = json_column
self.params: list[Any] = (
params # params are mutated during construction
)
self.params: list[Any] = params # mutated during construction
self.mode = mode

def build(self, expr: FilterExpression) -> Tuple[str, list[Any]]:
Expand Down Expand Up @@ -171,9 +169,8 @@ def _build_expression(self, expr: FilterExpression) -> str:
@staticmethod
def _psql_quote_literal(value: str) -> str:
"""
Safely quote a string literal for PostgreSQL to prevent SQL injection.
This is a simple implementation - in production, you should use proper parameterization
or your database driver's quoting functions.
Simple quoting for demonstration. In production, use parameterized queries or
your DB driver's quoting function instead.
"""
return "'" + value.replace("'", "''") + "'"

Expand All @@ -183,31 +180,81 @@ def _build_condition(self, cond: FilterCondition) -> str:
op = cond.operator
val = cond.value

# Handle special logic for collection_id
# 1. If the filter references "parent_id", handle it as a single-UUID column for graphs:
if key == "parent_id":
return self._build_parent_id_condition(op, val)

# 2. If the filter references "collection_id", handle it as an array column (chunks)
if key == "collection_id":
return self._build_collection_id_condition(op, val)

# 3. Otherwise, decide if it's top-level or metadata:
if field_is_metadata:
return self._build_metadata_condition(key, op, val)
else:
return self._build_column_condition(key, op, val)

def _build_parent_id_condition(self, op: str, val: Any) -> str:
"""
For 'graphs' tables, parent_id is a single UUID (not an array).
We handle the same ops but in a simpler, single-UUID manner.
"""
param_idx = len(self.params) + 1

if op == "$eq":
if not isinstance(val, str):
raise FilterError(
"$eq for parent_id expects a single UUID string"
)
self.params.append(val)
return f"parent_id = ${param_idx}::uuid"

elif op == "$ne":
if not isinstance(val, str):
raise FilterError(
"$ne for parent_id expects a single UUID string"
)
self.params.append(val)
return f"parent_id != ${param_idx}::uuid"

elif op == "$in":
# A list of UUIDs, any of which might match
if not isinstance(val, list):
raise FilterError(
"$in for parent_id expects a list of UUID strings"
)
self.params.append(val)
return f"parent_id = ANY(${param_idx}::uuid[])"

elif op == "$nin":
# A list of UUIDs, none of which may match
if not isinstance(val, list):
raise FilterError(
"$nin for parent_id expects a list of UUID strings"
)
self.params.append(val)
return f"parent_id != ALL(${param_idx}::uuid[])"

else:
# You could add more (like $gt, $lt, etc.) if your schema wants them
raise FilterError(f"Unsupported operator {op} for parent_id")

def _build_collection_id_condition(self, op: str, val: Any) -> str:
"""
For the 'chunks' table, collection_ids is an array of UUIDs.
This logic stays exactly as you had it.
"""
param_idx = len(self.params) + 1

# Handle operations
if op == "$eq":
# Expect a single UUID, ensure val is a string
if not isinstance(val, str):
raise FilterError(
"$eq for collection_id expects a single UUID string"
)
self.params.append(val)
# Check if val is in the collection_ids array
return f"${param_idx}::uuid = ANY(collection_ids)"

elif op == "$ne":
# Not equal means val is not in collection_ids
if not isinstance(val, str):
raise FilterError(
"$ne for collection_id expects a single UUID string"
Expand All @@ -216,31 +263,25 @@ def _build_collection_id_condition(self, op: str, val: Any) -> str:
return f"NOT (${param_idx}::uuid = ANY(collection_ids))"

elif op == "$in":
# Expect a list of UUIDs, any of which may match
if not isinstance(val, list):
raise FilterError(
"$in for collection_id expects a list of UUID strings"
)
self.params.append(val)
# Use overlap to check if any of the given IDs are in collection_ids
return f"collection_ids && ${param_idx}::uuid[]"

elif op == "$nin":
# None of the given UUIDs should be in collection_ids
if not isinstance(val, list):
raise FilterError(
"$nin for collection_id expects a list of UUID strings"
)
self.params.append(val)
# Negate overlap condition
return f"NOT (collection_ids && ${param_idx}::uuid[])"

elif op == "$contains":
# If someone tries "$contains" with a single collection_id, we can check if collection_ids fully contain it
# Usually $contains might mean we want to see if collection_ids contain a certain element.
# That's basically $eq logic. For a single value:
if isinstance(val, str):
self.params.append([val]) # Array of one element
# single string -> array with one element
self.params.append([val])
return f"collection_ids @> ${param_idx}::uuid[]"
elif isinstance(val, list):
self.params.append(val)
Expand Down Expand Up @@ -278,7 +319,6 @@ def _build_column_condition(self, col: str, op: str, val: Any) -> str:
self.params.append(val)
return f"{col} @> ${param_idx}"
elif op == "$any":
# If col == "collection_ids" handle special case
if col == "collection_ids":
self.params.append(f"%{val}%")
return f"array_to_string({col}, ',') LIKE ${param_idx}"
Expand All @@ -296,8 +336,7 @@ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str:
json_col = self.json_column

# Strip "metadata." prefix if present
if key.startswith("metadata."):
key = key[len("metadata.") :]
key = key.removeprefix("metadata.")

# Split on '.' to handle nested keys
parts = key.split(".")
Expand All @@ -310,41 +349,34 @@ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str:
"$gte",
"$eq",
"$ne",
)
) and isinstance(val, (int, float, str))
if op == "$in" or op == "$contains" or isinstance(val, (list, dict)):
use_text_extraction = False

# Build the JSON path expression
if len(parts) == 1:
# Single part key
if use_text_extraction:
path_expr = f"{json_col}->>'{parts[0]}'"
else:
path_expr = f"{json_col}->'{parts[0]}'"
else:
# Multiple segments
inner_parts = parts[:-1]
last_part = parts[-1]
# Build chain for the inner parts
path_expr = json_col
for p in inner_parts:
for p in parts[:-1]:
path_expr += f"->'{p}'"
# Last part
last_part = parts[-1]
if use_text_extraction:
path_expr += f"->>'{last_part}'"
else:
path_expr += f"->'{last_part}'"

# Convert numeric values to strings for text comparison
def prepare_value(v):
if isinstance(v, (int, float)):
return str(v)
return v
return str(v) if isinstance(v, (int, float)) else v

# Now apply the operator logic
if op == "$eq":
if use_text_extraction:
self.params.append(prepare_value(val))
prepared_val = prepare_value(val)
self.params.append(prepared_val)
return f"{path_expr} = ${param_idx}"
else:
self.params.append(json.dumps(val))
Expand Down Expand Up @@ -372,7 +404,6 @@ def prepare_value(v):
if not isinstance(val, list):
raise FilterError("argument to $in filter must be a list")

# For regular scalar values, use ANY with text extraction
if use_text_extraction:
str_vals = [
str(v) if isinstance(v, (int, float)) else v for v in val
Expand Down Expand Up @@ -413,7 +444,6 @@ def apply_filters(
"""
Apply filters with consistent WHERE clause handling
"""

if not filters:
return "", params

Expand Down
Loading
Loading