diff --git a/docetl/operations/code_operations.py b/docetl/operations/code_operations.py new file mode 100644 index 00000000..09a62c9a --- /dev/null +++ b/docetl/operations/code_operations.py @@ -0,0 +1,146 @@ +from typing import Any, Dict, List, Optional, Tuple +from concurrent.futures import ThreadPoolExecutor +from docetl.operations.base import BaseOperation +from docetl.operations.utils import RichLoopBar + +class CodeMapOperation(BaseOperation): + class schema(BaseOperation.schema): + type: str = "code_map" + code: str + drop_keys: Optional[List[str]] = None + + def syntax_check(self) -> None: + config = self.schema(**self.config) + try: + namespace = {} + exec(config.code, namespace) + if "transform" not in namespace: + raise ValueError("Code must define a 'transform' function") + if not callable(namespace["transform"]): + raise ValueError("'transform' must be a callable function") + except Exception as e: + raise ValueError(f"Invalid code configuration: {str(e)}") + + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: + namespace = {} + exec(self.config["code"], namespace) + transform_fn = namespace["transform"] + + results = [] + with ThreadPoolExecutor() as executor: + futures = [executor.submit(transform_fn, doc) for doc in input_data] + pbar = RichLoopBar( + range(len(futures)), + desc=f"Processing {self.config['name']} (code_map)", + console=self.console, + ) + for i in pbar: + result = futures[i].result() + if self.config.get("drop_keys"): + result = { + k: v for k, v in result.items() + if k not in self.config["drop_keys"] + } + doc = input_data[i] + merged_result = {**doc, **result} + results.append(merged_result) + + return results, 0.0 + +class CodeReduceOperation(BaseOperation): + class schema(BaseOperation.schema): + type: str = "code_reduce" + code: str + + def syntax_check(self) -> None: + config = self.schema(**self.config) + try: + namespace = {} + exec(config.code, namespace) + if "transform" not in namespace: + raise ValueError("Code must define a 'transform' function") + if not callable(namespace["transform"]): + raise ValueError("'transform' must be a callable function") + except Exception as e: + raise ValueError(f"Invalid code configuration: {str(e)}") + + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: + namespace = {} + exec(self.config["code"], namespace) + reduce_fn = namespace["transform"] + + reduce_keys = self.config.get("reduce_key", "_all") + if not isinstance(reduce_keys, list): + reduce_keys = [reduce_keys] + + if reduce_keys == ["_all"] or reduce_keys == "_all": + grouped_data = [("_all", input_data)] + else: + def get_group_key(item): + return tuple(item[key] for key in reduce_keys) + + grouped_data = {} + for item in input_data: + key = get_group_key(item) + if key not in grouped_data: + grouped_data[key] = [] + grouped_data[key].append(item) + + grouped_data = list(grouped_data.items()) + + results = [] + with ThreadPoolExecutor() as executor: + futures = [executor.submit(reduce_fn, group) for _, group in grouped_data] + pbar = RichLoopBar( + range(len(futures)), + desc=f"Processing {self.config['name']} (code_reduce)", + console=self.console, + ) + for i, (key, group) in zip(pbar, grouped_data): + result = futures[i].result() + + # Apply pass-through at the group level + if self.config.get("pass_through", False) and group: + for k, v in group[0].items(): + if k not in result: + result[k] = v + + results.append(result) + + return results, 0.0 + +class CodeFilterOperation(BaseOperation): + class schema(BaseOperation.schema): + type: str = "code_filter" + code: str + + def syntax_check(self) -> None: + config = self.schema(**self.config) + try: + namespace = {} + exec(config.code, namespace) + if "transform" not in namespace: + raise ValueError("Code must define a 'transform' function") + if not callable(namespace["transform"]): + raise ValueError("'transform' must be a callable function") + except Exception as e: + raise ValueError(f"Invalid code configuration: {str(e)}") + + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: + namespace = {} + exec(self.config["code"], namespace) + filter_fn = namespace["transform"] + + results = [] + with ThreadPoolExecutor() as executor: + futures = [executor.submit(filter_fn, doc) for doc in input_data] + pbar = RichLoopBar( + range(len(futures)), + desc=f"Processing {self.config['name']} (code_filter)", + console=self.console, + ) + for i in pbar: + should_keep = futures[i].result() + if should_keep: + results.append(input_data[i]) + return results, 0.0 \ No newline at end of file diff --git a/docs/operators/code.md b/docs/operators/code.md new file mode 100644 index 00000000..28178fe2 --- /dev/null +++ b/docs/operators/code.md @@ -0,0 +1,92 @@ +# Code Operations + +Code operations in DocETL allow you to define transformations using Python code rather than LLM prompts. This is useful when you need deterministic processing, complex calculations, or want to leverage existing Python libraries. + +## Motivation + +While LLM-powered operations are powerful for natural language tasks, sometimes you need operations that are: + +- Deterministic and reproducible +- Integrated with external Python libraries +- Focused on structured data transformations +- Math-based or computationally intensive (something an LLM is not good at) + +Code operations provide a way to handle these cases efficiently without LLM overhead. + +## Types of Code Operations + +### Code Map Operation + +The Code Map operation applies a Python function to each item in your input data independently. + +??? example "Example Code Map Operation" + + ```yaml + - name: extract_keywords + type: code_map + code: | + def transform(doc) -> dict: + # Your transformation code here + keywords = doc['text'].lower().split() + return { + 'keywords': keywords, + 'keyword_count': len(keywords) + } + ``` + +The code must define a `transform` function that takes a single document as input and returns a dictionary of transformed values. + +### Code Reduce Operation + +The Code Reduce operation aggregates multiple items into a single result using a Python function. + +??? example "Example Code Reduce Operation" + + ```yaml + - name: aggregate_stats + type: code_reduce + reduce_key: category + code: | + def transform(items) -> dict: + total = sum(item['value'] for item in items) + avg = total / len(items) + return { + 'total': total, + 'average': avg, + 'count': len(items) + } + ``` + +The transform function for reduce operations takes a list of items as input and returns a single aggregated result. + +### Code Filter Operation + +The Code Filter operation allows you to filter items based on custom Python logic. + +??? example "Example Code Filter Operation" + + ```yaml + - name: filter_valid_entries + type: code_filter + code: | + def transform(doc) -> bool: + # Return True to keep the document, False to filter it out + return doc['score'] >= 0.5 and len(doc['text']) > 100 + ``` + +The transform function should return True for items to keep and False for items to filter out. + +## Configuration + +### Required Parameters + +- type: Must be "code_map", "code_reduce", or "code_filter" +- code: Python code containing the transform function. For map, the function must take a single document as input and return a document (a dictionary). For reduce, the function must take a list of documents as input and return a single aggregated document (a dictionary). For filter, the function must take a single document as input and return a boolean value indicating whether to keep the document. + +### Optional Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| drop_keys | List of keys to remove from output (code_map only) | None | +| reduce_key | Key(s) to group by (code_reduce only) | "_all" | +| pass_through | Pass through unmodified keys from first item in group (code_reduce only) | false | diff --git a/mkdocs.yml b/mkdocs.yml index c28dfb66..e9cf8673 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -35,6 +35,7 @@ nav: - Gather: operators/gather.md - Unnest: operators/unnest.md - Sample: operators/sample.md + - Code: operators/code.md - Optimization: - Overview: optimization/overview.md - Example: optimization/example.md diff --git a/pyproject.toml b/pyproject.toml index 2a7faf57..00aeadaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,9 @@ gather = "docetl.operations.gather:GatherOperation" cluster = "docetl.operations.cluster:ClusterOperation" sample = "docetl.operations.sample:SampleOperation" link_resolve = "docetl.operations.link_resolve:LinkResolveOperation" +code_map = "docetl.operations.code_operations:CodeMapOperation" +code_reduce = "docetl.operations.code_operations:CodeReduceOperation" +code_filter = "docetl.operations.code_operations:CodeFilterOperation" [tool.poetry.plugins."docetl.parser"] llama_index_simple_directory_reader = "docetl.parsing_tools:llama_index_simple_directory_reader" diff --git a/tests/basic/test_code_operations.py b/tests/basic/test_code_operations.py new file mode 100644 index 00000000..2d367f8d --- /dev/null +++ b/tests/basic/test_code_operations.py @@ -0,0 +1,176 @@ +import pytest +from docetl.operations.code_operations import CodeMapOperation + +class MockRunner: + """A simple mock runner for testing""" + def __init__(self, config=None): + self.config = config or {} + self.console = None + +@pytest.fixture +def mock_runner(): + return MockRunner() + +@pytest.fixture +def sample_docs(): + return [ + { + "text": "The quick brown fox jumped over two tired turtles. Today is terrific!" + }, + { + "text": "Testing multiple sentences. This is another test. Totally awesome." + }, + { + "text": "No words with t here." + } + ] + +@pytest.fixture +def sentence_counter_op(mock_runner): + return CodeMapOperation( + config={ + "name": "sentence_counter", + "type": "code_map", + "code": """ +def transform(doc): + text = doc.get('text', '') + sentences = [s.strip() for s in text.split('.') if s.strip()] + return { + 'text': text, + 'sentence_count': len(sentences) + } +""" + }, + runner=mock_runner, + default_model="gpt-staru-turbo", + max_threads=4 + ) + +@pytest.fixture +def t_word_extractor_op(mock_runner): + return CodeMapOperation( + config={ + "name": "t_word_extractor", + "type": "code_map", + "code": """ +def transform(doc): + text = doc.get('text', '') + # Only match actual words (more than one character) + t_words = [ + word.strip('.,!?') + for word in text.split() + if word.lower().startswith('t') and len(word.strip('.,!?')) > 1 + ] + return { + 'text': text, + 't_words': t_words, + 't_word_count': len(t_words) + } +""" + }, + runner=mock_runner, + default_model="gpt-staru-turbo", + max_threads=4 + ) + +def test_sentence_counter(sentence_counter_op, sample_docs): + results, cost = sentence_counter_op.execute(sample_docs) + + assert len(results) == 3 + assert results[0]['sentence_count'] == 2 + assert results[1]['sentence_count'] == 3 + assert results[2]['sentence_count'] == 1 + + assert cost == 0.0 + + for original, result in zip(sample_docs, results): + assert result['text'] == original['text'] + +def test_t_word_extractor(t_word_extractor_op, sample_docs): + results, cost = t_word_extractor_op.execute(sample_docs) + + assert len(results[0]['t_words']) == 6 + assert set(results[0]['t_words']) == {'The', 'two', 'tired', 'turtles', 'Today', 'terrific'} + assert results[0]['t_word_count'] == 6 + + assert len(results[1]['t_words']) == 4 + assert set(results[1]['t_words']) == {'Testing', 'This', 'test', 'Totally'} + assert results[1]['t_word_count'] == 4 + + assert len(results[2]['t_words']) == 0 + assert results[2]['t_word_count'] == 0 + + assert cost == 0.0 + +def test_invalid_code(): + """Test that invalid Python code raises appropriate error""" + with pytest.raises(ValueError) as exc_info: + CodeMapOperation( + config={ + "name": "invalid_code", + "type": "code_map", + "code": """ +def transform(doc): + this is invalid python code + return {} +""" + }, + runner=MockRunner(), + default_model="gpt-staru-turbo", + max_threads=4 + ) + assert "Invalid code configuration" in str(exc_info.value) + +def test_missing_transform_function(): + """Test that code without transform function raises error""" + with pytest.raises(ValueError) as exc_info: + CodeMapOperation( + config={ + "name": "missing_transform", + "type": "code_map", + "code": """ +def some_other_function(doc): + return {} +""" + }, + runner=MockRunner(), + default_model="gpt-staru-turbo", + max_threads=4 + ) + assert "Code must define a 'transform' function" in str(exc_info.value) + +def test_empty_input(sentence_counter_op): + """Test handling of empty input list""" + results, cost = sentence_counter_op.execute([]) + assert results == [] + assert cost == 0.0 + +def test_missing_text_field(sentence_counter_op): + """Test handling of documents without 'text' field""" + doc_without_text = [{"other_field": "value"}] + results, cost = sentence_counter_op.execute(doc_without_text) + assert results[0]['sentence_count'] == 0 + +def test_drop_keys(mock_runner): + """Test that drop_keys configuration works""" + op = CodeMapOperation( + config={ + "name": "drop_test", + "type": "code_map", + "code": """ +def transform(doc): + return { + 'keep_this': 'value', + 'drop_this': 'should not appear' + } +""", + "drop_keys": ["drop_this"] + }, + runner=mock_runner, + default_model="gpt-staru-turbo", + max_threads=4 + ) + + results, _ = op.execute([{"text": "dummy"}]) + assert 'keep_this' in results[0] + assert 'drop_this' not in results[0] \ No newline at end of file