Skip to content

Commit 8c9080d

Browse files
authored
fix: rabitq (#556)
Signed-off-by: usamoi <usamoi@outlook.com>
1 parent 487ea6b commit 8c9080d

29 files changed

+986
-627
lines changed

Cargo.lock

+20-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/base/src/index.rs

+46-22
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,6 @@ impl IndexOptions {
118118
}
119119
Ok(())
120120
}
121-
QuantizationOptions::RaBitQ(_) => {
122-
if !matches!(self.vector.v, VectorKind::Vecf32) {
123-
return Err(ValidationError::new(
124-
"scalar quantization or product quantization is not support for `vector`",
125-
));
126-
}
127-
Ok(())
128-
}
129121
}
130122
}
131123
fn validate_self(&self) -> Result<(), ValidationError> {
@@ -156,6 +148,18 @@ impl IndexOptions {
156148
));
157149
}
158150
}
151+
IndexingOptions::Rabitq(_) => {
152+
if !matches!(self.vector.d, DistanceKind::L2) {
153+
return Err(ValidationError::new(
154+
"inverted_index is not support for distance that is not l2",
155+
));
156+
}
157+
if !matches!(self.vector.v, VectorKind::Vecf32) {
158+
return Err(ValidationError::new(
159+
"inverted_index is not support for vectors that are not vector",
160+
));
161+
}
162+
}
159163
}
160164
Ok(())
161165
}
@@ -289,6 +293,7 @@ pub enum IndexingOptions {
289293
Ivf(IvfIndexingOptions),
290294
Hnsw(HnswIndexingOptions),
291295
InvertedIndex(InvertedIndexingOptions),
296+
Rabitq(RabitqIndexingOptions),
292297
}
293298

294299
impl IndexingOptions {
@@ -310,6 +315,12 @@ impl IndexingOptions {
310315
};
311316
x
312317
}
318+
pub fn unwrap_rabitq(self) -> RabitqIndexingOptions {
319+
let IndexingOptions::Rabitq(x) = self else {
320+
unreachable!()
321+
};
322+
x
323+
}
313324
}
314325

315326
impl Default for IndexingOptions {
@@ -324,7 +335,8 @@ impl Validate for IndexingOptions {
324335
Self::Flat(x) => x.validate(),
325336
Self::Ivf(x) => x.validate(),
326337
Self::Hnsw(x) => x.validate(),
327-
Self::InvertedIndex(_) => Ok(()),
338+
Self::InvertedIndex(x) => x.validate(),
339+
Self::Rabitq(x) => x.validate(),
328340
}
329341
}
330342
}
@@ -428,15 +440,35 @@ impl Default for HnswIndexingOptions {
428440
}
429441
}
430442

443+
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
444+
#[serde(deny_unknown_fields)]
445+
pub struct RabitqIndexingOptions {
446+
#[serde(default = "RabitqIndexingOptions::default_nlist")]
447+
#[validate(range(min = 1, max = 1_000_000))]
448+
pub nlist: u32,
449+
}
450+
451+
impl RabitqIndexingOptions {
452+
fn default_nlist() -> u32 {
453+
1000
454+
}
455+
}
456+
457+
impl Default for RabitqIndexingOptions {
458+
fn default() -> Self {
459+
Self {
460+
nlist: Self::default_nlist(),
461+
}
462+
}
463+
}
464+
431465
#[derive(Debug, Clone, Serialize, Deserialize)]
432466
#[serde(deny_unknown_fields)]
433467
#[serde(rename_all = "snake_case")]
434468
pub enum QuantizationOptions {
435469
Trivial(TrivialQuantizationOptions),
436470
Scalar(ScalarQuantizationOptions),
437471
Product(ProductQuantizationOptions),
438-
#[serde(rename = "rabitq")]
439-
RaBitQ(RaBitQuantizationOptions),
440472
}
441473

442474
impl Validate for QuantizationOptions {
@@ -445,7 +477,6 @@ impl Validate for QuantizationOptions {
445477
Self::Trivial(x) => x.validate(),
446478
Self::Scalar(x) => x.validate(),
447479
Self::Product(x) => x.validate(),
448-
Self::RaBitQ(x) => x.validate(),
449480
}
450481
}
451482
}
@@ -466,16 +497,6 @@ impl Default for TrivialQuantizationOptions {
466497
}
467498
}
468499

