Skip to content

Commit

Permalink
[ENH] Use executor in segment.py (#2951)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Use executor interface in `SegmentAPI` to replace the direct usage of
segment readers. This is a follow up of
#2950
 - New functionality
	 - N/A

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*

---------

Co-authored-by: Sicheng Pan <sicheng@trychroma.com>
  • Loading branch information
Sicheng-Pan and Sicheng Pan authored Oct 14, 2024
1 parent f66222f commit a6cf995
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 249 deletions.
255 changes: 33 additions & 222 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from chromadb.db.system import SysDB
from chromadb.quota import QuotaEnforcer
from chromadb.rate_limit import RateLimitEnforcer
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
from chromadb.segment import SegmentManager
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan, Filter, Limit, KNN, Projection
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
Expand All @@ -23,18 +25,16 @@
VersionMismatchError,
)
from chromadb.api.types import (
URI,
CollectionMetadata,
Document,
IDs,
Embeddings,
Embedding,
Metadatas,
Documents,
URIs,
Where,
WhereDocument,
Include,
IncludeEnum,
GetResult,
QueryResult,
validate_metadata,
Expand All @@ -58,8 +58,6 @@
Sequence,
Generator,
List,
cast,
Set,
Any,
Callable,
TypeVar,
Expand Down Expand Up @@ -500,68 +498,20 @@ def _get(
)

coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)

where = validate_where(where) if where is not None and len(where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
else None
)

metadata_segment = self._manager.get_segment(collection_id, MetadataReader)

if sort is not None:
raise NotImplementedError("Sorting is not yet supported")

if page and page_size:
offset = (page - 1) * page_size
limit = page_size

records = metadata_segment.get_metadata(
where=where,
where_document=where_document,
ids=ids,
limit=limit,
offset=offset,
request_version_context=request_version_context,
)

if len(records) == 0:
# Nothing to return if there are no records
return GetResult(
ids=[],
embeddings=[] if "embeddings" in include else None,
metadatas=[] if "metadatas" in include else None,
documents=[] if "documents" in include else None,
uris=[] if "uris" in include else None,
data=[] if "data" in include else None,
included=include,
)

vectors: Sequence[t.VectorEmbeddingRecord] = []
if "embeddings" in include:
vector_ids = [r["id"] for r in records]
vector_segment = self._manager.get_segment(collection_id, VectorReader)
vectors = vector_segment.get_vectors(
ids=vector_ids, request_version_context=request_version_context
)

# TODO: Fix type so we don't need to ignore
# It is possible to have a set of records, some with metadata and some without
# Same with documents

metadatas = [r["metadata"] for r in records]

if "documents" in include:
documents = [_doc(m) for m in metadatas]

if "uris" in include:
uris = [_uri(m) for m in metadatas]

ids_amount = len(ids) if ids else 0
self._product_telemetry_client.capture(
CollectionGetEvent(
Expand All @@ -574,18 +524,19 @@ def _get(
)
)

return GetResult(
ids=[r["id"] for r in records],
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(metadatas)
if "metadatas" in include
else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
uris=uris if "uris" in include else None, # type: ignore
data=None,
included=include,
return self._executor.get(
GetPlan(
Scan(coll),
Filter(ids, where, where_document),
Limit(offset or 0, limit),
Projection(
IncludeEnum.documents in include,
IncludeEnum.embeddings in include,
IncludeEnum.metadatas in include,
False,
IncludeEnum.uris in include,
),
)
)

@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -630,21 +581,12 @@ def _delete(
)

coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)

if (where or where_document) or not ids:
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_segment.get_metadata(
where=where,
where_document=where_document,
ids=ids,
request_version_context=request_version_context,
)
ids_to_delete = [r["id"] for r in records]
ids_to_delete = self._executor.get(
GetPlan(Scan(coll), Filter(ids, where, where_document))
)["ids"]
else:
ids_to_delete = ids

Expand Down Expand Up @@ -674,13 +616,7 @@ def _delete(
def _count(self, collection_id: UUID) -> int:
add_attributes_to_current_span({"collection_id": str(collection_id)})
coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)

metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
return metadata_segment.count(request_version_context)
return self._executor.count(CountPlan(Scan(coll)))

@trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
# We retry on version mismatch errors because the version of the collection
Expand Down Expand Up @@ -736,110 +672,23 @@ def _query(
else where_document
)

allowed_ids = None

coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)

if where or where_document:
metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_reader.get_metadata(
where=where,
where_document=where_document,
include_metadata=False,
request_version_context=request_version_context,
)
allowed_ids = [r["id"] for r in records]

ids: List[List[str]] = []
distances: List[List[float]] = []
embeddings: List[Embeddings] = []
documents: List[List[Document]] = []
uris: List[List[URI]] = []
metadatas: List[List[t.Metadata]] = []

