diff --git a/Cargo.lock b/Cargo.lock index 3d048289e..bfedc06ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1008,12 +1008,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "fastdivide" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59668941c55e5c186b8b58c391629af56774ec768f73c08bbcd56f09348eb00b" - [[package]] name = "fastrand" version = "1.9.0" @@ -1540,6 +1534,7 @@ dependencies = [ "log", "parking_lot", "quantization", + "rabitq", "rand", "serde", "serde_json", @@ -2285,10 +2280,8 @@ dependencies = [ "base", "common", "detect", - "fastdivide", "k_means", "log", - "nalgebra", "num-traits", "rand", "serde", @@ -2353,6 +2346,25 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rabitq" +version = "0.0.0" +dependencies = [ + "base", + "common", + "detect", + "k_means", + "log", + "nalgebra", + "num-traits", + "quantization", + "rand", + "serde", + "serde_json", + "stoppable_rayon", + "storage", +] + [[package]] name = "radium" version = "0.7.0" diff --git a/crates/base/src/index.rs b/crates/base/src/index.rs index 8e1359cb1..554e5fc12 100644 --- a/crates/base/src/index.rs +++ b/crates/base/src/index.rs @@ -118,14 +118,6 @@ impl IndexOptions { } Ok(()) } - QuantizationOptions::RaBitQ(_) => { - if !matches!(self.vector.v, VectorKind::Vecf32) { - return Err(ValidationError::new( - "scalar quantization or product quantization is not support for `vector`", - )); - } - Ok(()) - } } } fn validate_self(&self) -> Result<(), ValidationError> { @@ -156,6 +148,18 @@ impl IndexOptions { )); } } + IndexingOptions::Rabitq(_) => { + if !matches!(self.vector.d, DistanceKind::L2) { + return Err(ValidationError::new( + "inverted_index is not support for distance that is not l2", + )); + } + if !matches!(self.vector.v, VectorKind::Vecf32) { + return Err(ValidationError::new( + "inverted_index is not support for vectors that are not vector", + )); + } + } } Ok(()) } @@ -289,6 +293,7 @@ pub enum IndexingOptions { Ivf(IvfIndexingOptions), Hnsw(HnswIndexingOptions), InvertedIndex(InvertedIndexingOptions), + Rabitq(RabitqIndexingOptions), } impl IndexingOptions { @@ -310,6 +315,12 @@ impl IndexingOptions { }; x } + pub fn unwrap_rabitq(self) -> RabitqIndexingOptions { + let IndexingOptions::Rabitq(x) = self else { + unreachable!() + }; + x + } } impl Default for IndexingOptions { @@ -324,7 +335,8 @@ impl Validate for IndexingOptions { Self::Flat(x) => x.validate(), Self::Ivf(x) => x.validate(), Self::Hnsw(x) => x.validate(), - Self::InvertedIndex(_) => Ok(()), + Self::InvertedIndex(x) => x.validate(), + Self::Rabitq(x) => x.validate(), } } } @@ -428,6 +440,28 @@ impl Default for HnswIndexingOptions { } } +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct RabitqIndexingOptions { + #[serde(default = "RabitqIndexingOptions::default_nlist")] + #[validate(range(min = 1, max = 1_000_000))] + pub nlist: u32, +} + +impl RabitqIndexingOptions { + fn default_nlist() -> u32 { + 1000 + } +} + +impl Default for RabitqIndexingOptions { + fn default() -> Self { + Self { + nlist: Self::default_nlist(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] #[serde(rename_all = "snake_case")] @@ -435,8 +469,6 @@ pub enum QuantizationOptions { Trivial(TrivialQuantizationOptions), Scalar(ScalarQuantizationOptions), Product(ProductQuantizationOptions), - #[serde(rename = "rabitq")] - RaBitQ(RaBitQuantizationOptions), } impl Validate for QuantizationOptions { @@ -445,7 +477,6 @@ impl Validate for QuantizationOptions { Self::Trivial(x) => x.validate(), Self::Scalar(x) => x.validate(), Self::Product(x) => x.validate(), - Self::RaBitQ(x) => x.validate(), } } } @@ -466,16 +497,6 @@ impl Default for TrivialQuantizationOptions { } } -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct RaBitQuantizationOptions {} - -impl Default for RaBitQuantizationOptions { - fn default() -> Self { - Self {} - } -} - #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] #[validate(schema(function = "Self::validate_self"))] @@ -558,6 +579,9 @@ pub struct SearchOptions { #[validate(range(min = 1, max = 65535))] pub hnsw_ef_search: u32, #[validate(range(min = 1, max = 65535))] + pub rabitq_nprobe: u32, + pub rabitq_fast_scan: bool, + #[validate(range(min = 1, max = 65535))] pub diskann_ef_search: u32, } diff --git a/crates/base/src/vector/mod.rs b/crates/base/src/vector/mod.rs index f8f9920fc..628f6a1ae 100644 --- a/crates/base/src/vector/mod.rs +++ b/crates/base/src/vector/mod.rs @@ -32,15 +32,18 @@ pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static { } pub trait VectorBorrowed: Copy + PartialEq + PartialOrd { + // it will be depcrated type Scalar: ScalarLike; + + // it will be depcrated + fn to_vec(&self) -> Vec; + type Owned: VectorOwned; fn own(&self) -> Self::Owned; fn dims(&self) -> u32; - fn to_vec(&self) -> Vec; - fn norm(&self) -> F32; fn operator_dot(self, rhs: Self) -> F32; diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 748c4d25e..cf11ccb3b 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -358,9 +358,12 @@ impl<'a> VectorBorrowed for SVecf32Borrowed<'a> { let dims = end - start; let s = self.indexes.partition_point(|&x| x < start); let e = self.indexes.partition_point(|&x| x < end); - let indexes = self.indexes[s..e].iter().map(|x| x - start); - let values = &self.values[s..e]; - Self::Owned::new_checked(dims, indexes.collect::>(), values.to_vec()) + let indexes = self.indexes[s..e] + .iter() + .map(|x| x - start) + .collect::>(); + let values = self.values[s..e].to_vec(); + Self::Owned::new_checked(dims, indexes, values) } } diff --git a/crates/cli/src/args.rs b/crates/cli/src/args.rs index 521b3fab9..e0f8fb7d3 100644 --- a/crates/cli/src/args.rs +++ b/crates/cli/src/args.rs @@ -142,6 +142,8 @@ impl QueryArguments { flat_pq_fast_scan: false, ivf_sq_fast_scan: false, ivf_pq_fast_scan: false, + rabitq_fast_scan: true, + rabitq_nprobe: self.probe, } } } diff --git a/crates/common/src/sample.rs b/crates/common/src/sample.rs index 5b6b0cd31..5c5852547 100644 --- a/crates/common/src/sample.rs +++ b/crates/common/src/sample.rs @@ -1,5 +1,7 @@ use crate::vec2::Vec2; use base::operator::{Borrowed, Operator, Owned, Scalar}; +use base::scalar::ScalarLike; +use base::scalar::F32; use base::search::Vectors; use base::vector::VectorBorrowed; use base::vector::VectorOwned; @@ -18,6 +20,23 @@ pub fn sample(vectors: &impl Vectors) -> Vec2> { samples } +pub fn sample_cast(vectors: &impl Vectors) -> Vec2 { + let n = vectors.len(); + let m = std::cmp::min(SAMPLES as u32, n); + let f = super::rand::sample_u32(&mut rand::thread_rng(), n, m); + let mut samples = Vec2::zeros((m as usize, vectors.dims() as usize)); + for i in 0..m { + let v = vectors + .vector(f[i as usize] as u32) + .to_vec() + .into_iter() + .map(|x| x.to_f()) + .collect::>(); + samples[(i as usize,)].copy_from_slice(&v); + } + samples +} + pub fn sample_subvector_transform( vectors: &impl Vectors, s: usize, diff --git a/crates/index/Cargo.toml b/crates/index/Cargo.toml index 6f3cc5231..29d9faa58 100644 --- a/crates/index/Cargo.toml +++ b/crates/index/Cargo.toml @@ -30,6 +30,7 @@ flat = { path = "../flat" } hnsw = { path = "../hnsw" } inverted = { path = "../inverted" } ivf = { path = "../ivf" } +rabitq = { path = "../rabitq" } [lints] workspace = true diff --git a/crates/index/src/indexing/sealed.rs b/crates/index/src/indexing/sealed.rs index 27aee6dff..e11094bd4 100644 --- a/crates/index/src/indexing/sealed.rs +++ b/crates/index/src/indexing/sealed.rs @@ -6,6 +6,7 @@ use flat::Flat; use hnsw::Hnsw; use inverted::InvertedIndex; use ivf::Ivf; +use rabitq::Rabitq; use std::path::Path; pub enum SealedIndexing { @@ -13,6 +14,7 @@ pub enum SealedIndexing { Ivf(Ivf), Hnsw(Hnsw), InvertedIndex(InvertedIndex), + Rabitq(Rabitq), } impl SealedIndexing { @@ -28,6 +30,7 @@ impl SealedIndexing { IndexingOptions::InvertedIndex(_) => { Self::InvertedIndex(InvertedIndex::create(path, options, source)) } + IndexingOptions::Rabitq(_) => Self::Rabitq(Rabitq::create(path, options, source)), } } @@ -37,6 +40,7 @@ impl SealedIndexing { IndexingOptions::Ivf(_) => Self::Ivf(Ivf::open(path)), IndexingOptions::Hnsw(_) => Self::Hnsw(Hnsw::open(path)), IndexingOptions::InvertedIndex(_) => Self::InvertedIndex(InvertedIndex::open(path)), + IndexingOptions::Rabitq(_) => Self::Rabitq(Rabitq::open(path)), } } @@ -50,6 +54,7 @@ impl SealedIndexing { SealedIndexing::Ivf(x) => x.vbase(vector, opts), SealedIndexing::Hnsw(x) => x.vbase(vector, opts), SealedIndexing::InvertedIndex(x) => x.vbase(vector, opts), + SealedIndexing::Rabitq(x) => x.vbase(vector, opts), } } @@ -59,6 +64,7 @@ impl SealedIndexing { SealedIndexing::Ivf(x) => x.len(), SealedIndexing::Hnsw(x) => x.len(), SealedIndexing::InvertedIndex(x) => x.len(), + SealedIndexing::Rabitq(x) => x.len(), } } @@ -68,6 +74,7 @@ impl SealedIndexing { SealedIndexing::Ivf(x) => x.vector(i), SealedIndexing::Hnsw(x) => x.vector(i), SealedIndexing::InvertedIndex(x) => x.vector(i), + SealedIndexing::Rabitq(x) => x.vector(i), } } @@ -77,6 +84,7 @@ impl SealedIndexing { SealedIndexing::Ivf(x) => x.payload(i), SealedIndexing::Hnsw(x) => x.payload(i), SealedIndexing::InvertedIndex(x) => x.payload(i), + SealedIndexing::Rabitq(x) => x.payload(i), } } } diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index c60202eec..6c450c6c0 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -29,6 +29,7 @@ use inverted::operator::OperatorInvertedIndex; use ivf::operator::OperatorIvf; use parking_lot::Mutex; use quantization::operator::OperatorQuantization; +use rabitq::operator::OperatorRabitq; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::collections::HashSet; @@ -43,12 +44,22 @@ use thiserror::Error; use validator::Validate; pub trait Op: - Operator + OperatorQuantization + OperatorStorage + OperatorIvf + OperatorInvertedIndex + Operator + + OperatorQuantization + + OperatorStorage + + OperatorIvf + + OperatorInvertedIndex + + OperatorRabitq { } impl< - T: Operator + OperatorQuantization + OperatorStorage + OperatorIvf + OperatorInvertedIndex, + T: Operator + + OperatorQuantization + + OperatorStorage + + OperatorIvf + + OperatorInvertedIndex + + OperatorRabitq, > Op for T { } diff --git a/crates/index/src/segment/sealed.rs b/crates/index/src/segment/sealed.rs index a77344e85..4fa8ee476 100644 --- a/crates/index/src/segment/sealed.rs +++ b/crates/index/src/segment/sealed.rs @@ -122,6 +122,7 @@ impl SealedSegment { SealedIndexing::Ivf(x) => x, SealedIndexing::Hnsw(x) => x, SealedIndexing::InvertedIndex(x) => x, + SealedIndexing::Rabitq(x) => x, } } } diff --git a/crates/ivf/src/lib.rs b/crates/ivf/src/lib.rs index 4958ffb0c..49407f1a3 100644 --- a/crates/ivf/src/lib.rs +++ b/crates/ivf/src/lib.rs @@ -2,7 +2,6 @@ #![allow(clippy::needless_range_loop)] pub mod ivf_naive; -pub mod ivf_projection; pub mod ivf_residual; pub mod operator; @@ -12,14 +11,12 @@ use base::index::*; use base::operator::*; use base::search::*; use common::variants::variants; -use ivf_projection::IvfProjection; use ivf_residual::IvfResidual; use std::path::Path; pub enum Ivf { Naive(IvfNaive), Residual(IvfResidual), - Projection(IvfProjection), } impl Ivf { @@ -30,13 +27,7 @@ impl Ivf { .. } = options.indexing.clone().unwrap_ivf(); std::fs::create_dir(path.as_ref()).unwrap(); - let this = if matches!(quantization_options, QuantizationOptions::RaBitQ(_)) { - Self::Projection(IvfProjection::create( - path.as_ref().join("ivf_projection"), - options, - source, - )) - } else if !residual_quantization + let this = if !residual_quantization || matches!(quantization_options, QuantizationOptions::Trivial(_)) || !O::RESIDUAL { @@ -56,15 +47,9 @@ impl Ivf { } pub fn open(path: impl AsRef) -> Self { - match variants( - path.as_ref(), - ["ivf_naive", "ivf_residual", "ivf_projection"], - ) { + match variants(path.as_ref(), ["ivf_naive", "ivf_residual"]) { "ivf_naive" => Self::Naive(IvfNaive::open(path.as_ref().join("ivf_naive"))), "ivf_residual" => Self::Residual(IvfResidual::open(path.as_ref().join("ivf_residual"))), - "ivf_projection" => { - Self::Projection(IvfProjection::open(path.as_ref().join("ivf_projection"))) - } _ => unreachable!(), } } @@ -73,7 +58,6 @@ impl Ivf { match self { Ivf::Naive(x) => x.len(), Ivf::Residual(x) => x.len(), - Ivf::Projection(x) => x.len(), } } @@ -81,7 +65,6 @@ impl Ivf { match self { Ivf::Naive(x) => x.vector(i), Ivf::Residual(x) => x.vector(i), - Ivf::Projection(x) => x.vector(i), } } @@ -89,7 +72,6 @@ impl Ivf { match self { Ivf::Naive(x) => x.payload(i), Ivf::Residual(x) => x.payload(i), - Ivf::Projection(x) => x.payload(i), } } @@ -101,7 +83,6 @@ impl Ivf { match self { Ivf::Naive(x) => x.vbase(vector, opts), Ivf::Residual(x) => x.vbase(vector, opts), - Ivf::Projection(x) => x.vbase(vector, opts), } } } diff --git a/crates/ivf/src/operator.rs b/crates/ivf/src/operator.rs index 55ebc7b38..01abf8d3c 100644 --- a/crates/ivf/src/operator.rs +++ b/crates/ivf/src/operator.rs @@ -7,7 +7,6 @@ use storage::OperatorStorage; pub trait OperatorIvf: OperatorQuantization + OperatorStorage { const RESIDUAL: bool; fn residual(lhs: Borrowed<'_, Self>, rhs: &[Scalar]) -> Owned; - fn residual_dense(lhs: &[Scalar], rhs: &[Scalar]) -> Owned; } impl OperatorIvf for BVectorDot { @@ -15,9 +14,6 @@ impl OperatorIvf for BVectorDot { fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Scalar]) -> Owned { unimplemented!() } - fn residual_dense(_lhs: &[Scalar], _rhs: &[Scalar]) -> Owned { - unimplemented!() - } } impl OperatorIvf for BVectorJaccard { @@ -25,9 +21,6 @@ impl OperatorIvf for BVectorJaccard { fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Scalar]) -> Owned { unimplemented!() } - fn residual_dense(_lhs: &[Scalar], _rhs: &[Scalar]) -> Owned { - unimplemented!() - } } impl OperatorIvf for BVectorHamming { @@ -35,9 +28,6 @@ impl OperatorIvf for BVectorHamming { fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Scalar]) -> Owned { unimplemented!() } - fn residual_dense(_lhs: &[Scalar], _rhs: &[Scalar]) -> Owned { - unimplemented!() - } } impl OperatorIvf for SVecf32Dot { @@ -45,9 +35,6 @@ impl OperatorIvf for SVecf32Dot { fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Scalar]) -> Owned { unimplemented!() } - fn residual_dense(_lhs: &[Scalar], _rhs: &[Scalar]) -> Owned { - unimplemented!() - } } impl OperatorIvf for SVecf32L2 { @@ -76,9 +63,6 @@ impl OperatorIvf for SVecf32L2 { } SVecf32Owned::new(lhs.dims(), indexes, values) } - fn residual_dense(_lhs: &[Scalar], _rhs: &[Scalar]) -> Owned { - unimplemented!() - } } impl OperatorIvf for Vecf32Dot { @@ -86,9 +70,6 @@ impl OperatorIvf for Vecf32Dot { fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Scalar]) -> Owned { unimplemented!() } - fn residual_dense(_lhs: &[Scalar], _rhs: &[Scalar]) -> Owned { - unimplemented!() - } } impl OperatorIvf for Vecf32L2 { @@ -96,13 +77,6 @@ impl OperatorIvf for Vecf32L2 { fn residual(lhs: Borrowed<'_, Self>, rhs: &[Scalar]) -> Owned { lhs.operator_minus(Vecf32Borrowed::new(rhs)) } - fn residual_dense(lhs: &[Scalar], rhs: &[Scalar]) -> Owned { - let mut res = vec![Scalar::::zero(); lhs.len()]; - for i in 0..lhs.len() { - res[i] = lhs[i] - rhs[i]; - } - Vecf32Owned::new(res) - } } impl OperatorIvf for Vecf16Dot { @@ -110,9 +84,6 @@ impl OperatorIvf for Vecf16Dot { fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Scalar]) -> Owned { unimplemented!() } - fn residual_dense(_lhs: &[Scalar], _rhs: &[Scalar]) -> Owned { - unimplemented!() - } } impl OperatorIvf for Vecf16L2 { @@ -120,11 +91,4 @@ impl OperatorIvf for Vecf16L2 { fn residual(lhs: Borrowed<'_, Self>, rhs: &[Scalar]) -> Owned { lhs.operator_minus(Vecf16Borrowed::new(rhs)) } - fn residual_dense(lhs: &[Scalar], rhs: &[Scalar]) -> Owned { - let mut res = vec![Scalar::::zero(); lhs.len()]; - for i in 0..lhs.len() { - res[i] = lhs[i] - rhs[i]; - } - Vecf16Owned::new(res) - } } diff --git a/crates/quantization/Cargo.toml b/crates/quantization/Cargo.toml index 7a709c370..f3abcdfda 100644 --- a/crates/quantization/Cargo.toml +++ b/crates/quantization/Cargo.toml @@ -4,8 +4,6 @@ version.workspace = true edition.workspace = true [dependencies] -fastdivide = "0.4.1" - log.workspace = true num-traits.workspace = true rand.workspace = true @@ -16,7 +14,6 @@ base = { path = "../base" } common = { path = "../common" } detect = { path = "../detect" } k_means = { path = "../k_means" } -nalgebra = { version = "0.33.0", features = ["debug"] } stoppable_rayon = { path = "../stoppable_rayon" } [lints] diff --git a/crates/quantization/src/lib.rs b/crates/quantization/src/lib.rs index 7d9b0b5ad..4e38006c0 100644 --- a/crates/quantization/src/lib.rs +++ b/crates/quantization/src/lib.rs @@ -9,14 +9,12 @@ pub mod fast_scan; pub mod operator; pub mod product; pub mod quantize; -pub mod rabitq; pub mod reranker; pub mod scalar; pub mod trivial; -mod utils; +pub mod utils; use self::product::ProductQuantizer; -use self::rabitq::RaBitQuantizer; use self::scalar::ScalarQuantizer; use crate::operator::OperatorQuantization; use base::index::*; @@ -39,7 +37,6 @@ pub enum Quantizer { Trivial(TrivialQuantizer), Scalar(ScalarQuantizer), Product(ProductQuantizer), - RaBitQ(RaBitQuantizer), } impl Quantizer { @@ -69,19 +66,12 @@ impl Quantizer { vectors, transform, )), - RaBitQ(rabitq_quantization_options) => Self::RaBitQ(RaBitQuantizer::train( - vector_options, - rabitq_quantization_options, - vectors, - transform, - )), } } } pub enum QuantizationPreprocessed { Trivial(O::TrivialQuantizationPreprocessed), - RaBit(O::RabitQuantizationPreprocessed), Scalar(O::QuantizationPreprocessed), Product(O::QuantizationPreprocessed), } @@ -161,9 +151,6 @@ impl Quantization { _ => unreachable!(), } })), - Quantizer::RaBitQ(_) => { - Box::new(std::iter::empty()) as Box> - } } }); let packed_codes = MmapArray::create( @@ -204,9 +191,6 @@ impl Quantization { } _ => Box::new(std::iter::empty()) as Box>, }, - Quantizer::RaBitQ(_x) => { - Box::new(std::iter::empty()) as Box> - } }, ); let meta = MmapArray::create( @@ -221,9 +205,6 @@ impl Quantization { Quantizer::Product(_) => { Box::new(std::iter::empty()) as Box> } - Quantizer::RaBitQ(_) => { - Box::new(std::iter::empty()) as Box> - } }, ); Self { @@ -252,25 +233,6 @@ impl Quantization { Quantizer::Trivial(x) => QuantizationPreprocessed::Trivial(x.preprocess(lhs)), Quantizer::Scalar(x) => QuantizationPreprocessed::Scalar(x.preprocess(lhs)), Quantizer::Product(x) => QuantizationPreprocessed::Product(x.preprocess(lhs)), - _ => unreachable!(), - } - } - - pub fn project(&self, lhs: &[Scalar]) -> Vec> { - match &*self.train { - Quantizer::RaBitQ(x) => x.project(lhs), - _ => unreachable!(), - } - } - - pub fn projection_preprocess( - &self, - lhs: Borrowed<'_, O>, - distance: F32, - ) -> QuantizationPreprocessed { - match &*self.train { - Quantizer::RaBitQ(x) => QuantizationPreprocessed::RaBit(x.preprocess(lhs, distance)), - _ => unreachable!(), } } @@ -299,9 +261,6 @@ impl Quantization { let rhs = &self.codes[start..end]; x.process(lhs, rhs) } - (Quantizer::RaBitQ(x), QuantizationPreprocessed::RaBit(lhs)) => { - x.process(lhs, u as usize) - } _ => unreachable!(), } } @@ -350,19 +309,6 @@ impl Quantization { Trivial(x) => Box::new(x.flat_rerank(heap, r)), Scalar(x) => Box::new(x.flat_rerank(heap, r, sq_rerank_size)), Product(x) => Box::new(x.flat_rerank(heap, r, pq_rerank_size)), - _ => unreachable!(), - } - } - - pub fn ivf_projection_rerank<'a, T: 'a>( - &'a self, - rough_distances: Vec<(F32, u32)>, - r: impl Fn(u32) -> (F32, T) + 'a, - ) -> Box + 'a> { - use Quantizer::*; - match &*self.train { - RaBitQ(x) => Box::new(x.ivf_projection_rerank(rough_distances, r)), - _ => unreachable!(), } } @@ -394,7 +340,6 @@ impl Quantization { }, r, )), - _ => unreachable!(), } } } diff --git a/crates/quantization/src/operator.rs b/crates/quantization/src/operator.rs index 8aaeaf129..387b19a13 100644 --- a/crates/quantization/src/operator.rs +++ b/crates/quantization/src/operator.rs @@ -1,5 +1,4 @@ use crate::product::operator::OperatorProductQuantization; -use crate::rabitq::operator::OperatorRaBitQ; use crate::scalar::operator::OperatorScalarQuantization; use crate::trivial::operator::OperatorTrivialQuantization; use base::operator::*; @@ -184,7 +183,6 @@ pub trait OperatorQuantization: + OperatorTrivialQuantization + OperatorScalarQuantization + OperatorProductQuantization - + OperatorRaBitQ { } diff --git a/crates/quantization/src/rabitq/mod.rs b/crates/quantization/src/rabitq/mod.rs deleted file mode 100644 index 08b08559e..000000000 --- a/crates/quantization/src/rabitq/mod.rs +++ /dev/null @@ -1,153 +0,0 @@ -use std::ops::Div; - -use self::operator::OperatorRaBitQ; -use crate::reranker::error_based::ErrorBasedReranker; -use base::index::{RaBitQuantizationOptions, VectorOptions}; -use base::operator::{Borrowed, Owned, Scalar}; -use base::scalar::{ScalarLike, F32}; -use base::search::{RerankerPop, Vectors}; -use base::vector::{VectorBorrowed, VectorOwned}; - -use num_traits::{Float, One, Zero}; -use rand::{thread_rng, Rng}; -use serde::{Deserialize, Serialize}; - -pub mod operator; - -const EPSILON: f32 = 1.9; -const THETA_LOG_DIM: u32 = 4; -const DEFAULT_X_DOT_PRODUCT: f32 = 0.8; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct RaBitQuantizer { - dim: u32, - dim_pad_64: u32, - projection: Vec>>, - binary_vec_x: Vec>, - distance_to_centroid_square: Vec>, - rand_bias: Vec>, - error_bound: Vec>, - factor_ip: Vec>, - factor_ppc: Vec>, -} - -impl RaBitQuantizer { - pub fn train( - vector_options: VectorOptions, - _options: RaBitQuantizationOptions, - vectors: &impl Vectors, - transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy, - ) -> Self { - let dim_pad = (vector_options.dims + 63) / 64 * 64; - let mut rand_bias = Vec::with_capacity(dim_pad as usize); - let mut rng = thread_rng(); - for _ in 0..dim_pad { - rand_bias.push(Scalar::::from_f32(rng.gen())); - } - let projection = O::gen_random_orthogonal(dim_pad as usize); - let n = vectors.len() as usize; - let mut distance_to_centroid = vec![Scalar::::zero(); n]; - let mut distance_to_centroid_square = vec![Scalar::::zero(); n]; - let mut quantized_x = vec![vec![Scalar::::zero(); dim_pad as usize]; n]; - for i in 0..n { - let vector = transform(vectors.vector(i as u32)).as_borrowed().to_vec(); - distance_to_centroid_square[i] = O::vector_dot_product(&vector, &vector); - distance_to_centroid[i] = distance_to_centroid_square[i].sqrt(); - for j in 0..vector_options.dims as usize { - quantized_x[i][j] = O::vector_dot_product(&projection[j], &vector); - } - } - let mut binary_vec_x = Vec::with_capacity(n); - let mut signed_x = Vec::with_capacity(n); - for i in 0..n { - binary_vec_x.push(O::vector_binarize_u64(&quantized_x[i])); - signed_x.push(O::vector_binarize_one(&quantized_x[i])); - } - let mut dot_product_x = vec![Scalar::::zero(); n]; - for i in 0..n { - let norm = O::vector_dot_product(&quantized_x[i], &quantized_x[i]).sqrt() - * Scalar::::from_f32(dim_pad as f32).sqrt(); - dot_product_x[i] = if norm.is_normal() { - O::vector_dot_product(&quantized_x[i], &signed_x[i]).div(norm) - } else { - Scalar::::from_f32(DEFAULT_X_DOT_PRODUCT) - } - } - - let mut error_bound = Vec::with_capacity(n); - let mut factor_ip = Vec::with_capacity(n); - let mut factor_ppc = Vec::with_capacity(n); - let error_base = Scalar::::from_f32(2.0 * EPSILON / (dim_pad as f32 - 1.0).sqrt()); - let dim_pad_sqrt = Scalar::::from_f32(dim_pad as f32).sqrt(); - let one_vec = vec![Scalar::::one(); dim_pad as usize]; - for i in 0..n { - let xc_over_dot_product = distance_to_centroid[i] / dot_product_x[i]; - error_bound.push( - error_base - * (xc_over_dot_product * xc_over_dot_product - distance_to_centroid_square[i]) - .sqrt(), - ); - let ip = Scalar::::from_f32(-2.0) / dim_pad_sqrt * xc_over_dot_product; - factor_ip.push(ip); - factor_ppc.push(ip * O::vector_dot_product(&one_vec, &signed_x[i])); - } - - Self { - dim: vector_options.dims, - dim_pad_64: dim_pad, - projection, - binary_vec_x, - distance_to_centroid_square, - rand_bias, - error_bound, - factor_ip, - factor_ppc, - } - } - - pub fn project(&self, lhs: &[Scalar]) -> Vec> { - let vec = lhs.to_vec(); - let mut res = Vec::with_capacity(lhs.len()); - for i in 0..lhs.len() { - res.push(O::vector_dot_product(&self.projection[i], &vec)) - } - res - } - - pub fn width(&self) -> usize { - (self.dim / 64) as usize - } - - pub fn preprocess( - &self, - lhs: Borrowed<'_, O>, - centroid_distance: F32, - ) -> O::RabitQuantizationPreprocessed { - O::rabit_quantization_preprocess( - self.dim_pad_64 as usize, - lhs, - centroid_distance, - &self.rand_bias, - ) - } - - pub fn ivf_projection_rerank<'a, T: 'a>( - &'a self, - rough_distances: Vec<(F32, u32)>, - r: impl Fn(u32) -> (F32, T) + 'a, - ) -> impl RerankerPop + 'a { - ErrorBasedReranker::new(rough_distances, r) - } - - pub fn process(&self, preprocessed: &O::RabitQuantizationPreprocessed, i: usize) -> F32 { - O::rabit_quantization_process( - self.distance_to_centroid_square[i], - self.factor_ppc[i], - self.factor_ip[i], - self.error_bound[i], - &self.binary_vec_x[i], - preprocessed, - ) - } -} diff --git a/crates/quantization/src/rabitq/operator.rs b/crates/quantization/src/rabitq/operator.rs deleted file mode 100644 index 8d03b7ad3..000000000 --- a/crates/quantization/src/rabitq/operator.rs +++ /dev/null @@ -1,170 +0,0 @@ -use base::operator::{Borrowed, Operator, Scalar}; -use base::scalar::{ScalarLike, F32}; -use base::vector::VectorBorrowed; - -use nalgebra::debug::RandomOrthogonal; -use nalgebra::{Dim, Dyn}; -use num_traits::{Float, One, ToPrimitive, Zero}; -use rand::{thread_rng, Rng}; - -use super::THETA_LOG_DIM; - -pub trait OperatorRaBitQ: Operator { - type RabitQuantizationPreprocessed; - - fn vector_dot_product(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar; - fn vector_binarize_u64(vec: &[Scalar]) -> Vec; - fn vector_binarize_one(vec: &[Scalar]) -> Vec>; - fn query_vector_binarize_u64(vec: &[u8]) -> Vec; - fn gen_random_orthogonal(dim: usize) -> Vec>>; - fn rabit_quantization_process( - x_centroid_square: Scalar, - factor_ppc: Scalar, - factor_ip: Scalar, - error_bound: Scalar, - binary_x: &[u64], - p: &Self::RabitQuantizationPreprocessed, - ) -> F32; - fn rabit_quantization_preprocess( - dim: usize, - vec: Borrowed<'_, Self>, - distance: F32, - rand_bias: &[Scalar], - ) -> Self::RabitQuantizationPreprocessed; -} - -impl OperatorRaBitQ for O { - // (distance, lower_bound, delta, scalar_sum, binary_vec_y) - type RabitQuantizationPreprocessed = (Scalar, Scalar, Scalar, Scalar, Vec); - - fn rabit_quantization_preprocess( - dim: usize, - vec: Borrowed<'_, Self>, - distance: F32, - rand_bias: &[Scalar], - ) -> Self::RabitQuantizationPreprocessed { - let vec = vec.to_vec(); - let mut lower_bound = Scalar::::infinity(); - let mut upper_bound = Scalar::::neg_infinity(); - for i in 0..dim { - lower_bound = Float::min(lower_bound, vec[i]); - upper_bound = Float::max(upper_bound, vec[i]); - } - let delta = - (upper_bound - lower_bound) / Scalar::::from_f32((1 << THETA_LOG_DIM) as f32 - 1.0); - - // scalar quantization - let mut quantized_y_scalar = vec![0u8; dim]; - let mut scalar_sum = 0u32; - let one_over_delta = Scalar::::one() / delta; - for i in 0..dim { - quantized_y_scalar[i] = ((vec[i] - lower_bound) * one_over_delta + rand_bias[i]) - .to_u8() - .expect("failed to convert a Scalar value to u8"); - scalar_sum += quantized_y_scalar[i] as u32; - } - let binary_vec_y = O::query_vector_binarize_u64(&quantized_y_scalar); - ( - Scalar::::from_f(distance), - lower_bound, - delta, - Scalar::::from_f32(scalar_sum as f32), - binary_vec_y, - ) - } - - fn gen_random_orthogonal(dim: usize) -> Vec>> { - let mut rng = thread_rng(); - let random_orth: RandomOrthogonal = - RandomOrthogonal::new(Dim::from_usize(dim), || rng.gen()); - let random_matrix = random_orth.unwrap(); - let mut projection = vec![Vec::with_capacity(dim); dim]; - // use the transpose of the random matrix as the inverse of the orthogonal matrix, - // but we need to transpose it again to make it efficient for the dot production - for (i, vec) in random_matrix.row_iter().enumerate() { - for val in vec.iter() { - projection[i].push(Scalar::::from_f32(*val)); - } - } - projection - } - - fn vector_dot_product(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - let mut sum = Scalar::::zero(); - let length = std::cmp::min(lhs.len(), rhs.len()); - for i in 0..length { - sum += lhs[i] * rhs[i]; - } - sum - } - - // binarize vector to 0 or 1 in binary format stored in u64 - fn vector_binarize_u64(vec: &[Scalar]) -> Vec { - let mut binary = vec![0u64; (vec.len() + 63) / 64]; - let zero = Scalar::::zero(); - for i in 0..vec.len() { - if vec[i] > zero { - binary[i / 64] |= 1 << (i % 64); - } - } - binary - } - - // binarize vector to +1 or -1 - fn vector_binarize_one(vec: &[Scalar]) -> Vec> { - let mut binary = vec![Scalar::::one(); vec.len()]; - let zero = Scalar::::zero(); - for i in 0..vec.len() { - if vec[i] <= zero { - binary[i] = -Scalar::::one(); - } - } - binary - } - - fn query_vector_binarize_u64(vec: &[u8]) -> Vec { - let length = vec.len(); - let mut binary = vec![0u64; length * THETA_LOG_DIM as usize / 64]; - for j in 0..THETA_LOG_DIM as usize { - for i in 0..length { - binary[(i + j * length) / 64] |= (((vec[i] >> j) & 1) as u64) << (i % 64); - } - } - binary - } - - fn rabit_quantization_process( - x_centroid_square: Scalar, - factor_ppc: Scalar, - factor_ip: Scalar, - error_bound: Scalar, - binary_x: &[u64], - p: &Self::RabitQuantizationPreprocessed, - ) -> F32 { - let estimate = x_centroid_square - + p.0 * p.0 - + p.1 * factor_ppc - + (Scalar::::from_f32(2.0 * asymmetric_binary_dot_product(binary_x, &p.4) as f32) - - p.3) - * factor_ip - * p.2; - (estimate - error_bound * p.0).to_f() - } -} - -fn binary_dot_product(x: &[u64], y: &[u64]) -> u32 { - let mut res = 0; - for i in 0..x.len() { - res += (x[i] & y[i]).count_ones(); - } - res -} - -fn asymmetric_binary_dot_product(x: &[u64], y: &[u64]) -> u32 { - let mut res = 0; - let length = x.len(); - for i in 0..THETA_LOG_DIM as usize { - res += binary_dot_product(x, &y[i * length..(i + 1) * length]) << i; - } - res -} diff --git a/crates/quantization/src/reranker/error.rs b/crates/quantization/src/reranker/error.rs deleted file mode 100644 index 4af405750..000000000 --- a/crates/quantization/src/reranker/error.rs +++ /dev/null @@ -1,38 +0,0 @@ -use base::scalar::F32; -use base::search::RerankerPop; -use common::always_equal::AlwaysEqual; -use std::cmp::Reverse; -use std::collections::BinaryHeap; - -pub struct ErrorFlatReranker { - rerank: R, - heap: BinaryHeap<(Reverse, u32)>, - cache: BinaryHeap<(Reverse, u32, AlwaysEqual)>, -} - -impl ErrorFlatReranker { - pub fn new(heap: Vec<(Reverse, u32)>, rerank: R) -> Self { - Self { - rerank, - heap: heap.into(), - cache: BinaryHeap::new(), - } - } -} - -impl RerankerPop for ErrorFlatReranker -where - R: Fn(u32) -> (F32, T), -{ - fn pop(&mut self) -> Option<(F32, u32, T)> { - while !self.heap.is_empty() - && (self.cache.is_empty() || self.heap.peek().unwrap().0 > self.cache.peek().unwrap().0) - { - let (_, u) = self.heap.pop().unwrap(); - let (accu_u, t) = (self.rerank)(u); - self.cache.push((Reverse(accu_u), u, AlwaysEqual(t))); - } - let (Reverse(accu_u), u, AlwaysEqual(t)) = self.cache.pop()?; - Some((accu_u, u, t)) - } -} diff --git a/crates/quantization/src/reranker/mod.rs b/crates/quantization/src/reranker/mod.rs index fa30ea85b..3a3cdd632 100644 --- a/crates/quantization/src/reranker/mod.rs +++ b/crates/quantization/src/reranker/mod.rs @@ -1,5 +1,3 @@ pub mod disabled; -pub mod error; -pub mod error_based; pub mod window; pub mod window_0; diff --git a/crates/quantization/src/scalar/mod.rs b/crates/quantization/src/scalar/mod.rs index d21cd1f54..67b25f20b 100644 --- a/crates/quantization/src/scalar/mod.rs +++ b/crates/quantization/src/scalar/mod.rs @@ -15,6 +15,7 @@ use num_traits::Float; use serde::Deserialize; use serde::Serialize; use std::cmp::Reverse; +use std::marker::PhantomData; use std::ops::Range; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -22,9 +23,10 @@ use std::ops::Range; pub struct ScalarQuantizer { dims: u32, bits: u32, - max: Vec>, - min: Vec>, - centroids: Vec2>, + max: Vec, + min: Vec, + centroids: Vec2, + _phantom: PhantomData O>, } impl ScalarQuantizer { @@ -36,14 +38,14 @@ impl ScalarQuantizer { ) -> Self { let dims = vector_options.dims; let bits = scalar_quantization_options.bits; - let mut max = vec![Scalar::::neg_infinity(); dims as usize]; - let mut min = vec![Scalar::::infinity(); dims as usize]; + let mut max = vec![F32::neg_infinity(); dims as usize]; + let mut min = vec![F32::infinity(); dims as usize]; let n = vectors.len(); for i in 0..n { let vector = transform(vectors.vector(i)).as_borrowed().to_vec(); for j in 0..dims as usize { - max[j] = std::cmp::max(max[j], vector[j]); - min[j] = std::cmp::min(min[j], vector[j]); + max[j] = std::cmp::max(max[j], vector[j].to_f()); + min[j] = std::cmp::min(min[j], vector[j].to_f()); } } let mut centroids = Vec2::zeros((1 << bits, dims as usize)); @@ -51,7 +53,7 @@ impl ScalarQuantizer { let bas = min[p as usize]; let del = max[p as usize] - min[p as usize]; for j in 0_usize..(1 << bits) { - let val = Scalar::::from_f(F32(j as f32 / ((1 << bits) - 1) as f32)); + let val = F32(j as f32 / ((1 << bits) - 1) as f32); centroids[(j, p as usize)] = bas + val * del; } } @@ -61,6 +63,7 @@ impl ScalarQuantizer { max, min, centroids, + _phantom: PhantomData, } } @@ -83,8 +86,8 @@ impl ScalarQuantizer { let mut codes = Vec::with_capacity(dims as usize); for i in 0..dims as usize { let del = self.max[i] - self.min[i]; - let w = - (((vector[i] - self.min[i]) / del).to_f32() * (((1 << bits) - 1) as f32)) as u32; + let w = (((vector[i].to_f() - self.min[i]) / del).to_f32() * (((1 << bits) - 1) as f32)) + as u32; codes.push(w.clamp(0, 255) as u8); } codes diff --git a/crates/quantization/src/scalar/operator.rs b/crates/quantization/src/scalar/operator.rs index eb63c97b8..8fc3a13c4 100644 --- a/crates/quantization/src/scalar/operator.rs +++ b/crates/quantization/src/scalar/operator.rs @@ -6,8 +6,8 @@ pub trait OperatorScalarQuantization: Operator + OperatorQuantizationProcess { fn scalar_quantization_preprocess( dims: u32, bits: u32, - max: &[Scalar], - min: &[Scalar], + max: &[F32], + min: &[F32], lhs: Borrowed<'_, Self>, ) -> Self::QuantizationPreprocessed; } @@ -16,8 +16,8 @@ impl OperatorScalarQuantization for Vecf32Dot { fn scalar_quantization_preprocess( dims: u32, bits: u32, - max: &[Scalar], - min: &[Scalar], + max: &[F32], + min: &[F32], lhs: Borrowed<'_, Self>, ) -> Self::QuantizationPreprocessed { let mut xy = Vec::with_capacity(dims as _); @@ -26,7 +26,7 @@ impl OperatorScalarQuantization for Vecf32Dot { let del = max[i as usize] - min[i as usize]; xy.extend((0..1 << bits).map(|k| { let x = lhs.slice()[i as usize]; - let val = Scalar::::from_f(F32(k as f32 / ((1 << bits) - 1) as f32)); + let val = F32(k as f32 / ((1 << bits) - 1) as f32); let y = bas + val * del; x * y })); @@ -39,8 +39,8 @@ impl OperatorScalarQuantization for Vecf32L2 { fn scalar_quantization_preprocess( dims: u32, bits: u32, - max: &[Scalar], - min: &[Scalar], + max: &[F32], + min: &[F32], lhs: Borrowed<'_, Self>, ) -> Self::QuantizationPreprocessed { let mut d2 = Vec::with_capacity(dims as _); @@ -49,7 +49,7 @@ impl OperatorScalarQuantization for Vecf32L2 { let del = max[i as usize] - min[i as usize]; d2.extend((0..1 << bits).map(|k| { let x = lhs.slice()[i as usize]; - let val = Scalar::::from_f(F32(k as f32 / ((1 << bits) - 1) as f32)); + let val = F32(k as f32 / ((1 << bits) - 1) as f32); let y = bas + val * del; let d = x - y; d * d @@ -63,8 +63,8 @@ impl OperatorScalarQuantization for Vecf16Dot { fn scalar_quantization_preprocess( dims: u32, bits: u32, - max: &[Scalar], - min: &[Scalar], + max: &[F32], + min: &[F32], lhs: Borrowed<'_, Self>, ) -> Self::QuantizationPreprocessed { let mut xy = Vec::with_capacity(dims as _); @@ -73,8 +73,8 @@ impl OperatorScalarQuantization for Vecf16Dot { let del = max[i as usize] - min[i as usize]; xy.extend((0..1 << bits).map(|k| { let x = lhs.slice()[i as usize].to_f(); - let val = Scalar::::from_f(F32(k as f32 / ((1 << bits) - 1) as f32)); - let y = (bas + val * del).to_f32(); + let val = F32(k as f32 / ((1 << bits) - 1) as f32); + let y = bas + val * del; x * y })); } @@ -86,8 +86,8 @@ impl OperatorScalarQuantization for Vecf16L2 { fn scalar_quantization_preprocess( dims: u32, bits: u32, - max: &[Scalar], - min: &[Scalar], + max: &[F32], + min: &[F32], lhs: Borrowed<'_, Self>, ) -> Self::QuantizationPreprocessed { let mut d2 = Vec::with_capacity(dims as _); @@ -96,8 +96,8 @@ impl OperatorScalarQuantization for Vecf16L2 { let del = max[i as usize] - min[i as usize]; d2.extend((0..1 << bits).map(|k| { let x = lhs.slice()[i as usize].to_f(); - let val = Scalar::::from_f(F32(k as f32 / ((1 << bits) - 1) as f32)); - let y = (bas + val * del).to_f32(); + let val = F32(k as f32 / ((1 << bits) - 1) as f32); + let y = bas + val * del; let d = x - y; d * d })); @@ -112,8 +112,8 @@ macro_rules! unimpl_operator_scalar_quantization { fn scalar_quantization_preprocess( _: u32, _: u32, - _: &[Scalar], - _: &[Scalar], + _: &[F32], + _: &[F32], _: Borrowed<'_, Self>, ) -> Self::QuantizationPreprocessed { unimplemented!() diff --git a/crates/rabitq/Cargo.toml b/crates/rabitq/Cargo.toml new file mode 100644 index 000000000..18b17789e --- /dev/null +++ b/crates/rabitq/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "rabitq" +version.workspace = true +edition.workspace = true + +[dependencies] +log.workspace = true +num-traits.workspace = true +rand.workspace = true +serde.workspace = true +serde_json.workspace = true + +base = { path = "../base" } +common = { path = "../common" } +detect = { path = "../detect" } +k_means = { version = "0.0.0", path = "../k_means" } +nalgebra = { version = "0.33.0", features = ["debug"] } +quantization = { path = "../quantization" } +stoppable_rayon = { path = "../stoppable_rayon" } +storage = { version = "0.0.0", path = "../storage" } + +[lints] +workspace = true diff --git a/crates/ivf/src/ivf_projection.rs b/crates/rabitq/src/lib.rs similarity index 56% rename from crates/ivf/src/ivf_projection.rs rename to crates/rabitq/src/lib.rs index ed05d8055..e479dc46f 100644 --- a/crates/ivf/src/ivf_projection.rs +++ b/crates/rabitq/src/lib.rs @@ -1,29 +1,38 @@ -use super::OperatorIvf as Op; -use base::index::{IndexOptions, IvfIndexingOptions, SearchOptions}; -use base::operator::{Borrowed, Scalar}; +#![allow(clippy::needless_range_loop)] +#![allow(clippy::type_complexity)] +#![allow(clippy::identity_op)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::len_without_is_empty)] + +pub mod operator; +pub mod quant; + +use crate::operator::OperatorRabitq as Op; +use crate::quant::quantization::Quantization; +use base::index::{IndexOptions, RabitqIndexingOptions, SearchOptions}; +use base::operator::Borrowed; +use base::scalar::F32; use base::search::{Collection, Element, Payload, Source, Vectors}; -use base::vector::{VectorBorrowed, VectorOwned}; use common::json::Json; use common::mmap_array::MmapArray; use common::remap::RemappedCollection; use common::vec2::Vec2; use k_means::{k_means, k_means_lookup, k_means_lookup_many}; -use quantization::Quantization; -use stoppable_rayon as rayon; -use storage::Storage; - use std::fs::create_dir; use std::path::Path; +use stoppable_rayon as rayon; +use storage::Storage; -pub struct IvfProjection { +pub struct Rabitq { storage: O::Storage, quantization: Quantization, payloads: MmapArray, offsets: Json>, - centroids: Json>>, + centroids: Json>, + projection: Json>>, } -impl IvfProjection { +impl Rabitq { pub fn create(path: impl AsRef, options: IndexOptions, source: &impl Source) -> Self { let remapped = RemappedCollection::from_source(source); from_nothing(path, options, &remapped) @@ -50,40 +59,32 @@ impl IvfProjection { vector: Borrowed<'a, O>, opts: &'a SearchOptions, ) -> (Vec, Box + 'a>) { - let projected_query = self.quantization.project(&vector.to_vec()); + let projected_query = O::proj(&self.projection, O::cast(vector)); let lists = select( k_means_lookup_many(&projected_query, &self.centroids), - opts.ivf_nprobe as usize, + opts.rabitq_nprobe as usize, ); - let preprocessed_centroids = lists - .iter() - .map(|&(distance, i)| { - self.quantization.projection_preprocess( - O::residual_dense(&projected_query, &self.centroids[(i,)]).as_borrowed(), - distance, - ) - }) - .collect::>(); - let mut rough_distances = Vec::new(); - for (code, i) in lists.iter().map(|(_, i)| *i).enumerate() { + let mut heap = Vec::new(); + for &(_, i) in lists.iter() { + let preprocessed = self + .quantization + .preprocess(&O::residual(&projected_query, &self.centroids[(i,)])); let start = self.offsets[i]; let end = self.offsets[i + 1]; - for u in start..end { - rough_distances.push(( - self.quantization - .process(&self.storage, &preprocessed_centroids[code], u), - u, - )); - } + self.quantization.push_batch( + &preprocessed, + start..end, + &mut heap, + F32(1.9), + opts.rabitq_fast_scan, + ); } - let mut reranker = self - .quantization - .ivf_projection_rerank(rough_distances, move |u| { - ( - O::distance(vector, self.storage.vector(u)), - self.payloads[u as usize], - ) - }); + let mut reranker = self.quantization.rerank(heap, move |u| { + ( + O::distance(vector, self.storage.vector(u)), + self.payloads[u as usize], + ) + }); ( Vec::new(), Box::new(std::iter::from_fn(move || { @@ -100,21 +101,35 @@ fn from_nothing( path: impl AsRef, options: IndexOptions, collection: &impl Collection, -) -> IvfProjection { +) -> Rabitq { create_dir(path.as_ref()).unwrap(); - let IvfIndexingOptions { - nlist, - quantization: quantization_options, - spherical_centroids: _, - residual_quantization: _, - } = options.indexing.clone().unwrap_ivf(); - let samples = common::sample::sample(collection); + let RabitqIndexingOptions { nlist } = options.indexing.clone().unwrap_rabitq(); + let projection = { + use nalgebra::debug::RandomOrthogonal; + use nalgebra::{Dim, Dyn}; + use rand::Rng; + let dims = options.vector.dims as usize; + let mut rng = rand::thread_rng(); + let random_orth: RandomOrthogonal = + RandomOrthogonal::new(Dim::from_usize(dims), || rng.gen()); + let random_matrix = random_orth.unwrap(); + let mut projection = vec![Vec::with_capacity(dims); dims]; + // use the transpose of the random matrix as the inverse of the orthogonal matrix, + // but we need to transpose it again to make it efficient for the dot production + for (i, vec) in random_matrix.row_iter().enumerate() { + for &val in vec.iter() { + projection[i].push(F32(val)); + } + } + projection + }; + let samples = common::sample::sample_cast(collection); rayon::check(); - let centroids = k_means(nlist as usize, samples, false); + let centroids: Vec2 = k_means(nlist as usize, samples, false); rayon::check(); let mut ls = vec![Vec::new(); nlist as usize]; for i in 0..collection.len() { - ls[k_means_lookup(&collection.vector(i).to_vec(), ¢roids)].push(i); + ls[k_means_lookup(O::cast(collection.vector(i)), ¢roids)].push(i); } let mut offsets = vec![0u32; nlist as usize + 1]; for i in 0..nlist { @@ -130,17 +145,17 @@ fn from_nothing( let quantization = Quantization::create( path.as_ref().join("quantization"), options.vector, - quantization_options, - &collection, + collection.len(), |vector| { - let target = k_means_lookup(&vector.to_vec(), ¢roids); - O::residual(vector, ¢roids[(target,)]) + let vector = O::cast(collection.vector(vector)); + let target = k_means_lookup(vector, ¢roids); + O::proj(&projection, &O::residual(vector, ¢roids[(target,)])) }, ); let projected_centroids = Vec2::from_vec( (centroids.shape_0(), centroids.shape_1()), (0..centroids.shape_0()) - .flat_map(|x| quantization.project(¢roids[(x,)])) + .flat_map(|x| O::proj(&projection, ¢roids[(x,)])) .collect(), ); let payloads = MmapArray::create( @@ -149,27 +164,31 @@ fn from_nothing( ); let offsets = Json::create(path.as_ref().join("offsets"), offsets); let centroids = Json::create(path.as_ref().join("centroids"), projected_centroids); - IvfProjection { + let projection = Json::create(path.as_ref().join("projection"), projection); + Rabitq { storage, payloads, offsets, centroids, quantization, + projection, } } -fn open(path: impl AsRef) -> IvfProjection { +fn open(path: impl AsRef) -> Rabitq { let storage = O::Storage::open(path.as_ref().join("storage")); let quantization = Quantization::open(path.as_ref().join("quantization")); let payloads = MmapArray::open(path.as_ref().join("payloads")); let offsets = Json::open(path.as_ref().join("offsets")); let centroids = Json::open(path.as_ref().join("centroids")); - IvfProjection { + let projection = Json::open(path.as_ref().join("projection")); + Rabitq { storage, quantization, payloads, offsets, centroids, + projection, } } diff --git a/crates/rabitq/src/operator.rs b/crates/rabitq/src/operator.rs new file mode 100644 index 000000000..c4e140e72 --- /dev/null +++ b/crates/rabitq/src/operator.rs @@ -0,0 +1,303 @@ +use base::operator::Borrowed; +use base::operator::*; +use base::scalar::F32; +use num_traits::Float; +use storage::OperatorStorage; + +pub trait OperatorRabitq: OperatorStorage { + const RESIDUAL: bool; + fn cast(vector: Borrowed<'_, Self>) -> &[F32]; + fn residual(lhs: &[F32], rhs: &[F32]) -> Vec; + + fn proj(projection: &[Vec], vector: &[F32]) -> Vec; + + type QuantizationPreprocessed0; + type QuantizationPreprocessed1; + + fn rabitq_quantization_preprocess( + vector: &[F32], + ) -> ( + Self::QuantizationPreprocessed0, + Self::QuantizationPreprocessed1, + ); + fn rabitq_quantization_process( + dis_u_2: F32, + factor_ppc: F32, + factor_ip: F32, + factor_err: F32, + code: &[u8], + p0: &Self::QuantizationPreprocessed0, + p1: &Self::QuantizationPreprocessed1, + ) -> (F32, F32); + fn rabitq_quantization_process_1( + dis_u_2: F32, + factor_ppc: F32, + factor_ip: F32, + factor_err: F32, + p0: &Self::QuantizationPreprocessed0, + param: u16, + ) -> (F32, F32); + + const SUPPORT_FAST_SCAN: bool; + fn fast_scan(preprocessed: &Self::QuantizationPreprocessed1) -> Vec; + fn fast_scan_resolve(x: F32) -> F32; +} + +impl OperatorRabitq for Vecf32L2 { + const RESIDUAL: bool = false; + fn cast(vector: Borrowed<'_, Self>) -> &[F32] { + vector.slice() + } + fn residual(lhs: &[F32], rhs: &[F32]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + (0..n).map(|i| lhs[i] - rhs[i]).collect() + } + + type QuantizationPreprocessed0 = (F32, F32, F32, F32); + type QuantizationPreprocessed1 = ((Vec, Vec, Vec, Vec), Vec); + + fn rabitq_quantization_preprocess( + vector: &[F32], + ) -> ( + (F32, F32, F32, F32), + ((Vec, Vec, Vec, Vec), Vec), + ) { + let dis_v_2 = vector.iter().map(|&x| x * x).sum(); + let (k, b, qvector) = quantization::quantize::quantize_15(vector); + let qvector_sum = F32(qvector.iter().fold(0_u32, |x, &y| x + y as u32) as _); + let blut = binarize(&qvector); + let lut = gen(&qvector); + ((dis_v_2, b, k, qvector_sum), (blut, lut)) + } + + fn rabitq_quantization_process( + dis_u_2: F32, + factor_ppc: F32, + factor_ip: F32, + factor_err: F32, + code: &[u8], + p0: &(F32, F32, F32, F32), + p1: &((Vec, Vec, Vec, Vec), Vec), + ) -> (F32, F32) { + rabitq_quantization_process(dis_u_2, factor_ppc, factor_ip, factor_err, code, *p0, p1) + } + + fn rabitq_quantization_process_1( + dis_u_2: F32, + factor_ppc: F32, + factor_ip: F32, + factor_err: F32, + p0: &Self::QuantizationPreprocessed0, + param: u16, + ) -> (F32, F32) { + rabitq_quantization_process_1(dis_u_2, factor_ppc, factor_ip, factor_err, *p0, param) + } + + fn proj(projection: &[Vec], vector: &[F32]) -> Vec { + let dims = vector.len(); + assert_eq!(projection.len(), dims); + (0..dims) + .map(|i| { + assert_eq!(projection[i].len(), dims); + let mut xy = F32(0.0); + for j in 0..dims { + xy += projection[i][j] * vector[j]; + } + xy + }) + .collect() + } + + const SUPPORT_FAST_SCAN: bool = true; + fn fast_scan(preprocessed: &((Vec, Vec, Vec, Vec), Vec)) -> Vec { + preprocessed.1.clone() + } + fn fast_scan_resolve(x: F32) -> F32 { + x + } +} + +macro_rules! unimpl_operator_rabitq { + ($t:ty) => { + impl OperatorRabitq for $t { + const RESIDUAL: bool = false; + fn cast(_: Borrowed<'_, Self>) -> &[F32] { + unimplemented!() + } + + fn residual(_: &[F32], _: &[F32]) -> Vec { + unimplemented!() + } + + fn proj(_: &[Vec], _: &[F32]) -> Vec { + unimplemented!() + } + + type QuantizationPreprocessed0 = std::convert::Infallible; + type QuantizationPreprocessed1 = std::convert::Infallible; + + fn rabitq_quantization_preprocess( + _: &[F32], + ) -> ( + Self::QuantizationPreprocessed0, + Self::QuantizationPreprocessed1, + ) { + unimplemented!() + } + + fn rabitq_quantization_process( + _: F32, + _: F32, + _: F32, + _: F32, + _: &[u8], + _: &Self::QuantizationPreprocessed0, + _: &Self::QuantizationPreprocessed1, + ) -> (F32, F32) { + unimplemented!() + } + + fn rabitq_quantization_process_1( + _: F32, + _: F32, + _: F32, + _: F32, + _: &Self::QuantizationPreprocessed0, + _: u16, + ) -> (F32, F32) { + unimplemented!() + } + + const SUPPORT_FAST_SCAN: bool = false; + fn fast_scan(_: &Self::QuantizationPreprocessed1) -> Vec { + unimplemented!() + } + fn fast_scan_resolve(_: F32) -> F32 { + unimplemented!() + } + } + }; +} + +unimpl_operator_rabitq!(Vecf32Dot); + +unimpl_operator_rabitq!(Vecf16Dot); +unimpl_operator_rabitq!(Vecf16L2); + +unimpl_operator_rabitq!(BVectorDot); +unimpl_operator_rabitq!(BVectorHamming); +unimpl_operator_rabitq!(BVectorJaccard); + +unimpl_operator_rabitq!(SVecf32Dot); +unimpl_operator_rabitq!(SVecf32L2); + +#[inline(always)] +pub fn rabitq_quantization_process( + dis_u_2: F32, + factor_ppc: F32, + factor_ip: F32, + factor_err: F32, + code: &[u8], + params: (F32, F32, F32, F32), + (blut, _lut): &((Vec, Vec, Vec, Vec), Vec), +) -> (F32, F32) { + let abdp = asymmetric_binary_dot_product(code, blut) as u16; + rabitq_quantization_process_1(dis_u_2, factor_ppc, factor_ip, factor_err, params, abdp) +} + +#[inline(always)] +pub fn rabitq_quantization_process_1( + dis_u_2: F32, + factor_ppc: F32, + factor_ip: F32, + factor_err: F32, + (dis_v_2, b, k, qvector_sum): (F32, F32, F32, F32), + abdp: u16, +) -> (F32, F32) { + let rough = + dis_u_2 + dis_v_2 + b * factor_ppc + (F32(2.0 * abdp as f32) - qvector_sum) * factor_ip * k; + let err = factor_err * dis_v_2.sqrt(); + (rough, err) +} + +fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { + let n = vector.len(); + let t0 = { + let mut t = vec![0u8; n.div_ceil(8)]; + for i in 0..n { + t[i / 8] |= ((vector[i] >> 0) & 1) << (i % 8); + } + t + }; + let t1 = { + let mut t = vec![0u8; n.div_ceil(8)]; + for i in 0..n { + t[i / 8] |= ((vector[i] >> 1) & 1) << (i % 8); + } + t + }; + let t2 = { + let mut t = vec![0u8; n.div_ceil(8)]; + for i in 0..n { + t[i / 8] |= ((vector[i] >> 2) & 1) << (i % 8); + } + t + }; + let t3 = { + let mut t = vec![0u8; n.div_ceil(8)]; + for i in 0..n { + t[i / 8] |= ((vector[i] >> 3) & 1) << (i % 8); + } + t + }; + (t0, t1, t2, t3) +} + +fn gen(qvector: &[u8]) -> Vec { + let dims = qvector.len() as u32; + let t = dims.div_ceil(4); + let mut lut = vec![0u8; t as usize * 16]; + for i in 0..t as usize { + let t0 = qvector.get(4 * i + 0).copied().unwrap_or_default(); + let t1 = qvector.get(4 * i + 1).copied().unwrap_or_default(); + let t2 = qvector.get(4 * i + 2).copied().unwrap_or_default(); + let t3 = qvector.get(4 * i + 3).copied().unwrap_or_default(); + lut[16 * i + 0b0000] = 0; + lut[16 * i + 0b0001] = t0; + lut[16 * i + 0b0010] = t1; + lut[16 * i + 0b0011] = t1 + t0; + lut[16 * i + 0b0100] = t2; + lut[16 * i + 0b0101] = t2 + t0; + lut[16 * i + 0b0110] = t2 + t1; + lut[16 * i + 0b0111] = t2 + t1 + t0; + lut[16 * i + 0b1000] = t3; + lut[16 * i + 0b1001] = t3 + t0; + lut[16 * i + 0b1010] = t3 + t1; + lut[16 * i + 0b1011] = t3 + t1 + t0; + lut[16 * i + 0b1100] = t3 + t2; + lut[16 * i + 0b1101] = t3 + t2 + t0; + lut[16 * i + 0b1110] = t3 + t2 + t1; + lut[16 * i + 0b1111] = t3 + t2 + t1 + t0; + } + lut +} + +fn binary_dot_product(x: &[u8], y: &[u8]) -> u32 { + assert_eq!(x.len(), y.len()); + let n = x.len(); + let mut res = 0; + for i in 0..n { + res += (x[i] & y[i]).count_ones(); + } + res +} + +fn asymmetric_binary_dot_product(x: &[u8], y: &(Vec, Vec, Vec, Vec)) -> u32 { + let mut res = 0; + res += binary_dot_product(x, &y.0) << 0; + res += binary_dot_product(x, &y.1) << 1; + res += binary_dot_product(x, &y.2) << 2; + res += binary_dot_product(x, &y.3) << 3; + res +} diff --git a/crates/quantization/src/reranker/error_based.rs b/crates/rabitq/src/quant/error_based.rs similarity index 89% rename from crates/quantization/src/reranker/error_based.rs rename to crates/rabitq/src/quant/error_based.rs index aaf6266e1..e135c1e90 100644 --- a/crates/quantization/src/reranker/error_based.rs +++ b/crates/rabitq/src/quant/error_based.rs @@ -11,17 +11,17 @@ pub struct ErrorBasedReranker { rerank: R, cache: BinaryHeap<(Reverse, u32, AlwaysEqual)>, distance_threshold: F32, - rough_distances: Vec<(F32, u32)>, + heap: Vec<(Reverse, u32)>, ranked: bool, } impl ErrorBasedReranker { - pub fn new(rough_distances: Vec<(F32, u32)>, rerank: R) -> Self { + pub fn new(heap: Vec<(Reverse, u32)>, rerank: R) -> Self { Self { rerank, cache: BinaryHeap::new(), distance_threshold: F32::infinity(), - rough_distances, + heap, ranked: false, } } @@ -36,7 +36,7 @@ where self.ranked = true; let mut recent_max_accurate = F32::neg_infinity(); let mut count = 0; - for &(lowerbound, u) in self.rough_distances.iter() { + for &(Reverse(lowerbound), u) in self.heap.iter() { if lowerbound < self.distance_threshold { let (accurate, t) = (self.rerank)(u); if accurate < self.distance_threshold { diff --git a/crates/rabitq/src/quant/mod.rs b/crates/rabitq/src/quant/mod.rs new file mode 100644 index 000000000..9504a3703 --- /dev/null +++ b/crates/rabitq/src/quant/mod.rs @@ -0,0 +1,3 @@ +pub mod error_based; +pub mod quantization; +pub mod quantizer; diff --git a/crates/rabitq/src/quant/quantization.rs b/crates/rabitq/src/quant/quantization.rs new file mode 100644 index 000000000..81c333975 --- /dev/null +++ b/crates/rabitq/src/quant/quantization.rs @@ -0,0 +1,191 @@ +use super::quantizer::RabitqQuantizer; +use crate::operator::OperatorRabitq; +use base::index::VectorOptions; +use base::scalar::F32; +use base::search::RerankerPop; +use common::json::Json; +use common::mmap_array::MmapArray; +use quantization::utils::InfiniteByteChunks; +use serde::{Deserialize, Serialize}; +use std::cmp::Reverse; +use std::ops::Range; +use std::path::Path; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub enum Quantizer { + Rabitq(RabitqQuantizer), +} + +impl Quantizer { + pub fn train(vector_options: VectorOptions) -> Self { + Self::Rabitq(RabitqQuantizer::train(vector_options)) + } +} + +pub enum QuantizationPreprocessed { + Rabitq( + ( + ::QuantizationPreprocessed0, + ::QuantizationPreprocessed1, + ), + ), +} + +pub struct Quantization { + train: Json>, + codes: MmapArray, + packed_codes: MmapArray, + meta: MmapArray, +} + +impl Quantization { + pub fn create( + path: impl AsRef, + vector_options: VectorOptions, + n: u32, + vectors: impl Fn(u32) -> Vec, + ) -> Self { + std::fs::create_dir(path.as_ref()).unwrap(); + fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 { + b0 | (b1 << 1) | (b2 << 2) | (b3 << 3) | (b4 << 4) | (b5 << 5) | (b6 << 6) | (b7 << 7) + } + fn merge_4([b0, b1, b2, b3]: [u8; 4]) -> u8 { + b0 | (b1 << 2) | (b2 << 4) | (b3 << 6) + } + fn merge_2([b0, b1]: [u8; 2]) -> u8 { + b0 | (b1 << 4) + } + let train = Quantizer::train(vector_options); + let train = Json::create(path.as_ref().join("train"), train); + let codes = MmapArray::create(path.as_ref().join("codes"), { + match &*train { + Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { + let vector = vectors(i); + let (_, _, _, _, codes) = x.encode(&vector); + let bytes = x.bytes(); + match x.bits() { + 1 => InfiniteByteChunks::new(codes.into_iter()) + .map(merge_8) + .take(bytes as usize) + .collect(), + 2 => InfiniteByteChunks::new(codes.into_iter()) + .map(merge_4) + .take(bytes as usize) + .collect(), + 4 => InfiniteByteChunks::new(codes.into_iter()) + .map(merge_2) + .take(bytes as usize) + .collect(), + 8 => codes, + _ => unreachable!(), + } + })), + } + }); + let packed_codes = MmapArray::create( + path.as_ref().join("packed_codes"), + match &*train { + Quantizer::Rabitq(x) => { + use quantization::fast_scan::b4::{pack, BLOCK_SIZE}; + let blocks = n.div_ceil(BLOCK_SIZE); + Box::new((0..blocks).flat_map(|block| { + let t = x.dims().div_ceil(4); + let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { + let id = BLOCK_SIZE * block + i as u32; + let (_, _, _, _, e) = x.encode(&vectors(std::cmp::min(id, n - 1))); + InfiniteByteChunks::new(e.into_iter()) + .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) + .take(t as usize) + .collect() + }); + pack(t, raw) + })) as Box> + } + }, + ); + let meta = MmapArray::create( + path.as_ref().join("meta"), + match &*train { + Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { + let (a, b, c, d, _) = x.encode(&vectors(i)); + [a, b, c, d].into_iter() + })), + }, + ); + Self { + train, + codes, + packed_codes, + meta, + } + } + + pub fn open(path: impl AsRef) -> Self { + let train = Json::open(path.as_ref().join("train")); + let codes = MmapArray::open(path.as_ref().join("codes")); + let packed_codes = MmapArray::open(path.as_ref().join("packed_codes")); + let meta = MmapArray::open(path.as_ref().join("meta")); + Self { + train, + codes, + packed_codes, + meta, + } + } + + pub fn preprocess(&self, lhs: &[F32]) -> QuantizationPreprocessed { + match &*self.train { + Quantizer::Rabitq(x) => QuantizationPreprocessed::Rabitq(x.preprocess(lhs)), + } + } + + pub fn process(&self, preprocessed: &QuantizationPreprocessed, u: u32) -> F32 { + match (&*self.train, preprocessed) { + (Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => { + let bytes = x.bytes() as usize; + let start = u as usize * bytes; + let end = start + bytes; + let a = self.meta[4 * u as usize + 0]; + let b = self.meta[4 * u as usize + 1]; + let c = self.meta[4 * u as usize + 2]; + let d = self.meta[4 * u as usize + 3]; + let codes = &self.codes[start..end]; + x.process(&lhs.0, &lhs.1, (a, b, c, d, codes)) + } + } + } + + pub fn push_batch( + &self, + preprocessed: &QuantizationPreprocessed, + rhs: Range, + heap: &mut Vec<(Reverse, u32)>, + rq_epsilon: F32, + rq_fast_scan: bool, + ) { + match (&*self.train, preprocessed) { + (Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => x.push_batch( + lhs, + rhs, + heap, + &self.codes, + &self.packed_codes, + &self.meta, + rq_epsilon, + rq_fast_scan, + ), + } + } + + pub fn rerank<'a, T: 'a>( + &'a self, + heap: Vec<(Reverse, u32)>, + r: impl Fn(u32) -> (F32, T) + 'a, + ) -> Box + 'a> { + use Quantizer::*; + match &*self.train { + Rabitq(x) => Box::new(x.rerank(heap, r)), + } + } +} diff --git a/crates/rabitq/src/quant/quantizer.rs b/crates/rabitq/src/quant/quantizer.rs new file mode 100644 index 000000000..3cc670ffa --- /dev/null +++ b/crates/rabitq/src/quant/quantizer.rs @@ -0,0 +1,207 @@ +use super::error_based::ErrorBasedReranker; +use crate::operator::OperatorRabitq; +use base::index::VectorOptions; +use base::scalar::F32; +use base::search::RerankerPop; +use num_traits::Float; +use serde::{Deserialize, Serialize}; +use std::cmp::Reverse; +use std::marker::PhantomData; +use std::ops::Range; + +pub const EPSILON: f32 = 1.9; +pub const THETA_LOG_DIM: u32 = 4; +pub const DEFAULT_X_DOT_PRODUCT: f32 = 0.8; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct RabitqQuantizer { + dims: u32, + _maker: PhantomData O>, +} + +impl RabitqQuantizer { + pub fn train(vector_options: VectorOptions) -> Self { + let dims = vector_options.dims; + Self { + dims, + _maker: PhantomData, + } + } + + pub fn bits(&self) -> u32 { + 1 + } + + pub fn bytes(&self) -> u32 { + self.dims.div_ceil(8) + } + + pub fn dims(&self) -> u32 { + self.dims + } + + pub fn width(&self) -> u32 { + self.dims + } + + pub fn encode(&self, vector: &[F32]) -> (F32, F32, F32, F32, Vec) { + let dis_u = vector.iter().map(|&x| x * x).sum::().sqrt(); + let sum_of_abs_x = vector.iter().map(|x| x.abs()).sum::(); + let sum_of_x_2 = vector.iter().map(|&x| x * x).sum::(); + let x0 = sum_of_abs_x / (sum_of_x_2 * F32(self.dims as _)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = F32(self.dims as f32).sqrt(); + let max_x1 = F32(1.0) / F32((self.dims as f32 - 1.0).sqrt()); + let factor_err = F32(2.0) * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = F32(-2.0) / fac_norm * x_x0; + let factor_ppc = factor_ip * vector.iter().map(|x| x.signum()).sum::(); + let mut codes = Vec::new(); + for i in 0..self.dims { + codes.push(vector[i as usize].is_sign_positive() as u8); + } + (dis_u * dis_u, factor_ppc, factor_ip, factor_err, codes) + } + + pub fn preprocess( + &self, + lhs: &[F32], + ) -> (O::QuantizationPreprocessed0, O::QuantizationPreprocessed1) { + O::rabitq_quantization_preprocess(lhs) + } + + pub fn process( + &self, + p0: &O::QuantizationPreprocessed0, + p1: &O::QuantizationPreprocessed1, + (a, b, c, d, e): (F32, F32, F32, F32, &[u8]), + ) -> F32 { + let (est, _) = O::rabitq_quantization_process(a, b, c, d, e, p0, p1); + est + } + + pub fn process_lowerbound( + &self, + p0: &O::QuantizationPreprocessed0, + p1: &O::QuantizationPreprocessed1, + (a, b, c, d, e): (F32, F32, F32, F32, &[u8]), + epsilon: F32, + ) -> F32 { + let (est, err) = O::rabitq_quantization_process(a, b, c, d, e, p0, p1); + est - err * epsilon + } + + pub fn push_batch( + &self, + (p0, p1): &(O::QuantizationPreprocessed0, O::QuantizationPreprocessed1), + rhs: Range, + heap: &mut Vec<(Reverse, u32)>, + codes: &[u8], + packed_codes: &[u8], + meta: &[F32], + epsilon: F32, + fast_scan: bool, + ) { + if fast_scan && O::SUPPORT_FAST_SCAN && quantization::fast_scan::b4::is_supported() { + use quantization::fast_scan::b4::{fast_scan, BLOCK_SIZE}; + let s = rhs.start.next_multiple_of(BLOCK_SIZE); + let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); + heap.extend((rhs.start..s).map(|u| { + ( + Reverse(self.process_lowerbound( + p0, + p1, + { + let bytes = self.bytes() as usize; + let start = u as usize * bytes; + let end = start + bytes; + let a = meta[4 * u as usize + 0]; + let b = meta[4 * u as usize + 1]; + let c = meta[4 * u as usize + 2]; + let d = meta[4 * u as usize + 3]; + (a, b, c, d, &codes[start..end]) + }, + epsilon, + )), + u, + ) + })); + let lut = O::fast_scan(p1); + for i in (s..e).step_by(BLOCK_SIZE as _) { + let t = self.dims.div_ceil(4); + let bytes = (t * 16) as usize; + let start = (i / BLOCK_SIZE) as usize * bytes; + let end = start + bytes; + heap.extend({ + let res = fast_scan(t, &packed_codes[start..end], &lut); + (i..i + BLOCK_SIZE) + .map(|u| { + ( + Reverse({ + let a = meta[4 * u as usize + 0]; + let b = meta[4 * u as usize + 1]; + let c = meta[4 * u as usize + 2]; + let d = meta[4 * u as usize + 3]; + let param = res[(u - i) as usize]; + let (est, err) = + O::rabitq_quantization_process_1(a, b, c, d, p0, param); + est - err * epsilon + }), + u, + ) + }) + .collect::>() + }); + } + heap.extend((e..rhs.end).map(|u| { + ( + Reverse(self.process_lowerbound( + p0, + p1, + { + let bytes = self.bytes() as usize; + let start = u as usize * bytes; + let end = start + bytes; + let a = meta[4 * u as usize + 0]; + let b = meta[4 * u as usize + 1]; + let c = meta[4 * u as usize + 2]; + let d = meta[4 * u as usize + 3]; + (a, b, c, d, &codes[start..end]) + }, + epsilon, + )), + u, + ) + })); + return; + } + heap.extend(rhs.map(|u| { + ( + Reverse(self.process_lowerbound( + p0, + p1, + { + let bytes = self.bytes() as usize; + let start = u as usize * bytes; + let end = start + bytes; + let a = meta[4 * u as usize + 0]; + let b = meta[4 * u as usize + 1]; + let c = meta[4 * u as usize + 2]; + let d = meta[4 * u as usize + 3]; + (a, b, c, d, &codes[start..end]) + }, + epsilon, + )), + u, + ) + })); + } + + pub fn rerank<'a, T: 'a>( + &'a self, + heap: Vec<(Reverse, u32)>, + r: impl Fn(u32) -> (F32, T) + 'a, + ) -> impl RerankerPop + 'a { + ErrorBasedReranker::new(heap, r) + } +} diff --git a/src/gucs/executing.rs b/src/gucs/executing.rs index bb20bf0e6..d4f9391a3 100644 --- a/src/gucs/executing.rs +++ b/src/gucs/executing.rs @@ -9,8 +9,6 @@ static FLAT_PQ_RERANK_SIZE: GucSetting = GucSetting::::new(0); static FLAT_PQ_FAST_SCAN: GucSetting = GucSetting::::new(false); -static FLAT_RQ_FAST_SCAN: GucSetting = GucSetting::::new(true); - static IVF_SQ_RERANK_SIZE: GucSetting = GucSetting::::new(0); static IVF_SQ_FAST_SCAN: GucSetting = GucSetting::::new(false); @@ -19,12 +17,14 @@ static IVF_PQ_RERANK_SIZE: GucSetting = GucSetting::::new(0); static IVF_PQ_FAST_SCAN: GucSetting = GucSetting::::new(false); -static IVF_RQ_FAST_SCAN: GucSetting = GucSetting::::new(true); - static IVF_NPROBE: GucSetting = GucSetting::::new(10); static HNSW_EF_SEARCH: GucSetting = GucSetting::::new(100); +static RABITQ_NPROBE: GucSetting = GucSetting::::new(10); + +static RABITQ_FAST_SCAN: GucSetting = GucSetting::::new(true); + static DISKANN_EF_SEARCH: GucSetting = GucSetting::::new(100); pub unsafe fn init() { @@ -64,14 +64,6 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); - GucRegistry::define_bool_guc( - "vectors.flat_rq_fast_scan", - "Enables fast scan or not.", - "https://docs.pgvecto.rs/usage/search.html", - &FLAT_RQ_FAST_SCAN, - GucContext::Userset, - GucFlags::default(), - ); GucRegistry::define_int_guc( "vectors.ivf_sq_rerank_size", "Scalar quantization reranker size.", @@ -108,14 +100,6 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); - GucRegistry::define_bool_guc( - "vectors.ivf_rq_fast_scan", - "Enables fast scan or not.", - "https://docs.pgvecto.rs/usage/search.html", - &IVF_RQ_FAST_SCAN, - GucContext::Userset, - GucFlags::default(), - ); GucRegistry::define_int_guc( "vectors.ivf_nprobe", "`nprobe` argument of IVF algorithm.", @@ -136,6 +120,24 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_int_guc( + "vectors.rabitq_nprobe", + "`nprobe` argument of RaBitQ algorithm.", + "https://docs.pgvecto.rs/usage/search.html", + &RABITQ_NPROBE, + 1, + u16::MAX as _, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_bool_guc( + "vectors.rabitq_fast_scan", + "Enables fast scan or not.", + "https://docs.pgvecto.rs/usage/search.html", + &RABITQ_FAST_SCAN, + GucContext::Userset, + GucFlags::default(), + ); GucRegistry::define_int_guc( "vectors.diskann_ef_search", "`ef_search` argument of DiskANN algorithm.", @@ -160,6 +162,8 @@ pub fn search_options() -> SearchOptions { ivf_pq_fast_scan: IVF_PQ_FAST_SCAN.get(), ivf_nprobe: IVF_NPROBE.get() as u32, hnsw_ef_search: HNSW_EF_SEARCH.get() as u32, + rabitq_nprobe: RABITQ_NPROBE.get() as u32, + rabitq_fast_scan: RABITQ_FAST_SCAN.get(), diskann_ef_search: DISKANN_EF_SEARCH.get() as u32, } }