diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index e13ecb3672d..57a8cdff306 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -5,8 +5,10 @@ from chromadb.db.system import SysDB from chromadb.quota import QuotaEnforcer from chromadb.rate_limit import RateLimitEnforcer -from chromadb.segment import SegmentManager, MetadataReader, VectorReader +from chromadb.segment import SegmentManager from chromadb.execution.executor.abstract import Executor +from chromadb.execution.expression.operator import Scan, Filter, Limit, KNN, Projection +from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.telemetry.opentelemetry import ( add_attributes_to_current_span, OpenTelemetryClient, @@ -23,18 +25,16 @@ VersionMismatchError, ) from chromadb.api.types import ( - URI, CollectionMetadata, - Document, IDs, Embeddings, - Embedding, Metadatas, Documents, URIs, Where, WhereDocument, Include, + IncludeEnum, GetResult, QueryResult, validate_metadata, @@ -58,8 +58,6 @@ Sequence, Generator, List, - cast, - Set, Any, Callable, TypeVar, @@ -500,11 +498,6 @@ def _get( ) coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) - where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) @@ -512,8 +505,6 @@ def _get( else None ) - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) - if sort is not None: raise NotImplementedError("Sorting is not yet supported") @@ -521,47 +512,6 @@ def _get( offset = (page - 1) * page_size limit = page_size - records = metadata_segment.get_metadata( - where=where, - where_document=where_document, - ids=ids, - limit=limit, - offset=offset, - request_version_context=request_version_context, - ) - - if len(records) == 0: - # Nothing to return if there are no records - return GetResult( - ids=[], - embeddings=[] if "embeddings" in include else None, - metadatas=[] if "metadatas" in include else None, - documents=[] if "documents" in include else None, - uris=[] if "uris" in include else None, - data=[] if "data" in include else None, - included=include, - ) - - vectors: Sequence[t.VectorEmbeddingRecord] = [] - if "embeddings" in include: - vector_ids = [r["id"] for r in records] - vector_segment = self._manager.get_segment(collection_id, VectorReader) - vectors = vector_segment.get_vectors( - ids=vector_ids, request_version_context=request_version_context - ) - - # TODO: Fix type so we don't need to ignore - # It is possible to have a set of records, some with metadata and some without - # Same with documents - - metadatas = [r["metadata"] for r in records] - - if "documents" in include: - documents = [_doc(m) for m in metadatas] - - if "uris" in include: - uris = [_uri(m) for m in metadatas] - ids_amount = len(ids) if ids else 0 self._product_telemetry_client.capture( CollectionGetEvent( @@ -574,18 +524,19 @@ def _get( ) ) - return GetResult( - ids=[r["id"] for r in records], - embeddings=[r["embedding"] for r in vectors] - if "embeddings" in include - else None, - metadatas=_clean_metadatas(metadatas) - if "metadatas" in include - else None, # type: ignore - documents=documents if "documents" in include else None, # type: ignore - uris=uris if "uris" in include else None, # type: ignore - data=None, - included=include, + return self._executor.get( + GetPlan( + Scan(coll), + Filter(ids, where, where_document), + Limit(offset or 0, limit), + Projection( + IncludeEnum.documents in include, + IncludeEnum.embeddings in include, + IncludeEnum.metadatas in include, + False, + IncludeEnum.uris in include, + ), + ) ) @trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION) @@ -630,21 +581,12 @@ def _delete( ) coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) self._manager.hint_use_collection(collection_id, t.Operation.DELETE) if (where or where_document) or not ids: - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) - records = metadata_segment.get_metadata( - where=where, - where_document=where_document, - ids=ids, - request_version_context=request_version_context, - ) - ids_to_delete = [r["id"] for r in records] + ids_to_delete = self._executor.get( + GetPlan(Scan(coll), Filter(ids, where, where_document)) + )["ids"] else: ids_to_delete = ids @@ -674,13 +616,7 @@ def _delete( def _count(self, collection_id: UUID) -> int: add_attributes_to_current_span({"collection_id": str(collection_id)}) coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) - - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) - return metadata_segment.count(request_version_context) + return self._executor.count(CountPlan(Scan(coll))) @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) # We retry on version mismatch errors because the version of the collection @@ -736,110 +672,23 @@ def _query( else where_document ) - allowed_ids = None - coll = self._get_collection(collection_id) - request_version_context = t.RequestVersionContext( - collection_version=coll.version, - log_position=coll.log_position, - ) for embedding in query_embeddings: self._validate_dimension(coll, len(embedding), update=False) - if where or where_document: - metadata_reader = self._manager.get_segment(collection_id, MetadataReader) - records = metadata_reader.get_metadata( - where=where, - where_document=where_document, - include_metadata=False, - request_version_context=request_version_context, - ) - allowed_ids = [r["id"] for r in records] - - ids: List[List[str]] = [] - distances: List[List[float]] = [] - embeddings: List[Embeddings] = [] - documents: List[List[Document]] = [] - uris: List[List[URI]] = [] - metadatas: List[List[t.Metadata]] = [] - - # If where conditions returned empty list then no need to proceed - # further and can simply return an empty result set here. - if allowed_ids is not None and allowed_ids == []: - for em in range(len(query_embeddings)): - ids.append([]) - if "distances" in include: - distances.append([]) - if "embeddings" in include: - embeddings.append([]) - if "documents" in include: - documents.append([]) - if "metadatas" in include: - metadatas.append([]) - if "uris" in include: - uris.append([]) - else: - query = t.VectorQuery( - vectors=query_embeddings, - k=n_results, - allowed_ids=allowed_ids, - include_embeddings="embeddings" in include, - options=None, - request_version_context=request_version_context, + return self._executor.knn( + KNNPlan( + Scan(coll), + KNN(query_embeddings, n_results), + Filter(None, where, where_document), + Projection( + IncludeEnum.documents in include, + IncludeEnum.embeddings in include, + IncludeEnum.metadatas in include, + IncludeEnum.distances in include, + IncludeEnum.uris in include, + ), ) - - vector_reader = self._manager.get_segment(collection_id, VectorReader) - results = vector_reader.query_vectors(query) - - for result in results: - ids.append([r["id"] for r in result]) - if "distances" in include: - distances.append([r["distance"] for r in result]) - if "embeddings" in include: - embeddings.append([cast(Embedding, r["embedding"]) for r in result]) - - if "documents" in include or "metadatas" in include or "uris" in include: - all_ids: Set[str] = set() - for id_list in ids: - all_ids.update(id_list) - metadata_reader = self._manager.get_segment( - collection_id, MetadataReader - ) - records = metadata_reader.get_metadata( - ids=list(all_ids), - include_metadata=True, - request_version_context=request_version_context, - ) - metadata_by_id = {r["id"]: r["metadata"] for r in records} - for id_list in ids: - # In the segment based architecture, it is possible for one segment - # to have a record that another segment does not have. This results in - # data inconsistency. For the case of the local segments and the - # local segment manager, there is a case where a thread writes - # a record to the vector segment but not the metadata segment. - # Then a query'ing thread reads from the vector segment and - # queries the metadata segment. The metadata segment does not have - # the record. In this case we choose to return potentially - # incorrect data in the form of None. - metadata_list = [metadata_by_id.get(id, None) for id in id_list] - if "metadatas" in include: - metadatas.append(_clean_metadatas(metadata_list)) # type: ignore - if "documents" in include: - doc_list = [_doc(m) for m in metadata_list] - documents.append(doc_list) # type: ignore - if "uris" in include: - uri_list = [_uri(m) for m in metadata_list] - uris.append(uri_list) # type: ignore - - return QueryResult( - ids=ids, - distances=distances if distances else None, - metadatas=metadatas if metadatas else None, - embeddings=embeddings if embeddings else None, - documents=documents if documents else None, - uris=uris if uris else None, - data=None, - included=include, ) @trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION) @@ -960,41 +809,3 @@ def _records( operation=operation, ) yield record - - -def _doc(metadata: Optional[t.Metadata]) -> Optional[str]: - """Retrieve the document (if any) from a Metadata map""" - - if metadata and "chroma:document" in metadata: - return str(metadata["chroma:document"]) - return None - - -def _uri(metadata: Optional[t.Metadata]) -> Optional[str]: - """Retrieve the uri (if any) from a Metadata map""" - - if metadata and "chroma:uri" in metadata: - return str(metadata["chroma:uri"]) - return None - - -def _clean_metadatas( - metadata: List[Optional[t.Metadata]], -) -> List[Optional[t.Metadata]]: - """Remove any chroma-specific metadata keys that the client shouldn't see from a - list of metadata maps.""" - return [_clean_metadata(m) for m in metadata] - - -def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]: - """Remove any chroma-specific metadata keys that the client shouldn't see from a - metadata map.""" - if not metadata: - return None - result = {} - for k, v in metadata.items(): - if not k.startswith("chroma:"): - result[k] = v - if len(result) == 0: - return None - return result diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index 810bcd68537..a47efe05970 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -13,8 +13,7 @@ from chromadb.execution.executor.abstract import Executor from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.segment.impl.manager.distributed import DistributedSegmentManager -from chromadb.segment.impl.metadata.grpc_segment import GrpcMetadataSegment -from chromadb.segment.impl.vector.grpc_segment import GrpcVectorSegment +from chromadb.segment import MetadataReader, VectorReader from chromadb.types import VectorQuery, VectorQueryResult, Collection from overrides import overrides @@ -200,8 +199,8 @@ def knn(self, plan: KNNPlan) -> QueryResult: included=included, ) - def _metadata_segment(self, collection: Collection) -> GrpcMetadataSegment: - return self._manager.get_segment(collection.id, GrpcMetadataSegment) + def _metadata_segment(self, collection: Collection) -> MetadataReader: + return self._manager.get_segment(collection.id, MetadataReader) - def _vector_segment(self, collection: Collection) -> GrpcVectorSegment: - return self._manager.get_segment(collection.id, GrpcVectorSegment) + def _vector_segment(self, collection: Collection) -> VectorReader: + return self._manager.get_segment(collection.id, VectorReader) diff --git a/chromadb/execution/executor/local.py b/chromadb/execution/executor/local.py index e483df40fa8..5926c151c97 100644 --- a/chromadb/execution/executor/local.py +++ b/chromadb/execution/executor/local.py @@ -10,8 +10,7 @@ from chromadb.execution.executor.abstract import Executor from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.segment.impl.manager.local import LocalSegmentManager -from chromadb.segment.impl.metadata.sqlite import SqliteMetadataSegment -from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment +from chromadb.segment import MetadataReader, VectorReader from chromadb.types import VectorQuery, VectorQueryResult, Collection from overrides import overrides @@ -197,8 +196,8 @@ def knn(self, plan: KNNPlan) -> QueryResult: included=included, ) - def _metadata_segment(self, collection: Collection) -> SqliteMetadataSegment: - return self._manager.get_segment(collection.id, SqliteMetadataSegment) + def _metadata_segment(self, collection: Collection) -> MetadataReader: + return self._manager.get_segment(collection.id, MetadataReader) - def _vector_segment(self, collection: Collection) -> LocalHnswSegment: - return self._manager.get_segment(collection.id, LocalHnswSegment) + def _vector_segment(self, collection: Collection) -> VectorReader: + return self._manager.get_segment(collection.id, VectorReader) diff --git a/chromadb/segment/__init__.py b/chromadb/segment/__init__.py index 6f93d758926..7247de9fd07 100644 --- a/chromadb/segment/__init__.py +++ b/chromadb/segment/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, TypeVar, Type +from typing import Optional, Sequence, TypeVar from abc import abstractmethod from chromadb.types import ( Collection, @@ -115,18 +115,6 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: returns a sequence of their IDs. Does not update the SysDB.""" pass - # Future Note: To support time travel, add optional parameters to this method to - # retrieve Segment instances that are bounded to events from a specific range of - # time - @abstractmethod - def get_segment(self, collection_id: UUID, type: Type[S]) -> S: - """Return the segment that should be used for servicing queries to a collection. - Implementations should cache appropriately; clients are intended to call this - method repeatedly rather than storing the result (thereby giving this - implementation full control over which segment impls are in or out of memory at - a given time.)""" - pass - @abstractmethod def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: """Signal to the segment manager that a collection is about to be used, so that diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index e117ac2ecee..c1792a56ce2 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -82,7 +82,6 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: "DistributedSegmentManager.get_segment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) - @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: if type == MetadataReader: scope = SegmentScope.METADATA diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index fe23f679186..47a45d18928 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -197,7 +197,6 @@ def _get_segment_sysdb(self, collection_id: UUID, scope: SegmentScope) -> Segmen "LocalSegmentManager.get_segment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) - @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: if type == MetadataReader: scope = SegmentScope.METADATA