Skip to content

Commit 99dc755

Browse files
committed
feat: add metrics dot and cos
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
1 parent 4bac484 commit 99dc755

File tree

7 files changed

+713
-220
lines changed

7 files changed

+713
-220
lines changed

crates/base/src/index.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ impl IndexOptions {
149149
}
150150
}
151151
IndexingOptions::Rabitq(_) => {
152-
if !matches!(self.vector.d, DistanceKind::L2) {
152+
if !matches!(self.vector.d, DistanceKind::L2 | DistanceKind::Dot) {
153153
return Err(ValidationError::new(
154-
"rabitq is not support for distance that is not l2",
154+
"rabitq is not support for distance that is not l2 or dot",
155155
));
156156
}
157157
if !matches!(self.vector.v, VectorKind::Vecf32) {

crates/k_means/src/lib.rs

+51
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@ pub fn k_means_lookup<S: ScalarLike>(vector: &[S], centroids: &Vec2<S>) -> usize
7171
result.1
7272
}
7373

74+
/// returns (centroid_dot_dis, index)
75+
pub fn k_means_lookup_by_dot<S: ScalarLike>(
76+
vector: &[S],
77+
centroids: &Vec2<S>,
78+
centroids_square: &[f32],
79+
) -> (f32, usize) {
80+
assert_ne!(centroids.shape_0(), 0);
81+
let vector_square = S::reduce_sum_of_x2(vector);
82+
let mut result = (f32::INFINITY, f32::INFINITY, 0);
83+
84+
for i in 0..centroids.shape_0() {
85+
let centroid_square = centroids_square[i];
86+
let dot = S::reduce_sum_of_xy(vector, &centroids[(i,)]);
87+
let l2_dis = vector_square + centroid_square - 2.0 * dot;
88+
if l2_dis <= result.0 {
89+
result = (l2_dis, -dot, i);
90+
}
91+
}
92+
(result.1, result.2)
93+
}
94+
7495
pub fn k_means_lookup_many<S: ScalarLike>(vector: &[S], centroids: &Vec2<S>) -> Vec<(f32, usize)> {
7596
assert_ne!(centroids.shape_0(), 0);
7697
let mut seq = Vec::new();
@@ -80,3 +101,33 @@ pub fn k_means_lookup_many<S: ScalarLike>(vector: &[S], centroids: &Vec2<S>) ->
80101
}
81102
seq
82103
}
104+
105+
/// returns Vec of <l2_dis, (centroid_dot_dis, vector_l2_norm, centroids_l2_norm, index)>
106+
pub fn k_means_lookup_many_by_dot<S: ScalarLike>(
107+
vector: &[S],
108+
centroids: &Vec2<S>,
109+
centroids_square: &[f32],
110+
) -> Vec<(f32, (f32, f32, f32, usize))> {
111+
assert_ne!(centroids.shape_0(), 0);
112+
let vector_square = S::reduce_sum_of_x2(vector);
113+
let mut seq = Vec::new();
114+
115+
for i in 0..centroids.shape_0() {
116+
let centroid_square = centroids_square[i];
117+
let dot = S::reduce_sum_of_xy(vector, &centroids[(i,)]);
118+
let l2_dis = vector_square + centroid_square - 2.0 * dot;
119+
seq.push((l2_dis, (-dot, vector_square, centroid_square, i)));
120+
}
121+
seq
122+
}
123+
124+
pub fn centroids_square<S: ScalarLike>(centroids: &Vec2<S>) -> Vec<f32> {
125+
assert_ne!(centroids.shape_0(), 0);
126+
let mut seq = Vec::new();
127+
128+
for i in 0..centroids.shape_0() {
129+
let centroids_square = S::reduce_sum_of_x2(&centroids[(i,)]);
130+
seq.push(centroids_square);
131+
}
132+
seq
133+
}

crates/quantization/src/quantize.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,13 @@ mod mul_add_round {
152152
#[inline(always)]
153153
pub fn quantize<const N: u8>(lut: &[f32]) -> (f32, f32, Vec<u8>) {
154154
let (min, max) = f32::reduce_min_max_of_x(lut);
155-
let k = 0.0f32.max((max - min) / (N as f32));
156-
let b = min;
157-
(k, b, mul_add_round::mul_add_round(lut, 1.0 / k, -b / k))
155+
let delta = 0.0f32.max((max - min) / (N as f32));
156+
let lower_bound = min;
157+
(
158+
delta,
159+
lower_bound,
160+
mul_add_round::mul_add_round(lut, 1.0 / delta, -lower_bound / delta),
161+
)
158162
}
159163

160164
#[inline(always)]

crates/rabitq/src/lib.rs

+28-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ use common::json::Json;
1717
use common::mmap_array::MmapArray;
1818
use common::remap::RemappedCollection;
1919
use common::vec2::Vec2;
20-
use k_means::{k_means, k_means_lookup, k_means_lookup_many};
20+
use k_means::{
21+
centroids_square, k_means, k_means_lookup, k_means_lookup_by_dot, k_means_lookup_many_by_dot,
22+
};
2123
use rayon::iter::{IntoParallelIterator, ParallelIterator};
2224
use std::fs::create_dir;
2325
use std::path::Path;
@@ -30,6 +32,7 @@ pub struct Rabitq<O: Op> {
3032
payloads: MmapArray<Payload>,
3133
offsets: Json<Vec<u32>>,
3234
projected_centroids: Json<Vec2<f32>>,
35+
centroids_square: Json<Vec<f32>>,
3336
projection: Json<Vec<Vec<f32>>>,
3437
}
3538

@@ -70,15 +73,18 @@ impl<O: Op> Rabitq<O> {
7073
) -> Box<dyn Iterator<Item = Element> + 'a> {
7174
let projected_query = O::proj(&self.projection, O::cast(vector));
7275
let lists = select(
73-
k_means_lookup_many(&projected_query, &self.projected_centroids),
76+
k_means_lookup_many_by_dot(&projected_query, &self.centroids, &self.centroids_square),
7477
opts.rabitq_nprobe as usize,
7578
);
7679
let mut heap = Vec::new();
77-
for &(_, i) in lists.iter() {
78-
let preprocessed = self.quantization.preprocess(&O::residual(
80+
for &(_, payload) in lists.iter() {
81+
let (centroid_dot_dis, original_square, centroids_square, i) = payload;
82+
let preprocessed = self.quantization.preprocess(
7983
&projected_query,
80-
&self.projected_centroids[(i,)],
81-
));
84+
centroid_dot_dis,
85+
original_square,
86+
centroids_square,
87+
);
8288
let start = self.offsets[i];
8389
let end = self.offsets[i + 1];
8490
self.quantization.push_batch(
@@ -135,6 +141,7 @@ fn from_nothing<O: Op>(
135141
let samples = O::sample(collection, nlist);
136142
rayon::check();
137143
let centroids: Vec2<f32> = k_means(nlist as usize, samples, true, spherical_centroids, false);
144+
let centroids_squares = centroids_square(&centroids);
138145
rayon::check();
139146
let ls = (0..collection.len())
140147
.into_par_iter()
@@ -174,8 +181,12 @@ fn from_nothing<O: Op>(
174181
collection.len(),
175182
|vector| {
176183
let vector = O::cast(collection.vector(vector));
177-
let target = k_means_lookup(vector, &centroids);
178-
O::proj(&projection, &O::residual(vector, &centroids[(target,)]))
184+
let (centroid_dot_dis, target) =
185+
k_means_lookup_by_dot(vector, &centroids, &centroids_squares);
186+
(
187+
O::proj(&projection, &O::residual(vector, &centroids[(target,)])),
188+
centroid_dot_dis,
189+
)
179190
},
180191
);
181192
let projected_centroids = Vec2::from_vec(
@@ -184,6 +195,7 @@ fn from_nothing<O: Op>(
184195
.flat_map(|x| O::proj(&projection, &centroids[(x,)]))
185196
.collect(),
186197
);
198+
let projected_centroids_square = centroids_square(&projected_centroids);
187199
let payloads = MmapArray::create(
188200
path.as_ref().join("payloads"),
189201
(0..collection.len()).map(|i| collection.payload(i)),
@@ -193,12 +205,17 @@ fn from_nothing<O: Op>(
193205
path.as_ref().join("projected_centroids"),
194206
projected_centroids,
195207
);
208+
let centroids_square = Json::create(
209+
path.as_ref().join("centroids_square"),
210+
projected_centroids_square,
211+
);
196212
let projection = Json::create(path.as_ref().join("projection"), projection);
197213
Rabitq {
198214
storage,
199215
payloads,
200216
offsets,
201217
projected_centroids,
218+
centroids_square,
202219
quantization,
203220
projection,
204221
}
@@ -210,18 +227,20 @@ fn open<O: Op>(path: impl AsRef<Path>) -> Rabitq<O> {
210227
let payloads = MmapArray::open(path.as_ref().join("payloads"));
211228
let offsets = Json::open(path.as_ref().join("offsets"));
212229
let projected_centroids = Json::open(path.as_ref().join("projected_centroids"));
230+
let centroids_square = Json::open(path.as_ref().join("centroids_square"));
213231
let projection = Json::open(path.as_ref().join("projection"));
214232
Rabitq {
215233
storage,
216234
quantization,
217235
payloads,
218236
offsets,
219237
projected_centroids,
238+
centroids_square,
220239
projection,
221240
}
222241
}
223242

224-
fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> {
243+
fn select<T>(mut lists: Vec<(f32, T)>, n: usize) -> Vec<(f32, T)> {
225244
if lists.is_empty() || n == 0 {
226245
return Vec::new();
227246
}

0 commit comments

Comments
 (0)