diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index e13ecb3672d..57a8cdff306 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -5,8 +5,10 @@ from chromadb.db.system import SysDB from chromadb.quota import QuotaEnforcer from chromadb.rate_limit import RateLimitEnforcer -from chromadb.segment import SegmentManager, MetadataReader, VectorReader +from chromadb.segment import SegmentManager from chromadb.execution.executor.abstract import Executor +from chromadb.execution.expression.operator import Scan, Filter, Limit, KNN, Projection +from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.telemetry.opentelemetry import ( add_attributes_to_current_span, OpenTelemetryClient, @@ -23,18 +25,16 @@ VersionMismatchError, ) from chromadb.api.types import ( - URI, CollectionMetadata, - Document, IDs, Embeddings, - Embedding, Metadatas, Documents, URIs, Where, WhereDocument, Include, + IncludeEnum, GetResult, QueryResult, validate_metadata, @@ -58,8 +58,6 @@ Sequence, Generator, List, - cast, - Set, Any, Callable, TypeVar, @@ -500,11 +498,6 @@ def _get( ) coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) - where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) @@ -512,8 +505,6 @@ def _get( else None ) - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) - if sort is not None: raise NotImplementedError("Sorting is not yet supported") @@ -521,47 +512,6 @@ def _get( offset = (page - 1) * page_size limit = page_size - records = metadata_segment.get_metadata( - where=where, - where_document=where_document, - ids=ids, - limit=limit, - offset=offset, - request_version_context=request_version_context, - ) - - if len(records) == 0: - # Nothing to return if there are no records - return GetResult( - ids=[], - embeddings=[] if "embeddings" in include else None, - metadatas=[] if "metadatas" in include else None, - documents=[] if "documents" in include else None, - uris=[] if "uris" in include else None, - data=[] if "data" in include else None, - included=include, - ) - - vectors: Sequence[t.VectorEmbeddingRecord] = [] - if "embeddings" in include: - vector_ids = [r["id"] for r in records] - vector_segment = self._manager.get_segment(collection_id, VectorReader) - vectors = vector_segment.get_vectors( - ids=vector_ids, request_version_context=request_version_context - ) - - # TODO: Fix type so we don't need to ignore - # It is possible to have a set of records, some with metadata and some without - # Same with documents - - metadatas = [r["metadata"] for r in records] - - if "documents" in include: - documents = [_doc(m) for m in metadatas] - - if "uris" in include: - uris = [_uri(m) for m in metadatas] - ids_amount = len(ids) if ids else 0 self._product_telemetry_client.capture( CollectionGetEvent( @@ -574,18 +524,19 @@ def _get( ) ) - return GetResult( - ids=[r["id"] for r in records], - embeddings=[r["embedding"] for r in vectors] - if "embeddings" in include - else None, - metadatas=_clean_metadatas(metadatas) - if "metadatas" in include - else None, # type: ignore - documents=documents if "documents" in include else None, # type: ignore - uris=uris if "uris" in include else None, # type: ignore - data=None, - included=include, + return self._executor.get( + GetPlan( + Scan(coll), + Filter(ids, where, where_document), + Limit(offset or 0, limit), + Projection( + IncludeEnum.documents in include, + IncludeEnum.embeddings in include, + IncludeEnum.metadatas in include, + False, + IncludeEnum.uris in include, + ), + ) ) @trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION) @@ -630,21 +581,12 @@ def _delete( ) coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) self._manager.hint_use_collection(collection_id, t.Operation.DELETE) if (where or where_document) or not ids: - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) - records = metadata_segment.get_metadata( - where=where, - where_document=where_document, - ids=ids, - request_version_context=request_version_context, - ) - ids_to_delete = [r["id"] for r in records] + ids_to_delete = self._executor.get( + GetPlan(Scan(coll), Filter(ids, where, where_document)) + )["ids"] else: ids_to_delete = ids @@ -674,13 +616,7 @@ def _delete( def _count(self, collection_id: UUID) -> int: add_attributes_to_current_span({"collection_id": str(collection_id)}) coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) - - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) - return metadata_segment.count(request_version_context) + return self._executor.count(CountPlan(Scan(coll))) @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) # We retry on version mismatch errors because the version of the collection @@ -736,110 +672,23 @@ def _query( else where_document ) - allowed_ids = None - coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) for embedding in query_embeddings: self._validate_dimension(coll, len(embedding), update=False) - if where or where_document: - metadata_reader = self._manager.get_segment(collection_id, MetadataReader) - records = metadata_reader.get_metadata( - where=where, - where_document=where_document, - include_metadata=False, - request_version_context=request_version_context, - ) - allowed_ids = [r["id"] for r in records] - - ids: List[List[str]] = [] - distances: List[List[float]] = [] - embeddings: List[Embeddings] = [] - documents: List[List[Document]] = [] - uris: List[List[URI]] = [] - metadatas: List[List[t.Metadata]] = [] - - # If where conditions returned empty list then no need to proceed - # further and can simply return an empty result set here. - if allowed_ids is not None and allowed_ids == []: - for em in range(len(query_embeddings)): - ids.append([]) - if "distances" in include: - distances.append([]) - if "embeddings" in include: - embeddings.append([]) - if "documents" in include: - documents.append([]) - if "metadatas" in include: - metadatas.append([]) - if "uris" in include: - uris.append([]) - else: - query = t.VectorQuery( - vectors=query_embeddings, - k=n_results, - allowed_ids=allowed_ids, - include_embeddings="embeddings" in include, - options=None, - request_version_context=request_version_context, + return self._executor.knn( + KNNPlan( + Scan(coll), + KNN(query_embeddings, n_results), + Filter(None, where, where_document), + Projection( + IncludeEnum.documents in include, + IncludeEnum.embeddings in include, + IncludeEnum.metadatas in include, + IncludeEnum.distances in include, + IncludeEnum.uris in include, + ), ) - - vector_reader = self._manager.get_segment(collection_id, VectorReader) - results = vector_reader.query_vectors(query) - - for result in results: - ids.append([r["id"] for r in result]) - if "distances" in include: - distances.append([r["distance"] for r in result]) - if "embeddings" in include: - embeddings.append([cast(Embedding, r["embedding"]) for r in result]) - - if "documents" in include or "metadatas" in include or "uris" in include: - all_ids: Set[str] = set() - for id_list in ids: - all_ids.update(id_list) - metadata_reader = self._manager.get_segment( - collection_id, MetadataReader - ) - records = metadata_reader.get_metadata( - ids=list(all_ids), - include_metadata=True, - request_version_context=request_version_context, - ) - metadata_by_id = {r["id"]: r["metadata"] for r in records} - for id_list in ids: - # In the segment based architecture, it is possible for one segment - # to have a record that another segment does not have. This results in - # data inconsistency. For the case of the local segments and the - # local segment manager, there is a case where a thread writes - # a record to the vector segment but not the metadata segment. - # Then a query'ing thread reads from the vector segment and - # queries the metadata segment. The metadata segment does not have - # the record. In this case we choose to return potentially - # incorrect data in the form of None. - metadata_list = [metadata_by_id.get(id, None) for id in id_list] - if "metadatas" in include: - metadatas.append(_clean_metadatas(metadata_list)) # type: ignore - if "documents" in include: - doc_list = [_doc(m) for m in metadata_list] - documents.append(doc_list) # type: ignore - if "uris" in include: - uri_list = [_uri(m) for m in metadata_list] - uris.append(uri_list) # type: ignore - - return QueryResult( - ids=ids, - distances=distances if distances else None, - metadatas=metadatas if metadatas else None, - embeddings=embeddings if embeddings else None, - documents=documents if documents else None, - uris=uris if uris else None, - data=None, - included=include, ) @trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION) @@ -960,41 +809,3 @@ def _records( operation=operation, ) yield record - - -def _doc(metadata: Optional[t.Metadata]) -> Optional[str]: - """Retrieve the document (if any) from a Metadata map""" - - if metadata and "chroma:document" in metadata: - return str(metadata["chroma:document"]) - return None - - -def _uri(metadata: Optional[t.Metadata]) -> Optional[str]: - """Retrieve the uri (if any) from a Metadata map""" - - if metadata and "chroma:uri" in metadata: - return str(metadata["chroma:uri"]) - return None - - -def _clean_metadatas( - metadata: List[Optional[t.Metadata]], -) -> List[Optional[t.Metadata]]: - """Remove any chroma-specific metadata keys that the client shouldn't see from a - list of metadata maps.""" - return [_clean_metadata(m) for m in metadata] - - -def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]: - """Remove any chroma-specific metadata keys that the client shouldn't see from a - metadata map.""" - if not metadata: - return None - result = {} - for k, v in metadata.items(): - if not k.startswith("chroma:"): - result[k] = v - if len(result) == 0: - return None - return result diff --git a/chromadb/ingest/impl/utils.py b/chromadb/ingest/impl/utils.py index 4ad92df6bc3..0d4eb57038f 100644 --- a/chromadb/ingest/impl/utils.py +++ b/chromadb/ingest/impl/utils.py @@ -3,7 +3,8 @@ from uuid import UUID from chromadb.db.base import SqlDB -from chromadb.segment import SegmentManager, VectorReader +from chromadb.segment import SegmentManager +from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment topic_regex = r"persistent:\/\/(?P.+)\/(?P.+)\/(?P.+)" @@ -46,4 +47,4 @@ def trigger_vector_segments_max_seq_id_migration( for collection_id in collection_ids_with_unmigrated_segments: # Loading the segment triggers the migration on init - segment_manager.get_segment(UUID(collection_id), VectorReader) + segment_manager.get_segment(UUID(collection_id), LocalHnswSegment) diff --git a/chromadb/segment/__init__.py b/chromadb/segment/__init__.py index 6f93d758926..9bc5c89f3b1 100644 --- a/chromadb/segment/__init__.py +++ b/chromadb/segment/__init__.py @@ -2,14 +2,8 @@ from abc import abstractmethod from chromadb.types import ( Collection, - MetadataEmbeddingRecord, Operation, RequestVersionContext, - VectorEmbeddingRecord, - Where, - WhereDocument, - VectorQuery, - VectorQueryResult, Segment, SeqId, Metadata, @@ -59,46 +53,6 @@ def delete(self) -> None: S = TypeVar("S", bound=SegmentImplementation) -class MetadataReader(SegmentImplementation): - """Embedding Metadata segment interface""" - - @abstractmethod - def get_metadata( - self, - request_version_context: RequestVersionContext, - where: Optional[Where] = None, - where_document: Optional[WhereDocument] = None, - ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - include_metadata: bool = True, - ) -> Sequence[MetadataEmbeddingRecord]: - """Query for embedding metadata.""" - pass - - -class VectorReader(SegmentImplementation): - """Embedding Vector segment interface""" - - @abstractmethod - def get_vectors( - self, - request_version_context: RequestVersionContext, - ids: Optional[Sequence[str]] = None, - ) -> Sequence[VectorEmbeddingRecord]: - """Get embeddings from the segment. If no IDs are provided, all embeddings are - returned.""" - pass - - @abstractmethod - def query_vectors( - self, query: VectorQuery - ) -> Sequence[Sequence[VectorQueryResult]]: - """Given a vector query, return the top-k nearest neighbors for vector in the - query.""" - pass - - class SegmentManager(Component): """Interface for a pluggable strategy for creating, retrieving and instantiating segments as required""" diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index e117ac2ecee..0cb2e10ac61 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -2,11 +2,11 @@ from chromadb.segment import ( SegmentImplementation, SegmentManager, - MetadataReader, SegmentType, - VectorReader, S, ) +from chromadb.segment.impl.metadata.grpc_segment import GrpcMetadataSegment +from chromadb.segment.impl.vector.grpc_segment import GrpcVectorSegment from chromadb.config import System, get_class from chromadb.db.system import SysDB from overrides import override @@ -84,9 +84,9 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: ) @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: - if type == MetadataReader: + if type == GrpcMetadataSegment: scope = SegmentScope.METADATA - elif type == VectorReader: + elif type == GrpcVectorSegment: scope = SegmentScope.VECTOR else: raise ValueError(f"Invalid segment type: {type}") diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index fe23f679186..0e1cca14ef4 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -2,11 +2,11 @@ from chromadb.segment import ( SegmentImplementation, SegmentManager, - MetadataReader, SegmentType, - VectorReader, S, ) +from chromadb.segment.impl.metadata.sqlite import SqliteMetadataSegment +from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment import logging from chromadb.segment.impl.manager.cache.cache import ( SegmentLRUCache, @@ -155,10 +155,10 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: for segment in segments: if segment["id"] in self._instances: if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value: - instance = self.get_segment(collection_id, VectorReader) + instance = self.get_segment(collection_id, LocalHnswSegment) instance.delete() elif segment["type"] == SegmentType.SQLITE.value: - instance = self.get_segment(collection_id, MetadataReader) # type: ignore[assignment] + instance = self.get_segment(collection_id, SqliteMetadataSegment) # type: ignore[assignment] instance.delete() del self._instances[segment["id"]] if segment["scope"] is SegmentScope.VECTOR: @@ -199,9 +199,9 @@ def _get_segment_sysdb(self, collection_id: UUID, scope: SegmentScope) -> Segmen ) @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: - if type == MetadataReader: + if type == SqliteMetadataSegment: scope = SegmentScope.METADATA - elif type == VectorReader: + elif type == LocalHnswSegment: scope = SegmentScope.VECTOR else: raise ValueError(f"Invalid segment type: {type}") @@ -225,12 +225,14 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S: def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: # The local segment manager responds to hints by pre-loading both the metadata and vector # segments for the given collection. - for type in [MetadataReader, VectorReader]: + for type in [SqliteMetadataSegment, LocalHnswSegment]: # Just use get_segment to load the segment into the cache instance = self.get_segment(collection_id, type) # If the segment is a vector segment, we need to keep segments in an LRU cache # to avoid hitting the OS file handle limit. - if type == VectorReader and self._system.settings.require("is_persistent"): + if type == LocalHnswSegment and self._system.settings.require( + "is_persistent" + ): instance = cast(PersistentLocalHnswSegment, instance) instance.open_persistent_index() self._vector_instances_file_handle_cache.set(collection_id, instance) diff --git a/chromadb/segment/impl/metadata/grpc_segment.py b/chromadb/segment/impl/metadata/grpc_segment.py index 53ffdc72734..81a29f29dd1 100644 --- a/chromadb/segment/impl/metadata/grpc_segment.py +++ b/chromadb/segment/impl/metadata/grpc_segment.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Sequence from chromadb.proto.convert import to_proto_request_version_context from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor -from chromadb.segment import MetadataReader +from chromadb.segment import SegmentImplementation from chromadb.config import System from chromadb.errors import InvalidArgumentError, VersionMismatchError from chromadb.types import Segment, RequestVersionContext @@ -21,7 +21,7 @@ import grpc -class GrpcMetadataSegment(MetadataReader): +class GrpcMetadataSegment(SegmentImplementation): """Embedding Metadata segment interface""" _request_timeout_seconds: int @@ -81,7 +81,6 @@ def max_seqid(self) -> int: "GrpcMetadataSegment.get_metadata", OpenTelemetryGranularity.ALL, ) - @override def get_metadata( self, request_version_context: RequestVersionContext, diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index d54dc95f342..52ce3758d4b 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -1,5 +1,5 @@ from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict, List -from chromadb.segment import MetadataReader +from chromadb.segment import SegmentImplementation from chromadb.ingest import Consumer from chromadb.config import System from chromadb.types import RequestVersionContext, Segment, InclusionExclusionOperator @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) -class SqliteMetadataSegment(MetadataReader): +class SqliteMetadataSegment(SegmentImplementation): _consumer: Consumer _db: SqliteDB _id: UUID @@ -109,7 +109,6 @@ def count(self, request_version_context: RequestVersionContext) -> int: return cast(int, result) @trace_method("SqliteMetadataSegment.get_metadata", OpenTelemetryGranularity.ALL) - @override def get_metadata( self, request_version_context: RequestVersionContext, diff --git a/chromadb/segment/impl/vector/grpc_segment.py b/chromadb/segment/impl/vector/grpc_segment.py index a66de4a71cd..f555d445e6b 100644 --- a/chromadb/segment/impl/vector/grpc_segment.py +++ b/chromadb/segment/impl/vector/grpc_segment.py @@ -8,7 +8,7 @@ to_proto_vector, ) from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor -from chromadb.segment import VectorReader +from chromadb.segment import SegmentImplementation from chromadb.errors import VersionMismatchError from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams from chromadb.telemetry.opentelemetry import ( @@ -35,7 +35,7 @@ import grpc -class GrpcVectorSegment(VectorReader, EnforceOverrides): +class GrpcVectorSegment(SegmentImplementation, EnforceOverrides): _vector_reader_stub: VectorReaderStub _segment: Segment _request_timeout_seconds: int @@ -56,7 +56,6 @@ def __init__(self, system: System, segment: Segment): ) @trace_method("GrpcVectorSegment.get_vectors", OpenTelemetryGranularity.ALL) - @override def get_vectors( self, request_version_context: RequestVersionContext, @@ -87,7 +86,6 @@ def get_vectors( return results @trace_method("GrpcVectorSegment.query_vectors", OpenTelemetryGranularity.ALL) - @override def query_vectors( self, query: VectorQuery ) -> Sequence[Sequence[VectorQueryResult]]: diff --git a/chromadb/segment/impl/vector/local_hnsw.py b/chromadb/segment/impl/vector/local_hnsw.py index 47cec8cf2fc..afeca0bc93c 100644 --- a/chromadb/segment/impl/vector/local_hnsw.py +++ b/chromadb/segment/impl/vector/local_hnsw.py @@ -1,7 +1,7 @@ from overrides import override from typing import Optional, Sequence, Dict, Set, List, cast from uuid import UUID -from chromadb.segment import VectorReader +from chromadb.segment import SegmentImplementation from chromadb.ingest import Consumer from chromadb.config import System, Settings from chromadb.segment.impl.vector.batch import Batch @@ -34,7 +34,7 @@ DEFAULT_CAPACITY = 1000 -class LocalHnswSegment(VectorReader): +class LocalHnswSegment(SegmentImplementation): _id: UUID _consumer: Consumer _collection: Optional[UUID] @@ -103,7 +103,6 @@ def stop(self) -> None: self._consumer.unsubscribe(self._subscription) @trace_method("LocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL) - @override def get_vectors( self, request_version_context: RequestVersionContext, @@ -130,7 +129,6 @@ def get_vectors( return results @trace_method("LocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL) - @override def query_vectors( self, query: VectorQuery ) -> Sequence[Sequence[VectorQueryResult]]: diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index 92f65b27714..907ce94459a 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -11,7 +11,8 @@ import chromadb from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System -from chromadb.segment import SegmentManager, VectorReader +from chromadb.segment import SegmentManager +from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants from chromadb.test.property.test_embeddings import ( @@ -142,7 +143,7 @@ def test_sync_threshold(settings: Settings) -> None: ) manager = system.instance(SegmentManager) - segment = manager.get_segment(collection.id, VectorReader) + segment = manager.get_segment(collection.id, LocalHnswSegment) def get_index_last_modified_at() -> float: # Time resolution on Windows can be up to 10ms diff --git a/chromadb/test/property/test_segment_manager.py b/chromadb/test/property/test_segment_manager.py index cfb816f3725..0b4b757f3eb 100644 --- a/chromadb/test/property/test_segment_manager.py +++ b/chromadb/test/property/test_segment_manager.py @@ -16,8 +16,8 @@ MultipleResults, ) from typing import Dict, List -from chromadb.segment import VectorReader from chromadb.segment import SegmentManager +from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment from chromadb.types import SegmentScope from chromadb.db.system import SysDB @@ -112,7 +112,7 @@ def create_segment( @rule(coll=collections) def get_segment(self, coll: strategies.Collection) -> None: segment = self.segment_manager.get_segment( - collection_id=coll.id, type=VectorReader + collection_id=coll.id, type=LocalHnswSegment ) self.last_use.add(coll.id) assert segment is not None diff --git a/chromadb/test/segment/test_metadata.py b/chromadb/test/segment/test_metadata.py index 50bab861800..9e05b71f61c 100644 --- a/chromadb/test/segment/test_metadata.py +++ b/chromadb/test/segment/test_metadata.py @@ -30,7 +30,6 @@ ) from pypika import Table from chromadb.ingest import Producer -from chromadb.segment import MetadataReader import uuid import time @@ -140,7 +139,7 @@ def _build_document(i: int) -> str: ) -def sync(segment: MetadataReader, seq_id: SeqId) -> None: +def sync(segment: SqliteMetadataSegment, seq_id: SeqId) -> None: # Try for up to 5 seconds, then throw a TimeoutError start = time.time() while time.time() - start < 5: @@ -668,7 +667,7 @@ def test_upsert( def _test_update( sample_embeddings: Iterator[OperationRecord], producer: Producer, - segment: MetadataReader, + segment: SqliteMetadataSegment, collection_id: uuid.UUID, op: Operation, ) -> None: diff --git a/chromadb/test/segment/test_vector.py b/chromadb/test/segment/test_vector.py index 0d62c827461..07c0b61d709 100644 --- a/chromadb/test/segment/test_vector.py +++ b/chromadb/test/segment/test_vector.py @@ -1,5 +1,5 @@ import pytest -from typing import Generator, List, Callable, Iterator, Type, cast +from typing import Generator, List, Callable, Iterator, Type, cast, Union from chromadb.config import System, Settings from chromadb.test.conftest import ProducerFn from chromadb.types import ( @@ -14,7 +14,6 @@ Vector, ) from chromadb.ingest import Producer -from chromadb.segment import VectorReader import uuid import time @@ -34,6 +33,8 @@ import shutil import numpy as np +VectorReader = Union[LocalHnswSegment, PersistentLocalHnswSegment] + def sqlite() -> Generator[System, None, None]: """Fixture generator for sqlite DB"""