# If where conditions returned empty list then no need to proceed
# further and can simply return an empty result set here.
if allowed_ids is not None and allowed_ids == []:
for em in range(len(query_embeddings)):
ids.append([])
if "distances" in include:
distances.append([])
if "embeddings" in include:
embeddings.append([])
if "documents" in include:
documents.append([])
if "metadatas" in include:
metadatas.append([])
if "uris" in include:
uris.append([])
else:
query = t.VectorQuery(
vectors=query_embeddings,
k=n_results,
allowed_ids=allowed_ids,
include_embeddings="embeddings" in include,
options=None,
request_version_context=request_version_context,
return self._executor.knn(
KNNPlan(
Scan(coll),
KNN(query_embeddings, n_results),
Filter(None, where, where_document),
Projection(
IncludeEnum.documents in include,
IncludeEnum.embeddings in include,
IncludeEnum.metadatas in include,
IncludeEnum.distances in include,
IncludeEnum.uris in include,
),
)

vector_reader = self._manager.get_segment(collection_id, VectorReader)
results = vector_reader.query_vectors(query)

for result in results:
ids.append([r["id"] for r in result])
if "distances" in include:
distances.append([r["distance"] for r in result])
if "embeddings" in include:
embeddings.append([cast(Embedding, r["embedding"]) for r in result])

if "documents" in include or "metadatas" in include or "uris" in include:
all_ids: Set[str] = set()
for id_list in ids:
all_ids.update(id_list)
metadata_reader = self._manager.get_segment(
collection_id, MetadataReader
)
records = metadata_reader.get_metadata(
ids=list(all_ids),
include_metadata=True,
request_version_context=request_version_context,
)
metadata_by_id = {r["id"]: r["metadata"] for r in records}
for id_list in ids:
# In the segment based architecture, it is possible for one segment
# to have a record that another segment does not have. This results in
# data inconsistency. For the case of the local segments and the
# local segment manager, there is a case where a thread writes
# a record to the vector segment but not the metadata segment.
# Then a query'ing thread reads from the vector segment and
# queries the metadata segment. The metadata segment does not have
# the record. In this case we choose to return potentially
# incorrect data in the form of None.
metadata_list = [metadata_by_id.get(id, None) for id in id_list]
if "metadatas" in include:
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
if "documents" in include:
doc_list = [_doc(m) for m in metadata_list]
documents.append(doc_list) # type: ignore
if "uris" in include:
uri_list = [_uri(m) for m in metadata_list]
uris.append(uri_list) # type: ignore

return QueryResult(
ids=ids,
distances=distances if distances else None,
metadatas=metadatas if metadatas else None,
embeddings=embeddings if embeddings else None,
documents=documents if documents else None,
uris=uris if uris else None,
data=None,
included=include,
)

@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -960,41 +809,3 @@ def _records(
operation=operation,
)
yield record


def _doc(metadata: Optional[t.Metadata]) -> Optional[str]:
"""Retrieve the document (if any) from a Metadata map"""

if metadata and "chroma:document" in metadata:
return str(metadata["chroma:document"])
return None


def _uri(metadata: Optional[t.Metadata]) -> Optional[str]:
"""Retrieve the uri (if any) from a Metadata map"""

if metadata and "chroma:uri" in metadata:
return str(metadata["chroma:uri"])
return None


def _clean_metadatas(
metadata: List[Optional[t.Metadata]],
) -> List[Optional[t.Metadata]]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
list of metadata maps."""
return [_clean_metadata(m) for m in metadata]


def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
metadata map."""
if not metadata:
return None
result = {}
for k, v in metadata.items():
if not k.startswith("chroma:"):
result[k] = v
if len(result) == 0:
return None
return result
11 changes: 5 additions & 6 deletions chromadb/execution/executor/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
from chromadb.segment.impl.metadata.grpc_segment import GrpcMetadataSegment
from chromadb.segment.impl.vector.grpc_segment import GrpcVectorSegment
from chromadb.segment import MetadataReader, VectorReader
from chromadb.types import VectorQuery, VectorQueryResult, Collection
from overrides import overrides

Expand Down Expand Up @@ -200,8 +199,8 @@ def knn(self, plan: KNNPlan) -> QueryResult:
included=included,
)

def _metadata_segment(self, collection: Collection) -> GrpcMetadataSegment:
return self._manager.get_segment(collection.id, GrpcMetadataSegment)
def _metadata_segment(self, collection: Collection) -> MetadataReader:
return self._manager.get_segment(collection.id, MetadataReader)

def _vector_segment(self, collection: Collection) -> GrpcVectorSegment:
return self._manager.get_segment(collection.id, GrpcVectorSegment)
def _vector_segment(self, collection: Collection) -> VectorReader:
return self._manager.get_segment(collection.id, VectorReader)
11 changes: 5 additions & 6 deletions chromadb/execution/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.segment.impl.manager.local import LocalSegmentManager
from chromadb.segment.impl.metadata.sqlite import SqliteMetadataSegment
from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment
from chromadb.segment import MetadataReader, VectorReader
from chromadb.types import VectorQuery, VectorQueryResult, Collection
from overrides import overrides

Expand Down Expand Up @@ -197,8 +196,8 @@ def knn(self, plan: KNNPlan) -> QueryResult:
included=included,
)

def _metadata_segment(self, collection: Collection) -> SqliteMetadataSegment:
return self._manager.get_segment(collection.id, SqliteMetadataSegment)
def _metadata_segment(self, collection: Collection) -> MetadataReader:
return self._manager.get_segment(collection.id, MetadataReader)

def _vector_segment(self, collection: Collection) -> LocalHnswSegment:
return self._manager.get_segment(collection.id, LocalHnswSegment)
def _vector_segment(self, collection: Collection) -> VectorReader:
return self._manager.get_segment(collection.id, VectorReader)
Loading

0 comments on commit a6cf995

Please sign in to comment.