Skip to content

Commit f45b3c3

Browse files
committedDec 1, 2024
merge face clusters
1 parent a8e527e commit f45b3c3

File tree

2 files changed

+46
-17
lines changed

2 files changed

+46
-17
lines changed
 

‎pix/api/face_clusters.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def list_face_clusters() -> List[FaceClusterDto]:
2929
image_repo = AppGraph.get_instance(ImageRepo)
3030
result = []
3131
for fc in fc_repo.all():
32+
if not fc.faces: continue
3233
face = fc.faces[0]
3334
result.append(FaceClusterDto(
3435
id=fc.id,

‎pix/task/facecluster.py

+45-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22
from dataclasses import dataclass
3-
from typing import List, Union
3+
from typing import Dict, List, Union
44
import uuid
55
import numpy as np
66
from sklearn.cluster import DBSCAN
@@ -16,6 +16,13 @@ class Face:
1616
face_cluster_id: Union[str, None]
1717
embedding: np.array
1818

19+
def to_face_cluster_face(self):
20+
return FaceClusterFace(
21+
image_id=self.image_id,
22+
index=self.index,
23+
embedding_hash="", # TODO
24+
)
25+
1926

2027
def main(
2128
image_repo: ImageRepo,
@@ -30,13 +37,7 @@ def main(
3037
face_cluster_repo.update(FaceCluster(
3138
id=fcid,
3239
label=None,
33-
faces=[
34-
FaceClusterFace(
35-
image_id=face.image_id,
36-
index=face.index,
37-
embedding_hash="", # TODO
38-
) for face in cluster
39-
],
40+
faces=[face.to_face_cluster_face() for face in cluster],
4041
))
4142
else:
4243
# existing cluster
@@ -49,19 +50,14 @@ def main(
4950
else:
5051
unassigned_faces.append(face)
5152

53+
primary_fcid, _ = max(face_by_existing_cluster.items(), key=lambda kv: len(kv[1]))
54+
5255
if len(face_by_existing_cluster) > 1:
53-
raise Exception("collision")
56+
try_merge_cluster(face_cluster_repo, face_by_existing_cluster)
5457

55-
primary_fcid = next(iter(face_by_existing_cluster.keys()))
5658
primary_fc = face_cluster_repo.get(primary_fcid)
5759
for face in unassigned_faces:
58-
primary_fc.faces.append(
59-
FaceClusterFace(
60-
image_id=face.image_id,
61-
index=face.index,
62-
embedding_hash="", # TODO
63-
)
64-
)
60+
primary_fc.faces.append(face.to_face_cluster_face())
6561
face_cluster_repo.update(primary_fc)
6662

6763

@@ -100,3 +96,35 @@ def normalize(v):
10096
clustered_faces[label].append(face)
10197

10298
return list(clustered_faces.values())
99+
100+
101+
def try_merge_cluster(
102+
face_cluster_repo: FaceClusterRepo,
103+
face_by_existing_cluster: Dict[str, List[Face]],
104+
):
105+
primary_fcid, _ = max(face_by_existing_cluster.items(), key=lambda kv: len(kv[1]))
106+
107+
for face_cluster_id, faces in face_by_existing_cluster.items():
108+
print(f"{face_cluster_id}: {len(faces)} faces")
109+
if input(f"Confirm merge into {primary_fcid} (y/n): ") != "y":
110+
raise Exception("collision")
111+
112+
with face_cluster_repo.db.transactional():
113+
primary_fc = face_cluster_repo.get(primary_fcid)
114+
other_face_clusters = [
115+
face_cluster_repo.get(fcid)
116+
for fcid in face_by_existing_cluster.keys()
117+
if fcid != primary_fcid
118+
]
119+
for fc in other_face_clusters:
120+
if len(fc.faces) != len(face_by_existing_cluster[fc.id]):
121+
raise Exception("cluster split")
122+
primary_fc.faces.extend([
123+
face.to_face_cluster_face()
124+
for face in face_by_existing_cluster[fc.id]
125+
])
126+
127+
# TODO: actually delete rows
128+
fc.faces = []
129+
face_cluster_repo.update(fc)
130+
face_cluster_repo.update(primary_fc)

0 commit comments

Comments
 (0)
Please sign in to comment.