-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge staging to main (after adding cluster operator) (#88)
* Parsers can now return any number of fields, and can access the whole item * nit: change gpt-4o to gpt-4o-mini in tests * feat: add verbose parameter for gleaning * feat: add verbose parameter for gleaning * fix: tokenizers should be wrapped in try catch * fix: resort to eval if ast eval does not work * docs: update docs to reflect new custom parsing API Co-authored-by: redhog <redhog@users.noreply.github.com> * Clustering (#84) * nit: change gpt-4o to gpt-4o-mini in tests * feat: add verbose parameter for gleaning * feat: add verbose parameter for gleaning * fix: tokenizers should be wrapped in try catch * fix: resort to eval if ast eval does not work * Merge staging to main (after parsers refactor) (#82) * Parsers can now return any number of fields, and can access the whole item * nit: change gpt-4o to gpt-4o-mini in tests * feat: add verbose parameter for gleaning * feat: add verbose parameter for gleaning * fix: tokenizers should be wrapped in try catch * fix: resort to eval if ast eval does not work * docs: update docs to reflect new custom parsing API --------- Co-authored-by: Egil <egil.moller@freecode.no> * Added new clustering operation * Reverse path * Added docs for cluster operator * Bugfix for docs formatting * docs: add sample parameter (#87) * Added new clustering operation * Reverse path * Added docs for cluster operator * Bugfix for docs formatting * add tests and link to doc --------- Co-authored-by: Shreya Shankar <ss.shankar505@gmail.com> Co-authored-by: Egil <egil.moller@freecode.no> * fix: fixing params in test --------- Co-authored-by: Egil <egil.moller@freecode.no> Co-authored-by: redhog <redhog@users.noreply.github.com> Co-authored-by: Egil Möller <redhog@redhog.org>
- Loading branch information
1 parent
29ace39
commit 70604cb
Showing
6 changed files
with
522 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
from jinja2 import Environment, Template | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Any, Dict, List, Optional, Tuple | ||
from .base import BaseOperation | ||
from .utils import RichLoopBar | ||
from .clustering_utils import get_embeddings_for_clustering | ||
|
||
|
||
class ClusterOperation(BaseOperation): | ||
def __init__( | ||
self, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
self.max_batch_size: int = self.config.get( | ||
"max_batch_size", kwargs.get("max_batch_size", float("inf")) | ||
) | ||
|
||
def syntax_check(self) -> None: | ||
""" | ||
Checks the configuration of the ClusterOperation for required keys and valid structure. | ||
Raises: | ||
ValueError: If required keys are missing or invalid in the configuration. | ||
TypeError: If configuration values have incorrect types. | ||
""" | ||
required_keys = ["embedding_keys", "summary_schema", "summary_prompt"] | ||
for key in required_keys: | ||
if key not in self.config: | ||
raise ValueError( | ||
f"Missing required key '{key}' in ClusterOperation configuration" | ||
) | ||
|
||
if not isinstance(self.config["embedding_keys"], list): | ||
raise TypeError("'embedding_keys' must be a list of strings") | ||
|
||
if "output_key" in self.config: | ||
if not isinstance(self.config["output_key"], str): | ||
raise TypeError("'output_key' must be a string") | ||
|
||
if not isinstance(self.config["summary_schema"], dict): | ||
raise TypeError("'summary_schema' must be a dictionary") | ||
|
||
if not isinstance(self.config["summary_prompt"], str): | ||
raise TypeError("'prompt' must be a string") | ||
|
||
# Check if the prompt is a valid Jinja2 template | ||
try: | ||
Template(self.config["summary_prompt"]) | ||
except Exception as e: | ||
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}") | ||
|
||
# Check optional parameters | ||
if "max_batch_size" in self.config: | ||
if not isinstance(self.config["max_batch_size"], int): | ||
raise TypeError("'max_batch_size' must be an integer") | ||
|
||
if "embedding_model" in self.config: | ||
if not isinstance(self.config["embedding_model"], str): | ||
raise TypeError("'embedding_model' must be a string") | ||
|
||
if "model" in self.config: | ||
if not isinstance(self.config["model"], str): | ||
raise TypeError("'model' must be a string") | ||
|
||
if "validate" in self.config: | ||
if not isinstance(self.config["validate"], list): | ||
raise TypeError("'validate' must be a list of strings") | ||
for rule in self.config["validate"]: | ||
if not isinstance(rule, str): | ||
raise TypeError("Each validation rule must be a string") | ||
|
||
def execute( | ||
self, input_data: List[Dict], is_build: bool = False | ||
) -> Tuple[List[Dict], float]: | ||
""" | ||
Executes the cluster operation on the input data. Modifies the | ||
input data and returns it in place. | ||
Args: | ||
input_data (List[Dict]): A list of dictionaries to process. | ||
is_build (bool): Whether the operation is being executed | ||
in the build phase. Defaults to False. | ||
Returns: | ||
Tuple[List[Dict], float]: A tuple containing the clustered | ||
list of dictionaries and the total cost of the operation. | ||
""" | ||
if not input_data: | ||
return input_data, 0 | ||
|
||
if len(input_data) == 1: | ||
input_data[0][self.config.get("output_key", "clusters")] = () | ||
return input_data, 0 | ||
|
||
embeddings, cost = get_embeddings_for_clustering( | ||
input_data, self.config, self.runner.api | ||
) | ||
|
||
tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings) | ||
|
||
self.prompt_template = Template(self.config["summary_prompt"]) | ||
cost += self.annotate_clustering_tree(tree) | ||
self.annotate_leaves(tree) | ||
|
||
return input_data, cost | ||
|
||
def agglomerative_cluster_of_embeddings(self, input_data, embeddings): | ||
import sklearn.cluster | ||
|
||
cl = sklearn.cluster.AgglomerativeClustering( | ||
compute_full_tree=True, compute_distances=True | ||
) | ||
cl.fit(embeddings) | ||
|
||
nsamples = len(embeddings) | ||
|
||
def build_tree(i): | ||
if i < nsamples: | ||
res = input_data[i] | ||
# res["embedding"] = list(embeddings[i]) | ||
return res | ||
return { | ||
"children": [ | ||
build_tree(cl.children_[i - nsamples, 0]), | ||
build_tree(cl.children_[i - nsamples, 1]), | ||
], | ||
"distance": cl.distances_[i - nsamples], | ||
} | ||
|
||
return build_tree(nsamples + len(cl.children_) - 1) | ||
|
||
def annotate_clustering_tree(self, t): | ||
if "children" in t: | ||
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor: | ||
futures = [ | ||
executor.submit(self.annotate_clustering_tree, child) | ||
for child in t["children"] | ||
] | ||
|
||
total_cost = 0 | ||
pbar = RichLoopBar( | ||
range(len(futures)), | ||
desc=f"Processing {self.config['name']} (map) on all documents", | ||
console=self.console, | ||
) | ||
for i in pbar: | ||
total_cost += futures[i].result() | ||
pbar.update(i) | ||
|
||
assert len(t["children"]) == 2, ( | ||
"Agglomerative clustering is supposed to generate clusters with 2 children each, but this cluster has %s" | ||
% len(t["children"]) | ||
) | ||
prompt = self.prompt_template.render( | ||
left=t["children"][0], right=t["children"][1] | ||
) | ||
|
||
def validation_fn(response: Dict[str, Any]): | ||
output = self.runner.api.parse_llm_response( | ||
response, | ||
schema=self.config["summary_schema"], | ||
manually_fix_errors=self.manually_fix_errors, | ||
)[0] | ||
if self.runner.api.validate_output(self.config, output, self.console): | ||
return output, True | ||
return output, False | ||
|
||
output, cost, success = self.runner.api.call_llm_with_validation( | ||
[{"role": "user", "content": prompt}], | ||
model=self.config.get("model", self.default_model), | ||
operation_type="cluster", | ||
schema=self.config["summary_schema"], | ||
llm_call_fn=lambda messages: self.runner.api.call_llm( | ||
self.config.get("model", self.default_model), | ||
"cluster", | ||
messages, | ||
self.config["summary_schema"], | ||
tools=self.config.get("tools", None), | ||
console=self.console, | ||
timeout_seconds=self.config.get("timeout", 120), | ||
max_retries_per_timeout=self.config.get( | ||
"max_retries_per_timeout", 2 | ||
), | ||
), | ||
validation_fn=validation_fn, | ||
val_rule=self.config.get("validate", []), | ||
num_retries=self.num_retries_on_validate_failure, | ||
console=self.console, | ||
) | ||
total_cost += cost | ||
|
||
t.update(output) | ||
|
||
return total_cost | ||
return 0 | ||
|
||
def annotate_leaves(self, tree, path=()): | ||
if "children" in tree: | ||
item = dict(tree) | ||
item.pop("children") | ||
for child in tree["children"]: | ||
self.annotate_leaves(child, path=(item,) + path) | ||
else: | ||
tree[self.config.get("output_key", "clusters")] = path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.