469-
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
470-
#[serde(deny_unknown_fields)]
471-
pub struct RaBitQuantizationOptions {}
472-
473-
impl Default for RaBitQuantizationOptions {
474-
fn default() -> Self {
475-
Self {}
476-
}
477-
}
478-
479500
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
480501
#[serde(deny_unknown_fields)]
481502
#[validate(schema(function = "Self::validate_self"))]
@@ -558,6 +579,9 @@ pub struct SearchOptions {
558579
#[validate(range(min = 1, max = 65535))]
559580
pub hnsw_ef_search: u32,
560581
#[validate(range(min = 1, max = 65535))]
582+
pub rabitq_nprobe: u32,
583+
pub rabitq_fast_scan: bool,
584+
#[validate(range(min = 1, max = 65535))]
561585
pub diskann_ef_search: u32,
562586
}
563587

crates/base/src/vector/mod.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@ pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static {
3232
}
3333

3434
pub trait VectorBorrowed: Copy + PartialEq + PartialOrd {
35+
// it will be depcrated
3536
type Scalar: ScalarLike;
37+
38+
// it will be depcrated
39+
fn to_vec(&self) -> Vec<Self::Scalar>;
40+
3641
type Owned: VectorOwned<Scalar = Self::Scalar>;
3742

3843
fn own(&self) -> Self::Owned;
3944

4045
fn dims(&self) -> u32;
4146

42-
fn to_vec(&self) -> Vec<Self::Scalar>;
43-
4447
fn norm(&self) -> F32;
4548

4649
fn operator_dot(self, rhs: Self) -> F32;

crates/base/src/vector/svecf32.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,12 @@ impl<'a> VectorBorrowed for SVecf32Borrowed<'a> {
358358
let dims = end - start;
359359
let s = self.indexes.partition_point(|&x| x < start);
360360
let e = self.indexes.partition_point(|&x| x < end);
361-
let indexes = self.indexes[s..e].iter().map(|x| x - start);
362-
let values = &self.values[s..e];
363-
Self::Owned::new_checked(dims, indexes.collect::<Vec<_>>(), values.to_vec())
361+
let indexes = self.indexes[s..e]
362+
.iter()
363+
.map(|x| x - start)
364+
.collect::<Vec<_>>();
365+
let values = self.values[s..e].to_vec();
366+
Self::Owned::new_checked(dims, indexes, values)
364367
}
365368
}
366369

crates/cli/src/args.rs

+2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ impl QueryArguments {
142142
flat_pq_fast_scan: false,
143143
ivf_sq_fast_scan: false,
144144
ivf_pq_fast_scan: false,
145+
rabitq_fast_scan: true,
146+
rabitq_nprobe: self.probe,
145147
}
146148
}
147149
}

crates/common/src/sample.rs

+19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use crate::vec2::Vec2;
22
use base::operator::{Borrowed, Operator, Owned, Scalar};
3+
use base::scalar::ScalarLike;
4+
use base::scalar::F32;
35
use base::search::Vectors;
46
use base::vector::VectorBorrowed;
57
use base::vector::VectorOwned;
@@ -18,6 +20,23 @@ pub fn sample<O: Operator>(vectors: &impl Vectors<O>) -> Vec2<Scalar<O>> {
1820
samples
1921
}
2022

23+
pub fn sample_cast<O: Operator>(vectors: &impl Vectors<O>) -> Vec2<F32> {
24+
let n = vectors.len();
25+
let m = std::cmp::min(SAMPLES as u32, n);
26+
let f = super::rand::sample_u32(&mut rand::thread_rng(), n, m);
27+
let mut samples = Vec2::zeros((m as usize, vectors.dims() as usize));
28+
for i in 0..m {
29+
let v = vectors
30+
.vector(f[i as usize] as u32)
31+
.to_vec()
32+
.into_iter()
33+
.map(|x| x.to_f())
34+
.collect::<Vec<_>>();
35+
samples[(i as usize,)].copy_from_slice(&v);
36+
}
37+
samples
38+
}
39+
2140
pub fn sample_subvector_transform<O: Operator>(
2241
vectors: &impl Vectors<O>,
2342
s: usize,

crates/index/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ flat = { path = "../flat" }
3030
hnsw = { path = "../hnsw" }
3131
inverted = { path = "../inverted" }
3232
ivf = { path = "../ivf" }
33+
rabitq = { path = "../rabitq" }
3334

3435
[lints]
3536
workspace = true

crates/index/src/indexing/sealed.rs

+8
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ use flat::Flat;
66
use hnsw::Hnsw;
77
use inverted::InvertedIndex;
88
use ivf::Ivf;
9+
use rabitq::Rabitq;
910
use std::path::Path;
1011

1112
pub enum SealedIndexing<O: Op> {
1213
Flat(Flat<O>),
1314
Ivf(Ivf<O>),
1415
Hnsw(Hnsw<O>),
1516
InvertedIndex(InvertedIndex<O>),
17+
Rabitq(Rabitq<O>),
1618
}
1719

1820
impl<O: Op> SealedIndexing<O> {
@@ -28,6 +30,7 @@ impl<O: Op> SealedIndexing<O> {
2830
IndexingOptions::InvertedIndex(_) => {
2931
Self::InvertedIndex(InvertedIndex::create(path, options, source))
3032
}
33+
IndexingOptions::Rabitq(_) => Self::Rabitq(Rabitq::create(path, options, source)),
3134
}
3235
}
3336

@@ -37,6 +40,7 @@ impl<O: Op> SealedIndexing<O> {
3740
IndexingOptions::Ivf(_) => Self::Ivf(Ivf::open(path)),
3841
IndexingOptions::Hnsw(_) => Self::Hnsw(Hnsw::open(path)),
3942
IndexingOptions::InvertedIndex(_) => Self::InvertedIndex(InvertedIndex::open(path)),
43+
IndexingOptions::Rabitq(_) => Self::Rabitq(Rabitq::open(path)),
4044
}
4145
}
4246

