From 38a073ff755108634ead22d7bbb2219fc0de8623 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sat, 12 Oct 2024 17:51:27 -0400 Subject: [PATCH] refactor: combine sampling and outlier operators --- docetl/operations/outliers.py | 82 ------------ docetl/operations/sample.py | 121 +++++++++++++++--- docs/operators/outliers.md | 60 --------- docs/operators/sample.md | 41 ++++-- mkdocs.yml | 1 + ..._cluster.py => test_cluster_and_sample.py} | 104 +++++++++++++++ 6 files changed, 239 insertions(+), 170 deletions(-) delete mode 100644 docetl/operations/outliers.py delete mode 100644 docs/operators/outliers.md rename tests/basic/{test_cluster.py => test_cluster_and_sample.py} (54%) diff --git a/docetl/operations/outliers.py b/docetl/operations/outliers.py deleted file mode 100644 index 391c1ab3..00000000 --- a/docetl/operations/outliers.py +++ /dev/null @@ -1,82 +0,0 @@ -from jinja2 import Environment, Template -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Tuple -import numpy as np -from .base import BaseOperation -from .utils import RichLoopBar -from .clustering_utils import get_embeddings_for_clustering - -class OutliersOperation(BaseOperation): - def __init__( - self, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.max_batch_size: int = self.config.get( - "max_batch_size", kwargs.get("max_batch_size", float("inf")) - ) - - def syntax_check(self) -> None: - """ - Checks the configuration of the OutlierOperation for required keys and valid structure. - - Raises: - ValueError: If required keys are missing - """ - - pass - - - def execute( - self, input_data: List[Dict], is_build: bool = False - ) -> Tuple[List[Dict], float]: - """ - Executes the cluster operation on the input data. Modifies the - input data and returns it in place. - - Args: - input_data (List[Dict]): A list of dictionaries to process. - is_build (bool): Whether the operation is being executed - in the build phase. Defaults to False. - - Returns: - Tuple[List[Dict], float]: A tuple containing the filtered - list of dictionaries and the total cost of the operation. - """ - - embeddings, cost = get_embeddings_for_clustering( - input_data, self.config, self.runner.api - ) - embeddings = np.array(embeddings) - - if self.config.get("center", None) is not None: - center_embeddings, cost2 = get_embeddings_for_clustering( - [self.config["center"]], self.config, self.runner.api - ) - cost += cost2 - center = np.array(center_embeddings[0]) - else: - center = embeddings.mean(axis=0) - - distances = np.sqrt(((embeddings - center)**2).sum(axis=1)) - - if "samples" in self.config: - distance_distribution = np.sort(distances) - samples = self.config["samples"] - if isinstance(samples, float): - samples = int(samples * (len(distance_distribution)-1)) - cutoff = distance_distribution[samples] - elif "std" in self.config: - cutoff = np.sqrt((embeddings.std(axis=0)**2).sum()) * self.config["std"] - - if not self.config.get("keep", False): - include = distances <= cutoff - else: - include = distances > cutoff - - return [ - item - for idx, item in enumerate(input_data) - if include[idx]], cost - diff --git a/docetl/operations/sample.py b/docetl/operations/sample.py index 91ebc344..e7503714 100644 --- a/docetl/operations/sample.py +++ b/docetl/operations/sample.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple +import numpy as np from docetl.operations.base import BaseOperation +from docetl.operations.clustering_utils import get_embeddings_for_clustering class SampleOperation(BaseOperation): @@ -18,7 +20,52 @@ def syntax_check(self) -> None: ValueError: If required keys are missing or invalid in the configuration. TypeError: If configuration values have incorrect types. """ - pass + if "samples" not in self.config and "outliers" not in self.config: + raise ValueError( + "Must specify either 'samples' or 'outliers' in SampleOperation configuration" + ) + + if "samples" in self.config: + if not isinstance(self.config["samples"], (int, float, list)) or ( + isinstance(self.config["samples"], (int, float)) + and self.config["samples"] <= 0 + ): + raise TypeError("'samples' must be a positive integer, float, or list") + + if "outliers" in self.config: + outliers_config = self.config["outliers"] + if "std" not in outliers_config and "samples" not in outliers_config: + raise ValueError( + "Must specify either 'std' or 'samples' in outliers configuration" + ) + + if "std" in outliers_config: + if ( + not isinstance(outliers_config["std"], (int, float)) + or outliers_config["std"] <= 0 + ): + raise TypeError("'std' in outliers must be a positive number") + + if "samples" in outliers_config: + if ( + not isinstance(outliers_config["samples"], (int, float)) + or outliers_config["samples"] <= 0 + ): + raise TypeError( + "'samples' in outliers must be a positive integer or float" + ) + + if "embedding_keys" not in outliers_config: + raise ValueError( + "'embedding_keys' must be specified in outliers configuration" + ) + + if not isinstance(outliers_config["embedding_keys"], list) or not all( + isinstance(key, str) for key in outliers_config["embedding_keys"] + ): + raise TypeError( + "'embedding_keys' in outliers must be a list of strings" + ) def execute( self, input_data: List[Dict], is_build: bool = False @@ -35,26 +82,62 @@ def execute( Tuple[List[Dict], float]: A tuple containing the filtered list of dictionaries and the total cost of the operation. """ + cost = 0 + if not input_data: + return [], cost - samples = self.config["samples"] - if isinstance(samples, list): - keys = list(samples[0].keys()) - key_to_doc = {tuple([doc[key] for key in keys]): doc for doc in input_data} + if "outliers" in self.config: + # Outlier functionality + outliers_config = self.config["outliers"] + embeddings, embedding_cost = get_embeddings_for_clustering( + input_data, outliers_config, self.runner.api + ) + cost += embedding_cost + embeddings = np.array(embeddings) + + center = embeddings.mean(axis=0) + distances = np.sqrt(((embeddings - center) ** 2).sum(axis=1)) + + if "std" in outliers_config: + cutoff = ( + np.sqrt((embeddings.std(axis=0) ** 2).sum()) + * outliers_config["std"] + ) + else: # "samples" in outliers_config + distance_distribution = np.sort(distances) + samples = outliers_config["samples"] + if isinstance(samples, float): + samples = int(samples * (len(distance_distribution) - 1)) + cutoff = distance_distribution[samples] + + keep = outliers_config.get("keep", False) + include = distances > cutoff if keep else distances <= cutoff - output_data = [ - key_to_doc[tuple([sample[key] for key in keys])] for sample in samples - ] + output_data = [item for idx, item in enumerate(input_data) if include[idx]] else: - stratify = None - if "stratify" in self.config: - stratify = [data[self.config["stratify"]] for data in input_data] + samples = self.config["samples"] + if isinstance(samples, list): + keys = list(samples[0].keys()) + key_to_doc = { + tuple([doc[key] for key in keys]): doc for doc in input_data + } - import sklearn.model_selection + output_data = [ + key_to_doc[tuple([sample[key] for key in keys])] + for sample in samples + ] + else: + stratify = None + if "stratify" in self.config: + stratify = [data[self.config["stratify"]] for data in input_data] - output_data, dummy = sklearn.model_selection.train_test_split( - input_data, - train_size=samples, - random_state=self.config.get("random_state", None), - stratify=stratify, - ) - return output_data, 0 + import sklearn.model_selection + + output_data, dummy = sklearn.model_selection.train_test_split( + input_data, + train_size=samples, + random_state=self.config.get("random_state", None), + stratify=stratify, + ) + + return output_data, cost diff --git a/docs/operators/outliers.md b/docs/operators/outliers.md deleted file mode 100644 index b137f578..00000000 --- a/docs/operators/outliers.md +++ /dev/null @@ -1,60 +0,0 @@ -# Outliers operation - -The Outliers operation in DocETL removes outliers from the input (or -keeps only outliers). - -## 🚀 Example: - -```yaml -- name: remove-worst-10 - type: outliers - samples: 0.9 - embedding_keys: - - concept - - description -``` - -This will keep the 90 percent closest to the center (average) -embedding of the keys provided. Altermnatively, you could set samples -to an integer count of items to keep (or a negative number to throw -away). You can also assume a gaussian distribution and set the key std -to a number of standard deviations out from the center, instead of -setting samples. - -Small note about embeddings: If you embed too short values, some -embedding models will yield a very "sparse" distribution, where the -absolute majority of points lie on the surface of a hyperssphere, -meaning that this operation will not work very well! - -### Using it as a poor-mans-RAG -```yaml -- name: remove-worst-10 - type: outliers - samples: 0.01 - embedding_keys: - - concept - - description - center: - concept: Horse - description: A horse is a large steppe roaming and grazing animal. Humans have utilized horses for transport throughout historical times -``` - -If center is provided, it must have the same keys as those listed -under embedding_keys, and their values will be used to calculate the -"center" embedding, instead of using the average of all embeddings of -the input items. This will effectively turn this into a search -operation for items similar to the center provided. - -## Required Parameters - -- `name`: A unique name for the operation. -- `type`: Must be set to "sample". -- `samples`: Either a an integer count of samples, or a float fraction of samples. -- `embedding_keys`: A list of keys to use for the embedding distance calculation. - -## Optional Parameters - -| Parameter | Description | Default | -| ------------------------- | -------------------------------------------------------------------------------- | ----------------------------- | -| `keep` | If set to true, return the outliers instead of the non-outliers | false -| `center` | An explicit center object to be used to calculate the center embedding instead of using the average | The average embedding of all input data diff --git a/docs/operators/sample.md b/docs/operators/sample.md index 62d852a2..7e0a5109 100644 --- a/docs/operators/sample.md +++ b/docs/operators/sample.md @@ -21,21 +21,44 @@ operation you add while developing your pipeline! ``` This sample operation will return a pseudo-randomly selected 10% of -the samples (`samples: 0.1`). The random selection will be seeded with +the samples (samples: 0.1). The random selection will be seeded with a constant (42), meaning the same selection will be returned if you rerun the pipeline (If no random state is given, a different sample will be returned every time). Additionally, the random sampling will -sample each value of the `category` key equally. +sample each value of the category key equally. ## Required Parameters -- `name`: A unique name for the operation. -- `type`: Must be set to "sample". -- `samples`: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples. +- name: A unique name for the operation. +- type: Must be set to "sample". +- samples: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples. ## Optional Parameters -| Parameter | Description | Default | -| ------------- | -------------------------------------------- | ----------------------------------- | -| `random_state | An integer to seed the random generator with | Use the (numpy) global random state | -| `stratify` | The key to stratify by | | +| Parameter | Description | Default | +| ------------ | -------------------------------------------- | ----------------------------------- | +| random_state | An integer to seed the random generator with | Use the (numpy) global random state | +| stratify | The key to stratify by | | + +## Outliers + +The Sample operation can also be used to sample outliers. To do this, instead of specifying "samples", specify an "outliers" object with the following parameters: + +- embedding_keys: A list of keys to use for creating embeddings. +- std: The number of standard deviations to use as the cutoff for outliers. +- samples: The number or fraction of samples to consider as outliers. +- keep: Whether to keep (true) or remove (false) the outliers. Defaults to false. + +You must specify either "std" or "samples" in the outliers configuration, but not both. + +Example: + +```yaml +- name: remove-worst-10 + type: sample + outliers: + embedding_keys: + - concept + - description + samples: 0.9 +``` diff --git a/mkdocs.yml b/mkdocs.yml index e995670e..988c272f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -35,6 +35,7 @@ nav: - Split: operators/split.md - Gather: operators/gather.md - Unnest: operators/unnest.md + - Sample: operators/sample.md - Optimization: - Overview: optimization/overview.md - Example: optimization/example.md diff --git a/tests/basic/test_cluster.py b/tests/basic/test_cluster_and_sample.py similarity index 54% rename from tests/basic/test_cluster.py rename to tests/basic/test_cluster_and_sample.py index 0bcaa717..0c6a2821 100644 --- a/tests/basic/test_cluster.py +++ b/tests/basic/test_cluster_and_sample.py @@ -1,5 +1,6 @@ import pytest from docetl.operations.cluster import ClusterOperation +from docetl.operations.sample import SampleOperation from tests.conftest import api_wrapper, default_model, max_threads @@ -32,36 +33,52 @@ def cluster_config(): def sample_data(): return [ { + "id": 1, "concept": "Shed", "description": "A simple, single-story roofed structure, often used for storage or as a workshop.", + "group": "A", }, { + "id": 2, "concept": "Barn", "description": "A large agricultural building used for storing farm products and sheltering livestock.", + "group": "B", }, { + "id": 3, "concept": "Tree house", "description": "A small house built among the branches of a tree for children to play in.", + "group": "A", }, { + "id": 4, "concept": "Skyscraper", "description": "A very tall building of many stories, typically found in urban areas.", + "group": "B", }, { + "id": 5, "concept": "Castle", "description": "A large fortified building or set of buildings from the medieval period.", + "group": "A", }, { + "id": 6, "concept": "Igloo", "description": "A dome-shaped dwelling made of blocks of solid snow, traditionally built by Inuit people.", + "group": "B", }, { + "id": 7, "concept": "Lighthouse", "description": "A tower with a bright light at the top, used to warn or guide ships at sea.", + "group": "A", }, { + "id": 8, "concept": "Windmill", "description": "A building with sails or vanes that turn in the wind and generate power to grind grain into flour.", + "group": "B", }, ] @@ -115,3 +132,90 @@ def test_cluster_operation_single_item( assert cost == 0 assert "categories" in results[0] assert isinstance(results[0]["categories"], tuple) + + +@pytest.fixture +def sample_config(): + return { + "name": "sample_operation", + "type": "sample", + "random_state": 42, # For reproducibility + } + + +def test_sample_operation_with_count( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 5 + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == 5 + assert cost == 0 + assert all(item in sample_data for item in results) + + +def test_sample_operation_with_fraction( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 0.5 + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == len(sample_data) // 2 + assert cost == 0 + assert all(item in sample_data for item in results) + + +def test_sample_operation_with_list( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_list = [{"id": 1}, {"id": 3}, {"id": 5}] + sample_config["samples"] = sample_list + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == len(sample_list) + assert cost == 0 + assert all(item["id"] in [1, 3, 5] for item in results) + + +def test_sample_operation_with_stratify( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 5 + sample_config["stratify"] = "group" + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == 5 + assert cost == 0 + assert all(item in sample_data for item in results) + assert len(set(item["group"] for item in results)) > 1 + + +def test_sample_operation_with_outliers( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["outliers"] = { + "std": 2, + "embedding_keys": ["concept", "description"], + "keep": True, + } + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) < len(sample_data) + assert cost > 0 + assert all(item in sample_data for item in results) + + +def test_sample_operation_empty_input( + sample_config, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 3 + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute([]) + + assert len(results) == 0 + assert cost == 0