Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge staging to main (after adding cluster operator) #88

Merged
merged 10 commits into from
Oct 9, 2024
37 changes: 13 additions & 24 deletions docetl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,14 @@ def _validate_parsing(
for tool in parsing_tools:
if (
not isinstance(tool, dict)
or "input_key" not in tool
or "function" not in tool
or "output_key" not in tool
):
raise ValueError(
"Each parsing tool must be a dictionary with 'input_key', 'function', and 'output_key' keys"
"Each parsing tool must be a dictionary with a 'function' key and any arguments required by that function"
)
if (
not isinstance(tool["input_key"], str)
or not isinstance(tool["function"], str)
or not isinstance(tool["output_key"], str)
):
if not isinstance(tool["function"], str):
raise ValueError(
"'input_key', 'function', and 'output_key' in parsing tools must be strings"
"'function' in parsing tools must be a string"
)
if "function_kwargs" in tool and not isinstance(
tool["function_kwargs"], dict
Expand Down Expand Up @@ -213,19 +207,12 @@ def load(self) -> List[Dict]:
def _process_item(
self,
item: Dict[str, Any],
input_key: str,
output_key: str,
func: Callable,
**function_kwargs: Dict[str, Any],
):
if input_key not in item:
raise ValueError(f"Input key {input_key} not found in item: {item}")
result = func(item[input_key], **function_kwargs)
if isinstance(result, list):
return [item.copy() | {output_key: res} for res in result]
else:
return [item | {output_key: result}]

result = func(item, **function_kwargs)
return [item.copy() | res for res in result]

def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
"""
Apply parsing tools to the data.
Expand All @@ -240,7 +227,13 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
ValueError: If a parsing tool is not found or if an input key is missing from an item.
"""
for tool in self.parsing:
input_key = tool["input_key"]
function_kwargs = dict(tool)
function_kwargs.pop("function")
# FIXME: The following is just for backwards compatibility
# with the existing yaml format...
if "function_kwargs" in function_kwargs:
function_kwargs.update(function_kwargs.pop("function_kwargs"))

try:
func = get_parser(tool["function"])
except KeyError:
Expand All @@ -261,17 +254,13 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
f"Parsing tool {tool['function']} not found. Please define it or use one of our existing parsing tools: {get_parsing_tools()}"
)

output_key = tool["output_key"]
function_kwargs = tool.get("function_kwargs", {})
new_data = []

with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
self._process_item,
item,
input_key,
output_key,
func,
**function_kwargs,
)
Expand Down
206 changes: 206 additions & 0 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from jinja2 import Environment, Template
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple
from .base import BaseOperation
from .utils import RichLoopBar
from .clustering_utils import get_embeddings_for_clustering


class ClusterOperation(BaseOperation):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.max_batch_size: int = self.config.get(
"max_batch_size", kwargs.get("max_batch_size", float("inf"))
)

def syntax_check(self) -> None:
"""
Checks the configuration of the ClusterOperation for required keys and valid structure.

Raises:
ValueError: If required keys are missing or invalid in the configuration.
TypeError: If configuration values have incorrect types.
"""
required_keys = ["embedding_keys", "summary_schema", "summary_prompt"]
for key in required_keys:
if key not in self.config:
raise ValueError(
f"Missing required key '{key}' in ClusterOperation configuration"
)

if not isinstance(self.config["embedding_keys"], list):
raise TypeError("'embedding_keys' must be a list of strings")

if "output_key" in self.config:
if not isinstance(self.config["output_key"], str):
raise TypeError("'output_key' must be a string")

if not isinstance(self.config["summary_schema"], dict):
raise TypeError("'summary_schema' must be a dictionary")

if not isinstance(self.config["summary_prompt"], str):
raise TypeError("'prompt' must be a string")

# Check if the prompt is a valid Jinja2 template
try:
Template(self.config["summary_prompt"])
except Exception as e:
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")

# Check optional parameters
if "max_batch_size" in self.config:
if not isinstance(self.config["max_batch_size"], int):
raise TypeError("'max_batch_size' must be an integer")

if "embedding_model" in self.config:
if not isinstance(self.config["embedding_model"], str):
raise TypeError("'embedding_model' must be a string")

if "model" in self.config:
if not isinstance(self.config["model"], str):
raise TypeError("'model' must be a string")

if "validate" in self.config:
if not isinstance(self.config["validate"], list):
raise TypeError("'validate' must be a list of strings")
for rule in self.config["validate"]:
if not isinstance(rule, str):
raise TypeError("Each validation rule must be a string")

def execute(
self, input_data: List[Dict], is_build: bool = False
) -> Tuple[List[Dict], float]:
"""
Executes the cluster operation on the input data. Modifies the
input data and returns it in place.

Args:
input_data (List[Dict]): A list of dictionaries to process.
is_build (bool): Whether the operation is being executed
in the build phase. Defaults to False.

Returns:
Tuple[List[Dict], float]: A tuple containing the clustered
list of dictionaries and the total cost of the operation.
"""
if not input_data:
return input_data, 0

if len(input_data) == 1:
input_data[0][self.config.get("output_key", "clusters")] = ()
return input_data, 0

embeddings, cost = get_embeddings_for_clustering(
input_data, self.config, self.runner.api
)

tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings)

self.prompt_template = Template(self.config["summary_prompt"])
cost += self.annotate_clustering_tree(tree)
self.annotate_leaves(tree)

