@@ -17,7 +17,7 @@ use common::json::Json;
17
17
use common:: mmap_array:: MmapArray ;
18
18
use common:: remap:: RemappedCollection ;
19
19
use common:: vec2:: Vec2 ;
20
- use k_means:: { k_means, k_means_lookup, k_means_lookup_many } ;
20
+ use k_means:: { k_means, k_means_lookup, k_means_lookup_by_dot , k_means_lookup_many_by_dot } ;
21
21
use std:: fs:: create_dir;
22
22
use std:: path:: Path ;
23
23
use stoppable_rayon as rayon;
@@ -69,14 +69,18 @@ impl<O: Op> Rabitq<O> {
69
69
) -> Box < dyn Iterator < Item = Element > + ' a > {
70
70
let projected_query = O :: proj ( & self . projection , O :: cast ( vector) ) ;
71
71
let lists = select (
72
- k_means_lookup_many ( & projected_query, & self . centroids ) ,
72
+ k_means_lookup_many_by_dot ( & projected_query, & self . centroids ) ,
73
73
opts. rabitq_nprobe as usize ,
74
74
) ;
75
75
let mut heap = Vec :: new ( ) ;
76
- for & ( _, i) in lists. iter ( ) {
77
- let preprocessed = self
78
- . quantization
79
- . preprocess ( & O :: residual ( & projected_query, & self . centroids [ ( i, ) ] ) ) ;
76
+ for & ( _, payload) in lists. iter ( ) {
77
+ let ( centroid_dot_dis, original_square, centroids_square, i) = payload;
78
+ let preprocessed = self . quantization . preprocess (
79
+ & O :: residual ( & projected_query, & self . centroids [ ( i, ) ] ) ,
80
+ centroid_dot_dis,
81
+ original_square,
82
+ centroids_square,
83
+ ) ;
80
84
let start = self . offsets [ i] ;
81
85
let end = self . offsets [ i + 1 ] ;
82
86
self . quantization . push_batch (
@@ -150,8 +154,13 @@ fn from_nothing<O: Op>(
150
154
collection. len ( ) ,
151
155
|vector| {
152
156
let vector = O :: cast ( collection. vector ( vector) ) ;
153
- let target = k_means_lookup ( vector, & centroids) ;
154
- O :: proj ( & projection, & O :: residual ( vector, & centroids[ ( target, ) ] ) )
157
+ let ( centroid_dot_dis, original_square, _, target) =
158
+ k_means_lookup_by_dot ( vector, & centroids) ;
159
+ (
160
+ O :: proj ( & projection, & O :: residual ( vector, & centroids[ ( target, ) ] ) ) ,
161
+ centroid_dot_dis,
162
+ original_square,
163
+ )
155
164
} ,
156
165
) ;
157
166
let projected_centroids = Vec2 :: from_vec (
@@ -194,7 +203,7 @@ fn open<O: Op>(path: impl AsRef<Path>) -> Rabitq<O> {
194
203
}
195
204
}
196
205
197
- fn select ( mut lists : Vec < ( f32 , usize ) > , n : usize ) -> Vec < ( f32 , usize ) > {
206
+ fn select < T > ( mut lists : Vec < ( f32 , T ) > , n : usize ) -> Vec < ( f32 , T ) > {
198
207
if lists. is_empty ( ) || n == 0 {
199
208
return Vec :: new ( ) ;
200
209
}
0 commit comments