Skip to content

Commit 9171b14

Browse files
committed
feat: add metrics dot and cos
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
1 parent b7e1a7a commit 9171b14

File tree

6 files changed

+576
-181
lines changed

6 files changed

+576
-181
lines changed

crates/k_means/src/lib.rs

+37
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ pub fn k_means_lookup<S: ScalarLike>(vector: &[S], centroids: &Vec2<S>) -> usize
6161
result.1
6262
}
6363

64+
/// returns (centroid_dot_dis, (vector_l2_norm, centroids_l2_norm, index))
65+
pub fn k_means_lookup_by_dot<S: ScalarLike>(
66+
vector: &[S],
67+
centroids: &Vec2<S>,
68+
) -> (f32, f32, f32, usize) {
69+
assert_ne!(centroids.shape_0(), 0);
70+
let mut result = (f32::INFINITY, f32::INFINITY, f32::zero(), f32::zero(), 0);
71+
for i in 0..centroids.shape_0() {
72+
let dot = S::reduce_sum_of_xy(vector, &centroids[(i,)]);
73+
let vector_square = S::reduce_sum_of_x2(vector);
74+
let centroids_square = S::reduce_sum_of_x2(&centroids[(i,)]);
75+
76+
let l2_dis = vector_square + centroids_square - 2.0 * dot;
77+
if l2_dis <= result.0 {
78+
result = (l2_dis, -dot, vector_square, centroids_square, i);
79+
}
80+
}
81+
(result.1, result.2, result.3, result.4)
82+
}
83+
6484
pub fn k_means_lookup_many<S: ScalarLike>(vector: &[S], centroids: &Vec2<S>) -> Vec<(f32, usize)> {
6585
assert_ne!(centroids.shape_0(), 0);
6686
let mut seq = Vec::new();
@@ -71,6 +91,23 @@ pub fn k_means_lookup_many<S: ScalarLike>(vector: &[S], centroids: &Vec2<S>) ->
7191
seq
7292
}
7393

94+
/// returns Vec of <l2_dis, (centroid_dot_dis, vector_l2_norm, centroids_l2_norm, index)>
95+
pub fn k_means_lookup_many_by_dot<S: ScalarLike>(
96+
vector: &[S],
97+
centroids: &Vec2<S>,
98+
) -> Vec<(f32, (f32, f32, f32, usize))> {
99+
assert_ne!(centroids.shape_0(), 0);
100+
let mut seq = Vec::new();
101+
for i in 0..centroids.shape_0() {
102+
let dot = S::reduce_sum_of_xy(vector, &centroids[(i,)]);
103+
let vector_square = S::reduce_sum_of_x2(vector);
104+
let centroids_square = S::reduce_sum_of_x2(&centroids[(i,)]);
105+
let l2_dis = vector_square + centroids_square - 2.0 * dot;
106+
seq.push((l2_dis, (-dot, vector_square, centroids_square, i)));
107+
}
108+
seq
109+
}
110+
74111
fn spherical<S: ScalarLike>(vector: &mut [S]) {
75112
let l = S::reduce_sum_of_x2(vector).sqrt();
76113
S::vector_mul_scalar_inplace(vector, 1.0 / l);

crates/quantization/src/quantize.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,13 @@ mod mul_add {
152152
#[inline(always)]
153153
pub fn quantize<const N: u8>(lut: &[f32]) -> (f32, f32, Vec<u8>) {
154154
let (min, max) = f32::reduce_min_max_of_x(lut);
155-
let k = 0.0f32.max((max - min) / (N as f32));
156-
let b = min;
157-
(k, b, mul_add::mul_add(lut, 1.0 / k, -b / k))
155+
let delta = 0.0f32.max((max - min) / (N as f32));
156+
let lower_bound = min;
157+
(
158+
delta,
159+
lower_bound,
160+
mul_add::mul_add(lut, 1.0 / delta, -lower_bound / delta),
161+
)
158162
}
159163

160164
#[inline(always)]

crates/rabitq/src/lib.rs

+18-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use common::json::Json;
1717
use common::mmap_array::MmapArray;
1818
use common::remap::RemappedCollection;
1919
use common::vec2::Vec2;
20-
use k_means::{k_means, k_means_lookup, k_means_lookup_many};
20+
use k_means::{k_means, k_means_lookup, k_means_lookup_by_dot, k_means_lookup_many_by_dot};
2121
use std::fs::create_dir;
2222
use std::path::Path;
2323
use stoppable_rayon as rayon;
@@ -69,14 +69,18 @@ impl<O: Op> Rabitq<O> {
6969
) -> Box<dyn Iterator<Item = Element> + 'a> {
7070
let projected_query = O::proj(&self.projection, O::cast(vector));
7171
let lists = select(
72-
k_means_lookup_many(&projected_query, &self.centroids),
72+
k_means_lookup_many_by_dot(&projected_query, &self.centroids),
7373
opts.rabitq_nprobe as usize,
7474
);
7575
let mut heap = Vec::new();
76-
for &(_, i) in lists.iter() {
77-
let preprocessed = self
78-
.quantization
79-
.preprocess(&O::residual(&projected_query, &self.centroids[(i,)]));
76+
for &(_, payload) in lists.iter() {
77+
let (centroid_dot_dis, original_square, centroids_square, i) = payload;
78+
let preprocessed = self.quantization.preprocess(
79+
&O::residual(&projected_query, &self.centroids[(i,)]),
80+
centroid_dot_dis,
81+
original_square,
82+
centroids_square,
83+
);
8084
let start = self.offsets[i];
8185
let end = self.offsets[i + 1];
8286
self.quantization.push_batch(
@@ -150,8 +154,13 @@ fn from_nothing<O: Op>(
150154
collection.len(),
151155
|vector| {
152156
let vector = O::cast(collection.vector(vector));
153-
let target = k_means_lookup(vector, &centroids);
154-
O::proj(&projection, &O::residual(vector, &centroids[(target,)]))
157+
let (centroid_dot_dis, original_square, _, target) =
158+
k_means_lookup_by_dot(vector, &centroids);
159+
(
160+
O::proj(&projection, &O::residual(vector, &centroids[(target,)])),
161+
centroid_dot_dis,
162+
original_square,
163+
)
155164
},
156165
);
157166
let projected_centroids = Vec2::from_vec(
@@ -194,7 +203,7 @@ fn open<O: Op>(path: impl AsRef<Path>) -> Rabitq<O> {
194203
}
195204
}
196205

197-
fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> {
206+
fn select<T>(mut lists: Vec<(f32, T)>, n: usize) -> Vec<(f32, T)> {
198207
if lists.is_empty() || n == 0 {
199208
return Vec::new();
200209
}

0 commit comments

Comments
 (0)