Skip to content

Commit bb46189

Browse files
authored
feat: add metrics dot and cos (#566)
* feat: add metrics dot and cos Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * add option residual_quantization Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * deprecate residual except l2 Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix by comments Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> --------- Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
1 parent 1ed47d8 commit bb46189

File tree

5 files changed

+683
-310
lines changed

5 files changed

+683
-310
lines changed

crates/base/src/index.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ impl IndexOptions {
149149
}
150150
}
151151
IndexingOptions::Rabitq(_) => {
152-
if !matches!(self.vector.d, DistanceKind::L2) {
152+
if !matches!(self.vector.d, DistanceKind::L2 | DistanceKind::Dot) {
153153
return Err(ValidationError::new(
154-
"rabitq is not support for distance that is not l2",
154+
"rabitq is not support for distance that is not l2 or dot",
155155
));
156156
}
157157
if !matches!(self.vector.v, VectorKind::Vecf32) {
@@ -446,8 +446,10 @@ pub struct RabitqIndexingOptions {
446446
#[serde(default = "RabitqIndexingOptions::default_nlist")]
447447
#[validate(range(min = 1, max = 1_000_000))]
448448
pub nlist: u32,
449-
#[serde(default = "IvfIndexingOptions::default_spherical_centroids")]
449+
#[serde(default = "RabitqIndexingOptions::default_spherical_centroids")]
450450
pub spherical_centroids: bool,
451+
#[serde(default = "RabitqIndexingOptions::default_residual_quantization")]
452+
pub residual_quantization: bool,
451453
}
452454

453455
impl RabitqIndexingOptions {
@@ -457,13 +459,17 @@ impl RabitqIndexingOptions {
457459
fn default_spherical_centroids() -> bool {
458460
false
459461
}
462+
fn default_residual_quantization() -> bool {
463+
false
464+
}
460465
}
461466

462467
impl Default for RabitqIndexingOptions {
463468
fn default() -> Self {
464469
Self {
465470
nlist: Self::default_nlist(),
466471
spherical_centroids: Self::default_spherical_centroids(),
472+
residual_quantization: Self::default_residual_quantization(),
467473
}
468474
}
469475
}

crates/rabitq/src/lib.rs

+40-24
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub struct Rabitq<O: Op> {
3131
offsets: Json<Vec<u32>>,
3232
projected_centroids: Json<Vec2<f32>>,
3333
projection: Json<Vec<Vec<f32>>>,
34+
is_residual: Json<bool>,
3435
}
3536

3637
impl<O: Op> Rabitq<O> {
@@ -74,21 +75,16 @@ impl<O: Op> Rabitq<O> {
7475
opts.rabitq_nprobe as usize,
7576
);
7677
let mut heap = Vec::new();
77-
for &(_, i) in lists.iter() {
78+
for &(dis_v2, i) in lists.iter() {
79+
let trans_vector = if *self.is_residual {
80+
&O::residual(&projected_query, &self.projected_centroids[(i,)])
81+
} else {
82+
&projected_query
83+
};
7884
let preprocessed = if opts.rabitq_fast_scan {
79-
self.quantization
80-
.fscan_preprocess(&O::residual(
81-
&projected_query,
82-
&self.projected_centroids[(i,)],
83-
))
84-
.into()
85+
self.quantization.fscan_preprocess(trans_vector, dis_v2)
8586
} else {
86-
self.quantization
87-
.preprocess(&O::residual(
88-
&projected_query,
89-
&self.projected_centroids[(i,)],
90-
))
91-
.into()
87+
self.quantization.preprocess(trans_vector, dis_v2)
9288
};
9389
let start = self.offsets[i];
9490
let end = self.offsets[i + 1];
@@ -116,6 +112,7 @@ fn from_nothing<O: Op>(
116112
let RabitqIndexingOptions {
117113
nlist,
118114
spherical_centroids,
115+
residual_quantization,
119116
} = options.indexing.clone().unwrap_rabitq();
120117
let projection = {
121118
use nalgebra::{DMatrix, QR};
@@ -137,6 +134,7 @@ fn from_nothing<O: Op>(
137134
}
138135
projection
139136
};
137+
let is_residual = residual_quantization && O::SUPPORT_RESIDUAL;
140138
rayon::check();
141139
let samples = O::sample(collection, nlist);
142140
rayon::check();
@@ -174,16 +172,30 @@ fn from_nothing<O: Op>(
174172
let collection = RemappedCollection::from_collection(collection, remap);
175173
rayon::check();
176174
let storage = O::Storage::create(path.as_ref().join("storage"), &collection);
177-
let quantization = Quantization::create(
178-
path.as_ref().join("quantization"),
179-
options.vector,
180-
collection.len(),
181-
|vector| {
182-
let vector = O::cast(collection.vector(vector));
183-
let target = k_means_lookup(vector, &centroids);
184-
O::proj(&projection, &O::residual(vector, &centroids[(target,)]))
185-
},
186-
);
175+
176+
let quantization = if is_residual {
177+
Quantization::create(
178+
path.as_ref().join("quantization"),
179+
options.vector,
180+
collection.len(),
181+
|vector| {
182+
let vector = O::cast(collection.vector(vector));
183+
let target = k_means_lookup(vector, &centroids);
184+
O::proj(&projection, &O::residual(vector, &centroids[(target,)]))
185+
},
186+
)
187+
} else {
188+
Quantization::create(
189+
path.as_ref().join("quantization"),
190+
options.vector,
191+
collection.len(),
192+
|vector| {
193+
let vector = O::cast(collection.vector(vector));
194+
O::proj(&projection, vector)
195+
},
196+
)
197+
};
198+
187199
let projected_centroids = Vec2::from_vec(
188200
(centroids.shape_0(), centroids.shape_1()),
189201
(0..centroids.shape_0())
@@ -200,13 +212,15 @@ fn from_nothing<O: Op>(
200212
projected_centroids,
201213
);
202214
let projection = Json::create(path.as_ref().join("projection"), projection);
215+
let is_residual = Json::create(path.as_ref().join("is_residual"), is_residual);
203216
Rabitq {
204217
storage,
205218
payloads,
206219
offsets,
207220
projected_centroids,
208221
quantization,
209222
projection,
223+
is_residual,
210224
}
211225
}
212226

@@ -217,17 +231,19 @@ fn open<O: Op>(path: impl AsRef<Path>) -> Rabitq<O> {
217231
let offsets = Json::open(path.as_ref().join("offsets"));
218232
let projected_centroids = Json::open(path.as_ref().join("projected_centroids"));
219233
let projection = Json::open(path.as_ref().join("projection"));
234+
let is_residual = Json::open(path.as_ref().join("is_residual"));
220235
Rabitq {
221236
storage,
222237
quantization,
223238
payloads,
224239
offsets,
225240
projected_centroids,
226241
projection,
242+
is_residual,
227243
}
228244
}
229245

230-
fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> {
246+
fn select<T>(mut lists: Vec<(f32, T)>, n: usize) -> Vec<(f32, T)> {
231247
if lists.is_empty() || n == 0 {
232248
return Vec::new();
233249
}

0 commit comments

Comments
 (0)