Skip to content

Commit

Permalink
fix: avoid divide-by-zero when training an index with a large dimensi…
Browse files Browse the repository at this point in the history
…on (#3426)
  • Loading branch information
westonpace authored Jan 31, 2025
1 parent a7c5216 commit c58814a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 30 deletions.
6 changes: 3 additions & 3 deletions rust/lance-io/src/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ impl ObjectStore {
Self {
inner: Arc::new(InMemory::new()).traced(),
scheme: String::from("memory"),
block_size: 64 * 1024,
block_size: 4 * 1024,
use_constant_size_upload_parts: false,
list_is_lexically_ordered: true,
io_parallelism: get_num_compute_intensive_cpus(),
Expand Down Expand Up @@ -977,7 +977,7 @@ async fn configure_store(
"memory" => Ok(ObjectStore {
inner: Arc::new(InMemory::new()).traced(),
scheme: String::from("memory"),
block_size: cloud_block_size,
block_size: file_block_size,
use_constant_size_upload_parts: false,
list_is_lexically_ordered: true,
io_parallelism: get_num_compute_intensive_cpus(),
Expand Down Expand Up @@ -1219,7 +1219,6 @@ mod tests {
#[rstest]
#[case("s3://bucket/foo.lance", None)]
#[case("gs://bucket/foo.lance", None)]
#[case("memory:///bucket/foo.lance", None)]
#[case("az://account/bucket/foo.lance",
Some(HashMap::from([
(String::from("account_name"), String::from("account")),
Expand All @@ -1236,6 +1235,7 @@ mod tests {
#[rstest]
#[case("file")]
#[case("file-object-store")]
#[case("memory:///bucket/foo.lance")]
#[tokio::test]
async fn test_block_size_used_file(#[case] prefix: &str) {
let tmp_dir = tempfile::tempdir().unwrap();
Expand Down
114 changes: 88 additions & 26 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2223,42 +2223,102 @@ mod tests {
.await;
}

struct TestPqParams {
num_sub_vectors: usize,
num_bits: usize,
}

impl TestPqParams {
fn small() -> Self {
Self {
num_sub_vectors: 2,
num_bits: 8,
}
}
}

// Clippy doesn't like that all start with Ivf but we might have some in the future
// that _don't_ start with Ivf so I feel it is meaningful to keep the prefix
#[allow(clippy::enum_variant_names)]
enum TestIndexType {
IvfPq { pq: TestPqParams },
IvfHnswPq { pq: TestPqParams, num_edges: usize },
IvfHnswSq { num_edges: usize },
IvfFlat,
}

struct CreateIndexCase {
metric_type: MetricType,
num_partitions: usize,
dimension: usize,
index_type: TestIndexType,
}

// We test L2 and Dot, because L2 PQ uses residuals while Dot doesn't,
// so they have slightly different code paths.
#[tokio::test]
#[rstest]
#[case::ivf_pq_l2(VectorIndexParams::with_ivf_pq_params(
MetricType::L2,
IvfBuildParams::new(2),
PQBuildParams::new(2, 8),
))]
#[case::ivf_pq_dot(VectorIndexParams::with_ivf_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
PQBuildParams::new(2, 8),
))]
#[case::ivf_flat(VectorIndexParams::ivf_flat(1, MetricType::Dot))]
#[case::ivf_hnsw_pq(VectorIndexParams::with_ivf_hnsw_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default().num_edges(100),
PQBuildParams::new(2, 8)
))]
#[case::ivf_hnsw_sq(VectorIndexParams::with_ivf_hnsw_sq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default().num_edges(100),
SQBuildParams::default()
))]
#[case::ivf_pq_l2(CreateIndexCase {
metric_type: MetricType::L2,
num_partitions: 2,
dimension: 16,
index_type: TestIndexType::IvfPq { pq: TestPqParams::small() },
})]
#[case::ivf_pq_dot(CreateIndexCase {
metric_type: MetricType::Dot,
num_partitions: 2,
dimension: 2000,
index_type: TestIndexType::IvfPq { pq: TestPqParams::small() },
})]
#[case::ivf_flat(CreateIndexCase { num_partitions: 1, metric_type: MetricType::Dot, dimension: 16, index_type: TestIndexType::IvfFlat })]
#[case::ivf_hnsw_pq(CreateIndexCase {
num_partitions: 2,
metric_type: MetricType::Dot,
dimension: 16,
index_type: TestIndexType::IvfHnswPq { pq: TestPqParams::small(), num_edges: 100 },
})]
#[case::ivf_hnsw_sq(CreateIndexCase {
metric_type: MetricType::Dot,
num_partitions: 2,
dimension: 16,
index_type: TestIndexType::IvfHnswSq { num_edges: 100 },
})]
async fn test_create_index_nulls(
#[case] mut index_params: VectorIndexParams,
#[case] test_case: CreateIndexCase,
#[values(IndexFileVersion::Legacy, IndexFileVersion::V3)] index_version: IndexFileVersion,
) {
let mut index_params = match test_case.index_type {
TestIndexType::IvfPq { pq } => VectorIndexParams::with_ivf_pq_params(
test_case.metric_type,
IvfBuildParams::new(test_case.num_partitions),
PQBuildParams::new(pq.num_sub_vectors, pq.num_bits),
),
TestIndexType::IvfHnswPq { pq, num_edges } => {
VectorIndexParams::with_ivf_hnsw_pq_params(
test_case.metric_type,
IvfBuildParams::new(test_case.num_partitions),
HnswBuildParams::default().num_edges(num_edges),
PQBuildParams::new(pq.num_sub_vectors, pq.num_bits),
)
}
TestIndexType::IvfFlat => {
VectorIndexParams::ivf_flat(test_case.num_partitions, test_case.metric_type)
}
TestIndexType::IvfHnswSq { num_edges } => VectorIndexParams::with_ivf_hnsw_sq_params(
test_case.metric_type,
IvfBuildParams::new(test_case.num_partitions),
HnswBuildParams::default().num_edges(num_edges),
SQBuildParams::default(),
),
};
index_params.version(index_version);

let nrows = 2_000;
let data = gen()
.col("vec", array::rand_vec::<Float32Type>(Dimension::from(16)))
.col(
"vec",
array::rand_vec::<Float32Type>(Dimension::from(test_case.dimension as u32)),
)
.into_batch_rows(RowCount::from(nrows))
.unwrap();

Expand Down Expand Up @@ -2287,7 +2347,9 @@ mod tests {
.await
.unwrap();

let query = vec![0.0; 16].into_iter().collect::<Float32Array>();
let query = vec![0.0; test_case.dimension]
.into_iter()
.collect::<Float32Array>();
let results = dataset
.scan()
.nearest("vec", &query, 2_000)
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ fn random_ranges(
block_size: usize,
byte_width: usize,
) -> impl Iterator<Item = std::ops::Range<u64>> + Send {
let rows_per_batch = block_size / byte_width;
let rows_per_batch = 1.max(block_size / byte_width);
let mut rng = SmallRng::from_entropy();
let num_bins = num_rows.div_ceil(rows_per_batch);

Expand Down

0 comments on commit c58814a

Please sign in to comment.