diff --git a/docetl/operations/cluster.py b/docetl/operations/cluster.py index cc6a8209..033e3bf4 100644 --- a/docetl/operations/cluster.py +++ b/docetl/operations/cluster.py @@ -1,3 +1,4 @@ +import numpy as np from jinja2 import Environment, Template from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Tuple @@ -99,6 +100,9 @@ def execute( ) tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings) + + if "collapse" in self.config: + tree = self.collapse_tree(tree, collapse = self.config["collapse"]) self.prompt_template = Template(self.config["summary_prompt"]) cost += self.annotate_clustering_tree(tree) @@ -122,7 +126,7 @@ def build_tree(i): # res["embedding"] = list(embeddings[i]) return res return { - "children": [ + "children": [ build_tree(cl.children_[i - nsamples, 0]), build_tree(cl.children_[i - nsamples, 1]), ], @@ -131,6 +135,40 @@ def build_tree(i): return build_tree(nsamples + len(cl.children_) - 1) + def get_tree_distances(self, t): + res = set() + if "distance" in t: + res.update(set([t["distance"] - child["distance"] for child in t["children"] if "distance" in child])) + if "children" in t: + for child in t["children"]: + res.update(self.get_tree_distances(child)) + return res + + def _collapse_tree(self, t, parent_dist = None, collapse = None): + if "children" in t: + if ( "distance" in t + and parent_dist is not None + and collapse is not None + and parent_dist - t["distance"] < collapse): + return [grandchild + for child in t["children"] + for grandchild in self._collapse_tree(child, parent_dist=parent_dist, collapse=collapse)] + else: + res = dict(t) + res["children"] = [grandchild + for idx, child in enumerate(t["children"]) + for grandchild in self._collapse_tree(child, parent_dist=t["distance"], collapse=collapse)] + return [res] + else: + return [t] + + def collapse_tree(self, tree, collapse = None): + if collapse is not None: + tree_distances = np.array(sorted(self.get_tree_distances(tree))) + collapse = tree_distances[int(len(tree_distances) * collapse)] + return self._collapse_tree(tree, collapse=collapse)[0] + + def annotate_clustering_tree(self, t): if "children" in t: with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor: @@ -149,12 +187,8 @@ def annotate_clustering_tree(self, t): total_cost += futures[i].result() pbar.update(i) - assert len(t["children"]) == 2, ( - "Agglomerative clustering is supposed to generate clusters with 2 children each, but this cluster has %s" - % len(t["children"]) - ) prompt = self.prompt_template.render( - left=t["children"][0], right=t["children"][1] + inputs=t["children"] ) def validation_fn(response: Dict[str, Any]): diff --git a/tests/basic/test_cluster.py b/tests/basic/test_cluster.py index 23835e8e..0bcaa717 100644 --- a/tests/basic/test_cluster.py +++ b/tests/basic/test_cluster.py @@ -17,11 +17,10 @@ def cluster_config(): these two concepts already encompasses the other; in that case, you should just use that concept. - {{left.concept}}: - {{left.description}} - - {{right.concept}}: - {{right.description}} + {% for input in inputs %} + {{input.concept}}: + {{input.description}} + {% endfor %} Provide the title of the super-concept, and a description. """,