@@ -50,6 +54,7 @@ impl<O: Op> SealedIndexing<O> {
5054
SealedIndexing::Ivf(x) => x.vbase(vector, opts),
5155
SealedIndexing::Hnsw(x) => x.vbase(vector, opts),
5256
SealedIndexing::InvertedIndex(x) => x.vbase(vector, opts),
57+
SealedIndexing::Rabitq(x) => x.vbase(vector, opts),
5358
}
5459
}
5560

@@ -59,6 +64,7 @@ impl<O: Op> SealedIndexing<O> {
5964
SealedIndexing::Ivf(x) => x.len(),
6065
SealedIndexing::Hnsw(x) => x.len(),
6166
SealedIndexing::InvertedIndex(x) => x.len(),
67+
SealedIndexing::Rabitq(x) => x.len(),
6268
}
6369
}
6470

@@ -68,6 +74,7 @@ impl<O: Op> SealedIndexing<O> {
6874
SealedIndexing::Ivf(x) => x.vector(i),
6975
SealedIndexing::Hnsw(x) => x.vector(i),
7076
SealedIndexing::InvertedIndex(x) => x.vector(i),
77+
SealedIndexing::Rabitq(x) => x.vector(i),
7178
}
7279
}
7380

@@ -77,6 +84,7 @@ impl<O: Op> SealedIndexing<O> {
7784
SealedIndexing::Ivf(x) => x.payload(i),
7885
SealedIndexing::Hnsw(x) => x.payload(i),
7986
SealedIndexing::InvertedIndex(x) => x.payload(i),
87+
SealedIndexing::Rabitq(x) => x.payload(i),
8088
}
8189
}
8290
}

crates/index/src/lib.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use inverted::operator::OperatorInvertedIndex;
2929
use ivf::operator::OperatorIvf;
3030
use parking_lot::Mutex;
3131
use quantization::operator::OperatorQuantization;
32+
use rabitq::operator::OperatorRabitq;
3233
use serde::{Deserialize, Serialize};
3334
use std::collections::HashMap;
3435
use std::collections::HashSet;
@@ -43,12 +44,22 @@ use thiserror::Error;
4344
use validator::Validate;
4445

4546
pub trait Op:
46-
Operator + OperatorQuantization + OperatorStorage + OperatorIvf + OperatorInvertedIndex
47+
Operator
48+
+ OperatorQuantization
49+
+ OperatorStorage
50+
+ OperatorIvf
51+
+ OperatorInvertedIndex
52+
+ OperatorRabitq
4753
{
4854
}
4955

5056
impl<
51-
T: Operator + OperatorQuantization + OperatorStorage + OperatorIvf + OperatorInvertedIndex,
57+
T: Operator
58+
+ OperatorQuantization
59+
+ OperatorStorage
60+
+ OperatorIvf
61+
+ OperatorInvertedIndex
62+
+ OperatorRabitq,
5263
> Op for T
5364
{
5465
}

crates/index/src/segment/sealed.rs

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ impl<O: Op> SealedSegment<O> {
122122
SealedIndexing::Ivf(x) => x,
123123
SealedIndexing::Hnsw(x) => x,
124124
SealedIndexing::InvertedIndex(x) => x,
125+
SealedIndexing::Rabitq(x) => x,
125126
}
126127
}
127128
}

0 commit comments

Comments
 (0)