diff --git a/src/datasets/search.py b/src/datasets/search.py index 2f3febe7bfd0..bef09f9cfad4 100644 --- a/src/datasets/search.py +++ b/src/datasets/search.py @@ -6,6 +6,7 @@ import numpy as np +from .features import Sequence from .utils import logging @@ -262,6 +263,11 @@ def add_vectors( """ import faiss # noqa: F811 + if column and not isinstance(vectors.features[column], Sequence): + raise ValueError( + f"Wrong feature type for column '{column}'. " f"Expected 1d array, got {vectors.features[column]}" + ) + # Create index if self.faiss_index is None: size = len(vectors[0]) if column is None else len(vectors[0][column]) diff --git a/tests/test_search.py b/tests/test_search.py index f44d80a57124..c4a9b80be338 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -34,6 +34,13 @@ def test_add_faiss_index(self): self.assertEqual(examples["filename"][0], "my_name-train_29") dset.drop_index("vecs") + def test_add_faiss_index_errors(self): + import faiss + + dset: Dataset = self._create_dummy_dataset() + with pytest.raises(ValueError, match="Wrong feature type for column 'filename'"): + _ = dset.add_faiss_index("filename", batch_size=100, metric_type=faiss.METRIC_INNER_PRODUCT) + def test_add_faiss_index_from_external_arrays(self): import faiss