@@ -31,6 +31,7 @@ pub struct Rabitq<O: Op> {
31
31
offsets : Json < Vec < u32 > > ,
32
32
projected_centroids : Json < Vec2 < f32 > > ,
33
33
projection : Json < Vec < Vec < f32 > > > ,
34
+ is_residual : Json < bool > ,
34
35
}
35
36
36
37
impl < O : Op > Rabitq < O > {
@@ -74,21 +75,16 @@ impl<O: Op> Rabitq<O> {
74
75
opts. rabitq_nprobe as usize ,
75
76
) ;
76
77
let mut heap = Vec :: new ( ) ;
77
- for & ( _, i) in lists. iter ( ) {
78
+ for & ( dis_v2, i) in lists. iter ( ) {
79
+ let trans_vector = if * self . is_residual {
80
+ & O :: residual ( & projected_query, & self . projected_centroids [ ( i, ) ] )
81
+ } else {
82
+ & projected_query
83
+ } ;
78
84
let preprocessed = if opts. rabitq_fast_scan {
79
- self . quantization
80
- . fscan_preprocess ( & O :: residual (
81
- & projected_query,
82
- & self . projected_centroids [ ( i, ) ] ,
83
- ) )
84
- . into ( )
85
+ self . quantization . fscan_preprocess ( trans_vector, dis_v2)
85
86
} else {
86
- self . quantization
87
- . preprocess ( & O :: residual (
88
- & projected_query,
89
- & self . projected_centroids [ ( i, ) ] ,
90
- ) )
91
- . into ( )
87
+ self . quantization . preprocess ( trans_vector, dis_v2)
92
88
} ;
93
89
let start = self . offsets [ i] ;
94
90
let end = self . offsets [ i + 1 ] ;
@@ -116,6 +112,7 @@ fn from_nothing<O: Op>(
116
112
let RabitqIndexingOptions {
117
113
nlist,
118
114
spherical_centroids,
115
+ residual_quantization,
119
116
} = options. indexing . clone ( ) . unwrap_rabitq ( ) ;
120
117
let projection = {
121
118
use nalgebra:: { DMatrix , QR } ;
@@ -137,6 +134,7 @@ fn from_nothing<O: Op>(
137
134
}
138
135
projection
139
136
} ;
137
+ let is_residual = residual_quantization && O :: SUPPORT_RESIDUAL ;
140
138
rayon:: check ( ) ;
141
139
let samples = O :: sample ( collection, nlist) ;
142
140
rayon:: check ( ) ;
@@ -174,16 +172,30 @@ fn from_nothing<O: Op>(
174
172
let collection = RemappedCollection :: from_collection ( collection, remap) ;
175
173
rayon:: check ( ) ;
176
174
let storage = O :: Storage :: create ( path. as_ref ( ) . join ( "storage" ) , & collection) ;
177
- let quantization = Quantization :: create (
178
- path. as_ref ( ) . join ( "quantization" ) ,
179
- options. vector ,
180
- collection. len ( ) ,
181
- |vector| {
182
- let vector = O :: cast ( collection. vector ( vector) ) ;
183
- let target = k_means_lookup ( vector, & centroids) ;
184
- O :: proj ( & projection, & O :: residual ( vector, & centroids[ ( target, ) ] ) )
185
- } ,
186
- ) ;
175
+
176
+ let quantization = if is_residual {
177
+ Quantization :: create (
178
+ path. as_ref ( ) . join ( "quantization" ) ,
179
+ options. vector ,
180
+ collection. len ( ) ,
181
+ |vector| {
182
+ let vector = O :: cast ( collection. vector ( vector) ) ;
183
+ let target = k_means_lookup ( vector, & centroids) ;
184
+ O :: proj ( & projection, & O :: residual ( vector, & centroids[ ( target, ) ] ) )
185
+ } ,
186
+ )
187
+ } else {
188
+ Quantization :: create (
189
+ path. as_ref ( ) . join ( "quantization" ) ,
190
+ options. vector ,
191
+ collection. len ( ) ,
192
+ |vector| {
193
+ let vector = O :: cast ( collection. vector ( vector) ) ;
194
+ O :: proj ( & projection, vector)
195
+ } ,
196
+ )
197
+ } ;
198
+
187
199
let projected_centroids = Vec2 :: from_vec (
188
200
( centroids. shape_0 ( ) , centroids. shape_1 ( ) ) ,
189
201
( 0 ..centroids. shape_0 ( ) )
@@ -200,13 +212,15 @@ fn from_nothing<O: Op>(
200
212
projected_centroids,
201
213
) ;
202
214
let projection = Json :: create ( path. as_ref ( ) . join ( "projection" ) , projection) ;
215
+ let is_residual = Json :: create ( path. as_ref ( ) . join ( "is_residual" ) , is_residual) ;
203
216
Rabitq {
204
217
storage,
205
218
payloads,
206
219
offsets,
207
220
projected_centroids,
208
221
quantization,
209
222
projection,
223
+ is_residual,
210
224
}
211
225
}
212
226
@@ -217,17 +231,19 @@ fn open<O: Op>(path: impl AsRef<Path>) -> Rabitq<O> {
217
231
let offsets = Json :: open ( path. as_ref ( ) . join ( "offsets" ) ) ;
218
232
let projected_centroids = Json :: open ( path. as_ref ( ) . join ( "projected_centroids" ) ) ;
219
233
let projection = Json :: open ( path. as_ref ( ) . join ( "projection" ) ) ;
234
+ let is_residual = Json :: open ( path. as_ref ( ) . join ( "is_residual" ) ) ;
220
235
Rabitq {
221
236
storage,
222
237
quantization,
223
238
payloads,
224
239
offsets,
225
240
projected_centroids,
226
241
projection,
242
+ is_residual,
227
243
}
228
244
}
229
245
230
- fn select ( mut lists : Vec < ( f32 , usize ) > , n : usize ) -> Vec < ( f32 , usize ) > {
246
+ fn select < T > ( mut lists : Vec < ( f32 , T ) > , n : usize ) -> Vec < ( f32 , T ) > {
231
247
if lists. is_empty ( ) || n == 0 {
232
248
return Vec :: new ( ) ;
233
249
}
0 commit comments