diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3ebcbe8d..13d19c6b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,6 +28,9 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 + - name: Install maturin + run: pip install maturin + - name: Copy environment file run: cp .env.sample .env diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 176e47b8..b3ce3b38 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,35 +1,37 @@ - name: docs - on: - push: - branches: - - master - - main - permissions: - contents: write - jobs: - deploy: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - uses: actions/setup-python@v5 - with: - python-version: 3.x - - name: Install Poetry - uses: snok/install-poetry@v1 - - name: Copy environment file - run: cp .env.sample .env - - name: Install dependencies - run: make install - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@v4 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - run: poetry run mkdocs build - - run: poetry run mkdocs gh-deploy --force +name: docs +on: + push: + branches: + - master + - main +permissions: + contents: write +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Install Poetry + uses: snok/install-poetry@v1 + - name: Install maturin + run: pip install maturin + - name: Copy environment file + run: cp .env.sample .env + - name: Install dependencies + run: make install + - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + - uses: actions/cache@v4 + with: + key: mkdocs-material-${{ env.cache_id }} + path: .cache + restore-keys: | + mkdocs-material- + - run: poetry run mkdocs build + - run: poetry run mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index b1d0b6b1..86fd079f 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,7 @@ website/.vercel # typescript website/*.tsbuildinfo -website/next-env.d.ts \ No newline at end of file +website/next-env.d.ts + +# Rust +*target/ \ No newline at end of file diff --git a/Makefile b/Makefile index bc147709..5c4213b3 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,34 @@ # Load environment variables from .env file include .env -.PHONY: tests tests-basic lint install mypy update ui-install ui-run +.PHONY: tests tests-basic lint install mypy update ui-install ui-run build-rust develop clean + +# Build commands +build-rust: + maturin develop --release --manifest-path docetl/rust/Cargo.toml + +develop: clean build-rust + poetry install --all-extras + +clean: + rm -rf target/ + rm -rf docetl/rust/target/ + rm -f docetl/resolver/resolver*.so + find . -type d -name "__pycache__" -exec rm -rf {} + + find . -type f -name "*.pyc" -delete + find . -type f -name "*.pyo" -delete + find . -type f -name "*.so" -delete + +# Install command now includes Rust build +install: clean + pip install poetry maturin + $(MAKE) develop # Existing commands -tests: +tests: clean build-rust poetry run pytest -tests-basic: +tests-basic: clean build-rust poetry run pytest tests/basic poetry run pytest tests/test_api.py poetry run pytest tests/test_runner_caching.py @@ -15,10 +36,6 @@ tests-basic: lint: poetry run ruff check docetl/* --fix -install: - pip install poetry - poetry install --all-extras - mypy: poetry run mypy diff --git a/docetl/operations/fast_resolve.py b/docetl/operations/fast_resolve.py new file mode 100644 index 00000000..e76dc70c --- /dev/null +++ b/docetl/operations/fast_resolve.py @@ -0,0 +1,378 @@ +from typing import List, Dict, Tuple, Any, Optional +from concurrent.futures import ThreadPoolExecutor +from rich.progress import Progress +from .base import BaseOperation +from docetl_resolver import FastResolver +from rich.console import Console +from rich.status import Status +from jinja2 import Template +import jinja2 +from docetl.operations.utils import RichLoopBar, rich_as_completed + +from rich.prompt import Confirm + +class FastResolveOperation(BaseOperation): + class schema(BaseOperation.schema): + type: str = "fast_resolve" + comparison_prompt: str + resolution_prompt: str + output: Optional[Dict[str, Any]] = None + embedding_model: Optional[str] = None + resolution_model: Optional[str] = None + comparison_model: Optional[str] = None + blocking_threshold: Optional[float] = None + blocking_keys: Optional[List[str]] = None + embedding_batch_size: Optional[int] = None + compare_batch_size: Optional[int] = None + + def syntax_check(self): + """Check if the config is valid.""" + required_keys = ["comparison_prompt", "output"] + for key in required_keys: + if key not in self.config: + raise ValueError(f"Missing required key '{key}' in FastResolveOperation configuration") + + if "schema" not in self.config["output"]: + raise ValueError("Missing 'schema' in 'output' configuration") + + if not isinstance(self.config["output"]["schema"], dict): + raise TypeError("'schema' in 'output' configuration must be a dictionary") + + if not self.config["output"]["schema"]: + raise ValueError("'schema' in 'output' configuration cannot be empty") + + # Check if the comparison_prompt is a valid Jinja2 template + try: + comparison_template = Template(self.config["comparison_prompt"]) + comparison_vars = comparison_template.environment.parse( + self.config["comparison_prompt"] + ).find_all(jinja2.nodes.Name) + comparison_var_names = {var.name for var in comparison_vars} + if "input1" not in comparison_var_names or "input2" not in comparison_var_names: + raise ValueError( + "'comparison_prompt' must contain both 'input1' and 'input2' variables" + ) + + if "resolution_prompt" in self.config: + reduction_template = Template(self.config["resolution_prompt"]) + reduction_vars = reduction_template.environment.parse( + self.config["resolution_prompt"] + ).find_all(jinja2.nodes.Name) + reduction_var_names = {var.name for var in reduction_vars} + if "inputs" not in reduction_var_names: + raise ValueError("'resolution_prompt' must contain 'inputs' variable") + except Exception as e: + raise ValueError(f"Invalid Jinja2 template: {str(e)}") + + def __init__( + self, + runner: "ConfigWrapper", + config: Dict, + default_model: str, + max_threads: int, + console: Optional[Console] = None, + status: Optional[Status] = None, + is_build: bool = False, + **kwargs, + ): + super().__init__(runner, config, default_model, max_threads, console, status, is_build, **kwargs) + self.resolver = FastResolver( + blocking_threshold=config.get("blocking_threshold", None), + debug=config.get("debug", False), + limit_comparisons=config.get("limit_comparisons", None), + ) + + def batch_embeddings(self, items: List[Dict], batch_size: int = 1000) -> Tuple[List[List[float]], float]: + """Get embeddings for all items in parallel batches.""" + all_embeddings = [] + total_cost = 0 + blocking_keys = self.config.get("blocking_keys", list(items[0].keys())) + + def process_batch(batch): + texts = [ + " ".join(str(item[key]) for key in blocking_keys if key in item) + for item in batch + ] + response = self.runner.api.gen_embedding( + model=self.config.get("embedding_model", "text-embedding-3-small"), + input=texts + ) + return [data["embedding"] for data in response["data"]], response.get("usage", {}).get("total_tokens", 0) * 0.0001 + + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: + futures = [] + for i in range(0, len(items), batch_size): + batch = items[i:i + batch_size] + futures.append(executor.submit(process_batch, batch)) + + for future in rich_as_completed( + futures, + total=len(futures), + desc="Generating embeddings", + console=self.console + ): + embeddings, cost = future.result() + all_embeddings.extend(embeddings) + total_cost += cost + + return all_embeddings, total_cost + + def compare_pair(self, item1: Dict, item2: Dict) -> Tuple[bool, float]: + """Compare two items using the LLM.""" + prompt_template = Template(self.config["comparison_prompt"]) + prompt = prompt_template.render(input1=item1, input2=item2) + + response = self.runner.api.call_llm( + self.config.get("comparison_model", self.default_model), + "compare", + [{"role": "user", "content": prompt}], + {"is_match": "bool"}, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), + ) + output = self.runner.api.parse_llm_response( + response.response, + {"is_match": "bool"}, + )[0] + return output["is_match"], response.total_cost + + def process_cluster(self, cluster: List[int], items: List[Dict]) -> Tuple[List[Dict], float]: + """Process a cluster of items to generate a resolved output.""" + if len(cluster) == 1: + return [items[cluster[0]]], 0 + + cluster_items = [items[i] for i in cluster] + reduction_template = Template(self.config["resolution_prompt"]) + resolution_prompt = reduction_template.render(inputs=cluster_items) + + response = self.runner.api.call_llm( + self.config.get("resolution_model", self.default_model), + "resolve", + [{"role": "user", "content": resolution_prompt}], + self.config["output"]["schema"], + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), + validation_config=( + { + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + ) + + if response.validated: + resolved = self.runner.api.parse_llm_response( + response.response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + + results = [] + for idx in cluster: + item = items[idx].copy() + # Save original values before overwriting + keys_in_output = [k for k in resolved.keys() if k in item.keys()] + item[f"_kv_pairs_preresolve_{self.config['name']}"] = { + k: item[k] for k in keys_in_output + } + item.update(resolved) + results.append(item) + + return results, response.total_cost + + return [], response.total_cost + + def validation_fn(self, response: Dict[str, Any]): + output = self.runner.api.parse_llm_response( + response, + schema=self.config["output"]["schema"], + )[0] + if self.runner.api.validate_output(self.config, output, self.console): + return output, True + return output, False + + def auto_batch(self, num_pairs: int) -> int: + """Calculate optimal batch size based on number of comparisons.""" + # Maximum batch size limit for 4o-mini model + M = 500 + + n = len(self.input_data) + m = num_pairs + + # https://www.wolframalpha.com/input?i=k%28k-1%29%2F2+%2B+%28n-k%29%28k-1%29+%3D+m%2C+solve+for+k + # Two possible solutions for k: + # k = -1/2 sqrt((1 - 2n)^2 - 8m) + n + 1/2 + # k = 1/2 (sqrt((1 - 2n)^2 - 8m) + 2n + 1) + + discriminant = (1 - 2*n)**2 - 8*m + sqrt_discriminant = discriminant ** 0.5 + + k1 = -0.5 * sqrt_discriminant + n + 0.5 + k2 = 0.5 * (sqrt_discriminant + 2*n + 1) + + # Take the maximum viable solution + k = max(k1, k2) + return M if k < 0 else min(int(k), M) + + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: + """Execute the fast resolve operation.""" + if not input_data: + return [], 0 + + blocking_threshold = self.config.get("blocking_threshold") + blocking_conditions = self.config.get("blocking_conditions", []) + + if self.status: + self.status.stop() + + if not blocking_threshold and not blocking_conditions: + # Prompt the user for confirmation + if not Confirm.ask( + f"[yellow]Warning: No blocking keys or conditions specified. " + f"This may result in a large number of comparisons. " + f"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. " + f"Do you want to continue without blocking?[/yellow]", + console=self.console, + ): + raise ValueError("Operation cancelled by user.") + + self.input_data = input_data + total_cost = 0 + + # Set up blocking rules + blocking_conditions = self.config.get("blocking_conditions", []) + for condition in blocking_conditions: + # Parse the condition string to extract keys and operation + if "==" in condition: + parts = condition.split("==") + if parts[0].strip().endswith(".lower()") and parts[1].strip().endswith(".lower()"): + key1 = parts[0].split("[")[1].split("]")[0].strip('"\'') + key2 = parts[1].split("[")[1].split("]")[0].strip('"\'') + self.resolver.add_equals_rule(key1, key2) + self.console.log(f"Added equals rule: {key1} equals {key2}") + else: + self.console.log(f"Skipped '==' condition - not using .lower(): {condition}") + elif " in " in condition: + parts = condition.split(" in ") + if parts[0].strip().endswith(".lower()") and parts[1].strip().endswith(".lower()"): + key1 = parts[0].split("[")[1].split("]")[0].strip('"\'') + key2 = parts[1].split("[")[1].split("]")[0].strip('"\'') + + if parts[0].strip().startswith("input1"): + self.resolver.add_contains_rule(key1, key2) + self.console.log(f"Added contains rule: {key1} contains {key2}") + else: + self.resolver.add_contained_in_rule(key1, key2) + self.console.log(f"Added contained_in rule: {key1} contained in {key2}") + else: + self.console.log(f"Skipped 'in' condition - not using .lower(): {condition}") + else: + self.console.log(f"Skipped condition - no recognized operator: {condition}") + + # Get embeddings with configurable batch size + embedding_batch_size = self.config.get("embedding_batch_size", 1000) + embeddings, embedding_cost = self.batch_embeddings(input_data, batch_size=embedding_batch_size) + total_cost += embedding_cost + + # Get comparison pairs from Rust, including blocking rules + comparison_pairs = self.resolver.process_embeddings(embeddings, input_data) + + # Calculate and log statistics + total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2 + comparisons_made = len(comparison_pairs) + comparisons_saved = total_possible_comparisons - comparisons_made + + self.console.log( + f"[green]Comparisons saved by blocking: {comparisons_saved} " + f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]" + ) + self.console.log( + f"[blue]Number of pairs to compare: {comparisons_made}[/blue]" + ) + + # Calculate batch size for comparisons + batch_size = self.config.get("compare_batch_size", self.auto_batch(len(comparison_pairs))) + self.console.log(f"Using compare batch size: {batch_size}") + + # Process comparisons in batches with progress bar + pbar = RichLoopBar( + range(0, len(comparison_pairs), batch_size), + desc=f"Processing batches of {batch_size} LLM comparisons", + console=self.console, + ) + + for i in pbar: + batch = comparison_pairs[i:i + batch_size] + + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: + futures = [] + valid_pairs = [] + + # Pre-filter pairs that might already be in same cluster or processed + for i, j in batch: + if (self.resolver.find_cluster(i) != self.resolver.find_cluster(j) and + not self.resolver.is_processed(i, j)): + futures.append( + executor.submit( + self.compare_pair, + input_data[i], + input_data[j] + ) + ) + valid_pairs.append((i, j)) + + # Process results and merge clusters + for future, (i, j) in zip(futures, valid_pairs): + is_match, cost = future.result() + total_cost += cost + # Mark pair as processed regardless of match result + self.resolver.mark_processed(i, j) + if is_match: + self.resolver.merge_clusters(i, j) + + pbar.update(i//batch_size) + + # Get final clusters + clusters = self.resolver.get_clusters() + + # Calculate and log cluster statistics + num_records_before = len(input_data) + num_clusters_after = len(clusters) + self.console.log(f"Number of records before resolution: {num_records_before}") + self.console.log(f"Number of distinct records after resolution: {num_clusters_after}") + + # Calculate and log self-join selectivity + true_match_count = sum( + len(cluster) * (len(cluster) - 1) // 2 + for cluster in clusters + if len(cluster) > 1 + ) + true_match_selectivity = true_match_count / total_possible_comparisons if total_possible_comparisons > 0 else 0 + self.console.log(f"Self-join selectivity: {true_match_selectivity:.4f}") + + # Process each cluster in parallel with progress + results = [] + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: + futures = [ + executor.submit(self.process_cluster, cluster, input_data) + for cluster in clusters + ] + + for future in rich_as_completed( + futures, + total=len(futures), + desc="Resolving clusters", + console=self.console + ): + cluster_results, cost = future.result() + results.extend(cluster_results) + total_cost += cost + + if self.status: + self.status.start() + + return results, total_cost \ No newline at end of file diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 0f896261..0b46cdac 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -477,6 +477,7 @@ def auto_batch() -> int: # Update batch_end to prevent overlapping in the next loop batch_end = next_end + better_batch = better_batch[:batch_size] last_processed = batch_end with ThreadPoolExecutor(max_workers=self.max_threads) as executor: diff --git a/docetl/rust/Cargo.lock b/docetl/rust/Cargo.lock new file mode 100644 index 00000000..5b77a8c3 --- /dev/null +++ b/docetl/rust/Cargo.lock @@ -0,0 +1,485 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "docetl_resolver" +version = "0.1.0" +dependencies = [ + "ndarray", + "pyo3", + "rand", + "rayon", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "libc" +version = "0.2.162" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", + "rayon", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] diff --git a/docetl/rust/Cargo.toml b/docetl/rust/Cargo.toml new file mode 100644 index 00000000..cb2388b6 --- /dev/null +++ b/docetl/rust/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "docetl_resolver" +version = "0.1.0" +edition = "2021" + +[lib] +name = "docetl_resolver" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.19", features = ["extension-module"] } +ndarray = { version = "0.15", features = ["rayon"] } +rayon = "1.7" +rand = "0.8" \ No newline at end of file diff --git a/docetl/rust/__init__.py b/docetl/rust/__init__.py new file mode 100644 index 00000000..99049ee1 --- /dev/null +++ b/docetl/rust/__init__.py @@ -0,0 +1 @@ +# This file can be empty, it just marks the directory as a Python package \ No newline at end of file diff --git a/docetl/rust/src/lib.rs b/docetl/rust/src/lib.rs new file mode 100644 index 00000000..c0d8dabc --- /dev/null +++ b/docetl/rust/src/lib.rs @@ -0,0 +1,349 @@ +use pyo3::prelude::*; +use ndarray::{Array2, Array1, Axis}; +use std::collections::{HashSet}; +use pyo3::types::{PyDict, PyList}; +use pyo3::Python; +use pyo3::types::PyModule; + +#[derive(Debug, Clone)] +struct ComparisonPair { + i: usize, + j: usize, + similarity: f64, +} + +#[derive(Debug, Clone)] +struct BlockingRule { + rule_type: String, + key1: String, + key2: String, +} + +#[pyclass] +pub struct FastResolver { + #[pyo3(get, set)] + pub blocking_threshold: Option, + #[pyo3(get, set)] + pub debug: bool, + #[pyo3(get, set)] + pub limit_comparisons: Option, + parent: Vec, + size: Vec, + clusters: Vec>, + processed_pairs: HashSet<(usize, usize)>, + blocking_rules: Vec, +} + +#[pymethods] +impl FastResolver { + #[new] + fn new(blocking_threshold: Option, debug: Option, limit_comparisons: Option) -> Self { + FastResolver { + blocking_threshold, + debug: debug.unwrap_or(false), + limit_comparisons, + parent: Vec::new(), + size: Vec::new(), + clusters: Vec::new(), + processed_pairs: HashSet::new(), + blocking_rules: Vec::new(), + } + } + + #[staticmethod] + fn compute_similarity_matrix(embeddings: Vec>) -> Vec> { + let n = embeddings.len(); + let n_features = embeddings[0].len(); + + // Convert to ndarray more efficiently using one allocation + let embedding_data: Vec = embeddings.into_iter().flatten().collect(); + let embedding_matrix = Array2::from_shape_vec((n, n_features), embedding_data) + .expect("Shape mismatch in embedding conversion"); + + // Compute norms using axis operation + let norms: Array1 = embedding_matrix.map_axis(Axis(1), |row| { + (row.dot(&row)).sqrt() + }); + + // Compute similarity matrix directly + let dot_products = embedding_matrix.dot(&embedding_matrix.t()); + let norms_matrix = &norms.view().into_shape((n, 1)).unwrap() + * &norms.view().into_shape((1, n)).unwrap(); + + // Divide element-wise and convert to Vec + let similarity = &dot_products / &norms_matrix; + similarity.outer_iter() + .map(|row| row.to_vec()) + .collect() + } + + fn add_contains_rule(&mut self, key1: String, key2: String) -> PyResult<()> { + self.blocking_rules.push(BlockingRule { + rule_type: "contains".to_string(), + key1, + key2, + }); + Ok(()) + } + + fn add_contained_in_rule(&mut self, key1: String, key2: String) -> PyResult<()> { + self.blocking_rules.push(BlockingRule { + rule_type: "contained_in".to_string(), + key1, + key2, + }); + Ok(()) + } + + fn add_equals_rule(&mut self, key1: String, key2: String) -> PyResult<()> { + self.blocking_rules.push(BlockingRule { + rule_type: "equals".to_string(), + key1, + key2, + }); + Ok(()) + } + + fn check_blocking_rules(&self, item1: &PyDict, item2: &PyDict) -> PyResult { + for rule in &self.blocking_rules { + let val1 = match item1.get_item(&rule.key1) { + Some(v) => v.to_string().to_lowercase(), + None => { + continue; + }, + }; + let val2 = match item2.get_item(&rule.key2) { + Some(v) => v.to_string().to_lowercase(), + None => { + continue; + }, + }; + + match rule.rule_type.as_str() { + "contains" => { + if val1.contains(&val2) { + return Ok(true); + } + } + "contained_in" => { + if val2.contains(&val1) { + return Ok(true); + } + } + "equals" => { + if val1 == val2 { + return Ok(true); + } + } + _ => continue, + } + } + Ok(false) + } + + fn process_items_with_rules<'py>( + &mut self, + _py: Python<'py>, + items: &'py PyList, + ) -> PyResult> { + let n_samples = items.len(); + let mut blocking_pairs = Vec::new(); + + // Skip if no blocking rules + if self.blocking_rules.is_empty() { + return Ok(blocking_pairs); + } + + // Print rules once before processing + if self.debug { + println!("\nChecking blocking rules:"); + for rule in &self.blocking_rules { + match rule.rule_type.as_str() { + "contains" => println!("- CONTAINS rule: input1 {} contains input2 {}", rule.key1, rule.key2), + "contained_in" => println!("- CONTAINED_IN rule: input1 {} is contained in input2 {}", rule.key1, rule.key2), + "equals" => println!("- EQUALS rule: input1 {} equals input2 {}", rule.key1, rule.key2), + _ => println!("- Unknown rule type: {}", rule.rule_type), + } + } + println!(""); // Empty line for readability + } + + // Check each pair against blocking rules + for i in 0..n_samples { + for j in (i+1)..n_samples { + let item1 = items.get_item(i)?.downcast::()?; + let item2 = items.get_item(j)?.downcast::()?; + + if self.check_blocking_rules(item1, item2)? { + let root1 = self.find_cluster(i); + let root2 = self.find_cluster(j); + if root1 != root2 && !self.is_processed(i, j) { + blocking_pairs.push((i, j)); + } + } + } + } + + Ok(blocking_pairs) + } + + fn process_embeddings( + &mut self, + embeddings: Vec>, + items: Option<&PyList>, + ) -> PyResult> { + if embeddings.is_empty() { + return Ok(Vec::new()); + } + if !embeddings.iter().all(|v| v.len() == embeddings[0].len()) { + return Err(PyErr::new::( + "All embeddings must have the same dimension" + )); + } + Python::with_gil(|py| { + let n_samples = embeddings.len(); + + if self.debug { + println!("Processing embeddings for {} samples...", n_samples); + } + + // Initialize only parent and size vectors + self.parent = (0..n_samples).collect(); + self.size = vec![1; n_samples]; + self.processed_pairs.clear(); + + let mut all_pairs = Vec::new(); + let mut similarity_pairs = Vec::new(); + + if self.debug { + println!("Computing similarity matrix..."); + } + + let similarity_matrix = Self::compute_similarity_matrix(embeddings); + + // Store all pairs with their similarities + for i in 0..n_samples { + for j in (i+1)..n_samples { + let similarity = similarity_matrix[i][j]; + if self.blocking_threshold.map_or(true, |t| similarity >= t) { + similarity_pairs.push(ComparisonPair { i, j, similarity }); + } + } + } + + similarity_pairs.sort_unstable_by(|a, b| { + b.similarity.partial_cmp(&a.similarity).unwrap() + }); + + // Add blocking rule pairs if items were provided + if let Some(items_list) = items { + if self.debug { + println!("Applying blocking rules..."); + } + + let blocking_pairs = self.process_items_with_rules(py, items_list)?; + + if self.debug { + println!("Found {} pairs from blocking rules", blocking_pairs.len()); + } + + all_pairs.extend(blocking_pairs); + } + + // Add similarity pairs after blocking pairs + all_pairs.extend(similarity_pairs.into_iter().map(|pair| (pair.i, pair.j))); + + // Initialize clusters only after all pairs are collected + self.clusters = vec![HashSet::new(); n_samples]; + for i in 0..n_samples { + self.clusters[i].insert(i); + } + + if self.debug { + println!("Filtering processed pairs..."); + } + + let mut filtered_pairs: Vec<(usize, usize)> = all_pairs.into_iter() + .filter(|(i, j)| { + let root1 = self.find_cluster(*i); + let root2 = self.find_cluster(*j); + root1 != root2 && !self.is_processed(*i, *j) + }) + .collect(); + + if let Some(limit) = self.limit_comparisons { + if filtered_pairs.len() > limit { + if self.debug { + println!("Limiting to {} pairs out of {}", limit, filtered_pairs.len()); + } + filtered_pairs.truncate(limit); + } + } + + if self.debug { + println!("Final number of pairs to process: {}", filtered_pairs.len()); + } + + Ok(filtered_pairs) + }) + } + + fn find_cluster(&mut self, mut item: usize) -> usize { + while self.parent[item] != item { + // Path compression: Point to grandparent to flatten tree + self.parent[item] = self.parent[self.parent[item]]; + item = self.parent[item]; + } + item + } + + fn merge_clusters(&mut self, item1: usize, item2: usize) -> PyResult<()> { + if item1 >= self.parent.len() || item2 >= self.parent.len() { + return Err(PyErr::new::( + "Invalid cluster index" + )); + } + let mut root1 = self.find_cluster(item1); + let mut root2 = self.find_cluster(item2); + + if root1 != root2 { + // Union by size - attach smaller tree to root of larger tree + if self.size[root1] < self.size[root2] { + std::mem::swap(&mut root1, &mut root2); + } + + // Merge root2 into root1 + self.parent[root2] = root1; + self.size[root1] += self.size[root2]; + + // Merge clusters + let items = self.clusters[root2].drain().collect::>(); + self.clusters[root1].extend(items); + } + + Ok(()) + } + + fn get_clusters(&self) -> PyResult>> { + Ok(self.clusters.iter() + .filter(|c| !c.is_empty()) + .map(|c| c.iter().copied().collect()) + .collect()) + } + + fn is_processed(&self, i: usize, j: usize) -> bool { + let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) }; + self.processed_pairs.contains(&(min_idx, max_idx)) + } + + fn mark_processed(&mut self, i: usize, j: usize) { + let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) }; + self.processed_pairs.insert((min_idx, max_idx)); + } +} + +#[pymodule] +fn docetl_resolver(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 00aeadaf..32eaca8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,14 @@ ignore_missing_imports = true show_error_codes = true [build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +requires = ["poetry-core>=1.0.0", "maturin>=1.0,<2.0"] +build-backend = "maturin" + +[tool.maturin] +python-source = "docetl" +module-name = "docetl_resolver" +manifest-path = "docetl/rust/Cargo.toml" +develop = true [tool.poetry.plugins."docetl.operation"] map = "docetl.operations.map:MapOperation" @@ -108,4 +114,4 @@ txt_to_string = "docetl.parsing_tools:txt_to_string" docx_to_string = "docetl.parsing_tools:docx_to_string" pptx_to_string = "docetl.parsing_tools:pptx_to_string" azure_di_read = "docetl.parsing_tools:azure_di_read" -paddleocr_pdf_to_string = "docetl.parsing_tools:paddleocr_pdf_to_string" +paddleocr_pdf_to_string = "docetl.parsing_tools:paddleocr_pdf_to_string" \ No newline at end of file diff --git a/tests/test_fast_resolve.py b/tests/test_fast_resolve.py new file mode 100644 index 00000000..1fd0fbad --- /dev/null +++ b/tests/test_fast_resolve.py @@ -0,0 +1,218 @@ +import pytest +import random +import string +import time +from docetl.operations.fast_resolve import FastResolveOperation +from docetl.operations.resolve import ResolveOperation + + +@pytest.fixture +def fast_resolve_config(): + return { + "name": "name_email_resolver", + "type": "fast_resolve", + "blocking_threshold": 0.8, + "debug": True, + "blocking_keys": ["name", "email"], + "blocking_conditions": [ + "input1['email'].lower() == input2['email'].lower()", # Exact email match + "input1['name'].lower() in input2['name'].lower()", # Name containment + "input2['name'].lower() in input1['name'].lower()" # Reverse name containment + ], + "comparison_prompt": """Compare these two entries and determine if they refer to the same person: + Person 1: {{ input1.name }} {{ input1.email }} + Person 2: {{ input2.name }} {{ input2.email }} + Return true if they match, false otherwise.""", + "resolution_prompt": """Given these similar entries, determine the canonical form. + Choose the most complete name and the most professional email address: {{ inputs }}""", + "output": { + "schema": { + "name": "string", + "email": "string" + } + }, + "embedding_model": "text-embedding-3-small", + "comparison_model": "azure/gpt-4o-mini", + "resolution_model": "azure/gpt-4o-mini", + "embedding_batch_size": 1000, + "limit_comparisons": 1000 + } + + +def generate_large_dataset(num_base_records=100): + """Generate a very large dataset with intentional duplicates and transitive relationships. + + Example of transitivity: + - John Doe <-> Johnny Doe <-> J. Doe (all same email) + - Multiple email variations for same person + - Name variations that chain together + """ + + # Base data to create variations from + first_names = ['John', 'Michael', 'William', 'James', 'David', 'Robert', 'Thomas', 'Christopher'] + last_names = ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Garcia', 'Miller', 'Davis'] + domains = ['gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com', 'company.com'] + + data = [] + + # Create base records with intentional relationships + for _ in range(num_base_records): + first = random.choice(first_names) + last = random.choice(last_names) + domain = random.choice(domains) + + # Create base email variations for this person + email_variations = [ + f"{first.lower()}.{last.lower()}@{domain}", + f"{first.lower()[0]}{last.lower()}@{domain}", + f"{first.lower()}{last.lower()[0]}@{domain}", + f"{first.lower()}_{last.lower()}@{domain}" + ] + + # Create name variations that chain together + name_variations = [ + f"{first} {last}", # Standard + f"{first}y {last}", # Diminutive + f"{first[0]}. {last}", # Initial + f"{first} {last[0]}.", # Last initial + f"{first[0]}. {last[0]}.", # Both initials + ] + + # Add middle initials to some variations + middle_initials = random.sample(string.ascii_uppercase, 2) + name_variations.extend([ + f"{first} {mi}. {last}" for mi in middle_initials + ]) + + # Create multiple records with combinations of name/email variations + # This ensures transitive relationships + for name in name_variations: + # Use same email for some variations to create strong links + primary_email = random.choice(email_variations) + data.append({"name": name, "email": primary_email}) + + # Add some variations with different emails + if random.random() < 0.3: + data.append({"name": name, "email": random.choice(email_variations)}) + + # Add typo variations + if random.random() < 0.2: + typo_name = name.replace('i', 'y') if 'i' in name else name + 'n' + data.append({"name": typo_name, "email": primary_email}) + + # Add some completely different email domains for same person + alt_domain = random.choice([d for d in domains if d != domain]) + alt_email = f"{first.lower()}.{last.lower()}@{alt_domain}" + data.append({"name": random.choice(name_variations), "email": alt_email}) + + # Shuffle the dataset + random.shuffle(data) + + # Print some statistics about the dataset + print(f"\nGenerated Dataset Statistics:") + print(f"Total records: {len(data)}") + print(f"Unique names: {len(set(r['name'] for r in data))}") + print(f"Unique emails: {len(set(r['email'] for r in data))}") + print(f"Average variations per base record: {len(data) / num_base_records:.1f}") + + return data + + +@pytest.fixture +def fast_resolve_sample_data(): + # Set random seed for reproducibility + random.seed(42) + return generate_large_dataset() + + +def dont_do_test_fast_resolve_operation( + fast_resolve_config, default_model, fast_resolve_sample_data, api_wrapper +): + + distinct_names = set(result["name"] for result in fast_resolve_sample_data) + distinct_emails = set(result["email"] for result in fast_resolve_sample_data) + print(f"Distinct names in input: {len(distinct_names)}") + print(f"Distinct emails in input: {len(distinct_emails)}") + + operation = FastResolveOperation( + api_wrapper, fast_resolve_config, default_model, 256 + ) + results, cost = operation.execute(fast_resolve_sample_data) + + # Calculate and print some statistics + input_count = len(fast_resolve_sample_data) + output_count = len(results) + distinct_output_names = set(result["name"] for result in results) + distinct_output_emails = set(result["email"] for result in results) + + print(f"\nTest Statistics:") + print(f"Input records: {input_count}") + print(f"Output records: {output_count}") + print(f"Distinct names in output: {len(distinct_output_names)}") + print(f"Distinct emails in output: {len(distinct_output_emails)}") + print(f"Reduction ratio: {(input_count - output_count) / input_count:.2%}") + print(f"Total cost: {cost}") + + # Assertions + assert len(distinct_names) < len(fast_resolve_sample_data) + assert output_count == input_count + assert cost > 0 + + +def test_fast_resolve_operation_empty_input( + fast_resolve_config, default_model, max_threads, api_wrapper +): + operation = FastResolveOperation( + api_wrapper, fast_resolve_config, default_model, max_threads + ) + results, cost = operation.execute([]) + + assert len(results) == 0 + assert cost == 0 + + + +def test_compare_resolve_performance( + fast_resolve_config, default_model, api_wrapper +): + """Compare performance between FastResolve and regular Resolve operations.""" + + # Generate a smaller dataset for testing + large_dataset = generate_large_dataset() + print(f"\nTesting with {len(large_dataset)} records") + + # Test FastResolve with blocking rules + start_time = time.time() + fast_operation = FastResolveOperation( + api_wrapper, fast_resolve_config, default_model, 256 + ) + fast_results, fast_cost = fast_operation.execute(large_dataset) + fast_time = time.time() - start_time + + # Test regular Resolve with sample + start_time = time.time() + regular_operation = ResolveOperation( + api_wrapper, fast_resolve_config, default_model, 256 + ) + regular_results, regular_cost = regular_operation.execute(large_dataset) + regular_time = time.time() - start_time + + # Print detailed performance metrics + print("\nPerformance Comparison:") + print(f"FastResolve Time: {fast_time:.2f} seconds") + print(f"Regular Resolve Time: {regular_time:.2f} seconds") + print(f"FastResolve Cost: ${fast_cost:.4f}") + print(f"Regular Resolve Cost: ${regular_cost:.4f}") + print(f"Speed Improvement: {(regular_time - fast_time) / regular_time:.1%}") + print(f"Cost Savings: {(regular_cost - fast_cost) / regular_cost:.1%}") + + # Additional metrics + print("\nResolution Quality Metrics:") + print(f"FastResolve output records: {len(fast_results)}") + print(f"Distinct names in output: {len(set(r['name'] for r in fast_results))}") + print(f"Distinct emails in output: {len(set(r['email'] for r in fast_results))}") + print(f"Reduction ratio: {(len(large_dataset) - len(fast_results)) / len(large_dataset):.2%}") + + # Assertions + assert fast_time < regular_time, "FastResolve should be faster than regular Resolve" + assert len(fast_results) <= len(large_dataset), "Output should not be larger than input" \ No newline at end of file