1
1
from collections import defaultdict
2
2
from dataclasses import dataclass
3
- from typing import List , Union
3
+ from typing import Dict , List , Union
4
4
import uuid
5
5
import numpy as np
6
6
from sklearn .cluster import DBSCAN
@@ -16,6 +16,13 @@ class Face:
16
16
face_cluster_id : Union [str , None ]
17
17
embedding : np .array
18
18
19
+ def to_face_cluster_face (self ):
20
+ return FaceClusterFace (
21
+ image_id = self .image_id ,
22
+ index = self .index ,
23
+ embedding_hash = "" , # TODO
24
+ )
25
+
19
26
20
27
def main (
21
28
image_repo : ImageRepo ,
@@ -30,13 +37,7 @@ def main(
30
37
face_cluster_repo .update (FaceCluster (
31
38
id = fcid ,
32
39
label = None ,
33
- faces = [
34
- FaceClusterFace (
35
- image_id = face .image_id ,
36
- index = face .index ,
37
- embedding_hash = "" , # TODO
38
- ) for face in cluster
39
- ],
40
+ faces = [face .to_face_cluster_face () for face in cluster ],
40
41
))
41
42
else :
42
43
# existing cluster
@@ -49,19 +50,14 @@ def main(
49
50
else :
50
51
unassigned_faces .append (face )
51
52
53
+ primary_fcid , _ = max (face_by_existing_cluster .items (), key = lambda kv : len (kv [1 ]))
54
+
52
55
if len (face_by_existing_cluster ) > 1 :
53
- raise Exception ( "collision" )
56
+ try_merge_cluster ( face_cluster_repo , face_by_existing_cluster )
54
57
55
- primary_fcid = next (iter (face_by_existing_cluster .keys ()))
56
58
primary_fc = face_cluster_repo .get (primary_fcid )
57
59
for face in unassigned_faces :
58
- primary_fc .faces .append (
59
- FaceClusterFace (
60
- image_id = face .image_id ,
61
- index = face .index ,
62
- embedding_hash = "" , # TODO
63
- )
64
- )
60
+ primary_fc .faces .append (face .to_face_cluster_face ())
65
61
face_cluster_repo .update (primary_fc )
66
62
67
63
@@ -100,3 +96,35 @@ def normalize(v):
100
96
clustered_faces [label ].append (face )
101
97
102
98
return list (clustered_faces .values ())
99
+
100
+
101
+ def try_merge_cluster (
102
+ face_cluster_repo : FaceClusterRepo ,
103
+ face_by_existing_cluster : Dict [str , List [Face ]],
104
+ ):
105
+ primary_fcid , _ = max (face_by_existing_cluster .items (), key = lambda kv : len (kv [1 ]))
106
+
107
+ for face_cluster_id , faces in face_by_existing_cluster .items ():
108
+ print (f"{ face_cluster_id } : { len (faces )} faces" )
109
+ if input (f"Confirm merge into { primary_fcid } (y/n): " ) != "y" :
110
+ raise Exception ("collision" )
111
+
112
+ with face_cluster_repo .db .transactional ():
113
+ primary_fc = face_cluster_repo .get (primary_fcid )
114
+ other_face_clusters = [
115
+ face_cluster_repo .get (fcid )
116
+ for fcid in face_by_existing_cluster .keys ()
117
+ if fcid != primary_fcid
118
+ ]
119
+ for fc in other_face_clusters :
120
+ if len (fc .faces ) != len (face_by_existing_cluster [fc .id ]):
121
+ raise Exception ("cluster split" )
122
+ primary_fc .faces .extend ([
123
+ face .to_face_cluster_face ()
124
+ for face in face_by_existing_cluster [fc .id ]
125
+ ])
126
+
127
+ # TODO: actually delete rows
128
+ fc .faces = []
129
+ face_cluster_repo .update (fc )
130
+ face_cluster_repo .update (primary_fc )
0 commit comments