Skip to content

Commit

Permalink
Added collapse option to clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
Egil committed Oct 11, 2024
1 parent 22d3a40 commit 0b2af87
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from jinja2 import Environment, Template
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -99,6 +100,9 @@ def execute(
)

tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings)

if "collapse" in self.config:
tree = self.collapse_tree(tree, collapse = self.config["collapse"])

self.prompt_template = Template(self.config["summary_prompt"])
cost += self.annotate_clustering_tree(tree)
Expand All @@ -122,7 +126,7 @@ def build_tree(i):
# res["embedding"] = list(embeddings[i])
return res
return {
"children": [
"children": [
build_tree(cl.children_[i - nsamples, 0]),
build_tree(cl.children_[i - nsamples, 1]),
],
Expand All @@ -131,6 +135,40 @@ def build_tree(i):

return build_tree(nsamples + len(cl.children_) - 1)

def get_tree_distances(self, t):
res = set()
if "distance" in t:
res.update(set([t["distance"] - child["distance"] for child in t["children"] if "distance" in child]))
if "children" in t:
for child in t["children"]:
res.update(self.get_tree_distances(child))
return res

def _collapse_tree(self, t, parent_dist = None, collapse = None):
if "children" in t:
if ( "distance" in t
and parent_dist is not None
and collapse is not None
and parent_dist - t["distance"] < collapse):
return [grandchild
for child in t["children"]
for grandchild in self._collapse_tree(child, parent_dist=parent_dist, collapse=collapse)]
else:
res = dict(t)
res["children"] = [grandchild
for idx, child in enumerate(t["children"])
for grandchild in self._collapse_tree(child, parent_dist=t["distance"], collapse=collapse)]
return [res]
else:
return [t]

def collapse_tree(self, tree, collapse = None):
if collapse is not None:
tree_distances = np.array(sorted(self.get_tree_distances(tree)))
collapse = tree_distances[int(len(tree_distances) * collapse)]
return self._collapse_tree(tree, collapse=collapse)[0]


def annotate_clustering_tree(self, t):
if "children" in t:
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
Expand All @@ -149,12 +187,8 @@ def annotate_clustering_tree(self, t):
total_cost += futures[i].result()
pbar.update(i)

assert len(t["children"]) == 2, (
"Agglomerative clustering is supposed to generate clusters with 2 children each, but this cluster has %s"
% len(t["children"])
)
prompt = self.prompt_template.render(
left=t["children"][0], right=t["children"][1]
inputs=t["children"]
)

def validation_fn(response: Dict[str, Any]):
Expand Down

0 comments on commit 0b2af87

Please sign in to comment.