@@ -17,7 +17,9 @@ use common::json::Json;
17
17
use common:: mmap_array:: MmapArray ;
18
18
use common:: remap:: RemappedCollection ;
19
19
use common:: vec2:: Vec2 ;
20
- use k_means:: { k_means, k_means_lookup, k_means_lookup_many} ;
20
+ use k_means:: {
21
+ centroids_square, k_means, k_means_lookup, k_means_lookup_by_dot, k_means_lookup_many_by_dot,
22
+ } ;
21
23
use rayon:: iter:: { IntoParallelIterator , ParallelIterator } ;
22
24
use std:: fs:: create_dir;
23
25
use std:: path:: Path ;
@@ -30,6 +32,7 @@ pub struct Rabitq<O: Op> {
30
32
payloads : MmapArray < Payload > ,
31
33
offsets : Json < Vec < u32 > > ,
32
34
projected_centroids : Json < Vec2 < f32 > > ,
35
+ centroids_square : Json < Vec < f32 > > ,
33
36
projection : Json < Vec < Vec < f32 > > > ,
34
37
}
35
38
@@ -70,15 +73,18 @@ impl<O: Op> Rabitq<O> {
70
73
) -> Box < dyn Iterator < Item = Element > + ' a > {
71
74
let projected_query = O :: proj ( & self . projection , O :: cast ( vector) ) ;
72
75
let lists = select (
73
- k_means_lookup_many ( & projected_query, & self . projected_centroids ) ,
76
+ k_means_lookup_many_by_dot ( & projected_query, & self . centroids , & self . centroids_square ) ,
74
77
opts. rabitq_nprobe as usize ,
75
78
) ;
76
79
let mut heap = Vec :: new ( ) ;
77
- for & ( _, i) in lists. iter ( ) {
78
- let preprocessed = self . quantization . preprocess ( & O :: residual (
80
+ for & ( _, payload) in lists. iter ( ) {
81
+ let ( centroid_dot_dis, original_square, centroids_square, i) = payload;
82
+ let preprocessed = self . quantization . preprocess (
79
83
& projected_query,
80
- & self . projected_centroids [ ( i, ) ] ,
81
- ) ) ;
84
+ centroid_dot_dis,
85
+ original_square,
86
+ centroids_square,
87
+ ) ;
82
88
let start = self . offsets [ i] ;
83
89
let end = self . offsets [ i + 1 ] ;
84
90
self . quantization . push_batch (
@@ -135,6 +141,7 @@ fn from_nothing<O: Op>(
135
141
let samples = O :: sample ( collection, nlist) ;
136
142
rayon:: check ( ) ;
137
143
let centroids: Vec2 < f32 > = k_means ( nlist as usize , samples, true , spherical_centroids, false ) ;
144
+ let centroids_squares = centroids_square ( & centroids) ;
138
145
rayon:: check ( ) ;
139
146
let ls = ( 0 ..collection. len ( ) )
140
147
. into_par_iter ( )
@@ -174,8 +181,12 @@ fn from_nothing<O: Op>(
174
181
collection. len ( ) ,
175
182
|vector| {
176
183
let vector = O :: cast ( collection. vector ( vector) ) ;
177
- let target = k_means_lookup ( vector, & centroids) ;
178
- O :: proj ( & projection, & O :: residual ( vector, & centroids[ ( target, ) ] ) )
184
+ let ( centroid_dot_dis, target) =
185
+ k_means_lookup_by_dot ( vector, & centroids, & centroids_squares) ;
186
+ (
187
+ O :: proj ( & projection, & O :: residual ( vector, & centroids[ ( target, ) ] ) ) ,
188
+ centroid_dot_dis,
189
+ )
179
190
} ,
180
191
) ;
181
192
let projected_centroids = Vec2 :: from_vec (
@@ -184,6 +195,7 @@ fn from_nothing<O: Op>(
184
195
. flat_map ( |x| O :: proj ( & projection, & centroids[ ( x, ) ] ) )
185
196
. collect ( ) ,
186
197
) ;
198
+ let projected_centroids_square = centroids_square ( & projected_centroids) ;
187
199
let payloads = MmapArray :: create (
188
200
path. as_ref ( ) . join ( "payloads" ) ,
189
201
( 0 ..collection. len ( ) ) . map ( |i| collection. payload ( i) ) ,
@@ -193,12 +205,17 @@ fn from_nothing<O: Op>(
193
205
path. as_ref ( ) . join ( "projected_centroids" ) ,
194
206
projected_centroids,
195
207
) ;
208
+ let centroids_square = Json :: create (
209
+ path. as_ref ( ) . join ( "centroids_square" ) ,
210
+ projected_centroids_square,
211
+ ) ;
196
212
let projection = Json :: create ( path. as_ref ( ) . join ( "projection" ) , projection) ;
197
213
Rabitq {
198
214
storage,
199
215
payloads,
200
216
offsets,
201
217
projected_centroids,
218
+ centroids_square,
202
219
quantization,
203
220
projection,
204
221
}
@@ -210,18 +227,20 @@ fn open<O: Op>(path: impl AsRef<Path>) -> Rabitq<O> {
210
227
let payloads = MmapArray :: open ( path. as_ref ( ) . join ( "payloads" ) ) ;
211
228
let offsets = Json :: open ( path. as_ref ( ) . join ( "offsets" ) ) ;
212
229
let projected_centroids = Json :: open ( path. as_ref ( ) . join ( "projected_centroids" ) ) ;
230
+ let centroids_square = Json :: open ( path. as_ref ( ) . join ( "centroids_square" ) ) ;
213
231
let projection = Json :: open ( path. as_ref ( ) . join ( "projection" ) ) ;
214
232
Rabitq {
215
233
storage,
216
234
quantization,
217
235
payloads,
218
236
offsets,
219
237
projected_centroids,
238
+ centroids_square,
220
239
projection,
221
240
}
222
241
}
223
242
224
- fn select ( mut lists : Vec < ( f32 , usize ) > , n : usize ) -> Vec < ( f32 , usize ) > {
243
+ fn select < T > ( mut lists : Vec < ( f32 , T ) > , n : usize ) -> Vec < ( f32 , T ) > {
225
244
if lists. is_empty ( ) || n == 0 {
226
245
return Vec :: new ( ) ;
227
246
}
0 commit comments