return input_data, cost

def agglomerative_cluster_of_embeddings(self, input_data, embeddings):
import sklearn.cluster

cl = sklearn.cluster.AgglomerativeClustering(
compute_full_tree=True, compute_distances=True
)
cl.fit(embeddings)

nsamples = len(embeddings)

def build_tree(i):
if i < nsamples:
res = input_data[i]
# res["embedding"] = list(embeddings[i])
return res
return {
"children": [
build_tree(cl.children_[i - nsamples, 0]),
build_tree(cl.children_[i - nsamples, 1]),
],
"distance": cl.distances_[i - nsamples],
}

return build_tree(nsamples + len(cl.children_) - 1)

def annotate_clustering_tree(self, t):
if "children" in t:
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
futures = [
executor.submit(self.annotate_clustering_tree, child)
for child in t["children"]
]

total_cost = 0
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (map) on all documents",
console=self.console,
)
for i in pbar:
total_cost += futures[i].result()
pbar.update(i)

assert len(t["children"]) == 2, (
"Agglomerative clustering is supposed to generate clusters with 2 children each, but this cluster has %s"
% len(t["children"])
)
prompt = self.prompt_template.render(
left=t["children"][0], right=t["children"][1]
)

def validation_fn(response: Dict[str, Any]):
output = self.runner.api.parse_llm_response(
response,
schema=self.config["summary_schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
if self.runner.api.validate_output(self.config, output, self.console):
return output, True
return output, False

output, cost, success = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
model=self.config.get("model", self.default_model),
operation_type="cluster",
schema=self.config["summary_schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm(
self.config.get("model", self.default_model),
"cluster",
messages,
self.config["summary_schema"],
tools=self.config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
)
total_cost += cost

t.update(output)

return total_cost
return 0

def annotate_leaves(self, tree, path=()):
if "children" in tree:
item = dict(tree)
item.pop("children")
for child in tree["children"]:
self.annotate_leaves(child, path=(item,) + path)
else:
tree[self.config.get("output_key", "clusters")] = path
1 change: 1 addition & 0 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def validation_fn(response: Dict[str, Any]):
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
verbose=self.config.get("verbose", False),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
Expand Down
6 changes: 6 additions & 0 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
Returns:
Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.
"""
if self.config.get("gleaning", {}).get("validation_prompt", None):
self.console.log(
f"Using gleaning with validation prompt: {self.config.get('gleaning', {}).get('validation_prompt', '')}"
)

reduce_keys = self.config["reduce_key"]
if isinstance(reduce_keys, str):
reduce_keys = [reduce_keys]
Expand Down Expand Up @@ -860,6 +865,7 @@ def _batch_reduce(
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
verbose=self.config.get("verbose", False),
)
item_cost += gleaning_cost
else:
Expand Down
12 changes: 9 additions & 3 deletions docetl/operations/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
split_key = self.config["split_key"]
method = self.config["method"]
method_kwargs = self.config["method_kwargs"]
encoder = tiktoken.encoding_for_model(
self.config["method_kwargs"].get("model", self.default_model).split("/")[-1]
)
try:
encoder = tiktoken.encoding_for_model(
self.config["method_kwargs"]
.get("model", self.default_model)
.split("/")[-1]
)
except Exception:
encoder = tiktoken.encoding_for_model("gpt-4o")

results = []
cost = 0.0

Expand Down
18 changes: 12 additions & 6 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def safe_eval(expression: str, output: Dict) -> bool:
# Safely evaluate the expression
return bool(aeval(expression))
except Exception:
return False
# try to evaluate with python eval
try:
return bool(eval(expression, locals={"output": output}))
except Exception:
return False


class APIWrapper(object):
Expand Down Expand Up @@ -720,6 +724,7 @@ def call_llm_with_gleaning(
console: Console = Console(),
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
verbose: bool = False,
) -> Tuple[str, float]:
"""
Call LLM with a gleaning process, including validation and improvement rounds.
Expand Down Expand Up @@ -789,7 +794,7 @@ def call_llm_with_gleaning(
# Call LLM for validation
self.runner.rate_limiter.try_acquire("llm_call", weight=1)
validator_response = completion(
model="gpt-4o-mini",
model=model,
messages=truncate_messages(
messages + [{"role": "user", "content": validator_prompt}], model
),
Expand Down Expand Up @@ -817,9 +822,10 @@ def call_llm_with_gleaning(
if not suggestion["should_refine"]:
break

# console.log(
# f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}"
# )
if verbose:
console.log(
f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}"
)

# Prompt for improvement
improvement_prompt = f"""Based on the validation feedback:
Expand Down Expand Up @@ -1166,4 +1172,4 @@ def rich_as_completed(futures, total=None, desc=None, leave=True, console=None):
with RichLoopBar(total=total, desc=desc, leave=leave, console=console) as pbar:
for future in as_completed(futures):
yield future
pbar.update()
pbar.update()
2 changes: 1 addition & 1 deletion docetl/optimizers/reduce_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ def _calculate_compression_ratio(
reduce_key = op_config["reduce_key"]
input_schema = op_config.get("input", {}).get("schema", {})
output_schema = op_config["output"]["schema"]
model = op_config.get("model", "gpt-4o")
model = op_config.get("model", "gpt-4o-mini")

compression_ratios = {}

Expand Down
Loading
Loading