Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Revert "[ENH] Migrate away from MetadataReader and VectorReader" #2950

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 222 additions & 33 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from chromadb.db.system import SysDB
from chromadb.quota import QuotaEnforcer
from chromadb.rate_limit import RateLimitEnforcer
from chromadb.segment import SegmentManager
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan, Filter, Limit, KNN, Projection
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
Expand All @@ -25,16 +23,18 @@
VersionMismatchError,
)
from chromadb.api.types import (
URI,
CollectionMetadata,
Document,
IDs,
Embeddings,
Embedding,
Metadatas,
Documents,
URIs,
Where,
WhereDocument,
Include,
IncludeEnum,
GetResult,
QueryResult,
validate_metadata,
Expand All @@ -58,6 +58,8 @@
Sequence,
Generator,
List,
cast,
Set,
Any,
Callable,
TypeVar,
Expand Down Expand Up @@ -498,20 +500,68 @@ def _get(
)

coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)

where = validate_where(where) if where is not None and len(where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
else None
)

metadata_segment = self._manager.get_segment(collection_id, MetadataReader)

if sort is not None:
raise NotImplementedError("Sorting is not yet supported")

if page and page_size:
offset = (page - 1) * page_size
limit = page_size

records = metadata_segment.get_metadata(
where=where,
where_document=where_document,
ids=ids,
limit=limit,
offset=offset,
request_version_context=request_version_context,
)

if len(records) == 0:
# Nothing to return if there are no records
return GetResult(
ids=[],
embeddings=[] if "embeddings" in include else None,
metadatas=[] if "metadatas" in include else None,
documents=[] if "documents" in include else None,
uris=[] if "uris" in include else None,
data=[] if "data" in include else None,
included=include,
)

vectors: Sequence[t.VectorEmbeddingRecord] = []
if "embeddings" in include:
vector_ids = [r["id"] for r in records]
vector_segment = self._manager.get_segment(collection_id, VectorReader)
vectors = vector_segment.get_vectors(
ids=vector_ids, request_version_context=request_version_context
)

# TODO: Fix type so we don't need to ignore
# It is possible to have a set of records, some with metadata and some without
# Same with documents

metadatas = [r["metadata"] for r in records]

if "documents" in include:
documents = [_doc(m) for m in metadatas]

if "uris" in include:
uris = [_uri(m) for m in metadatas]

ids_amount = len(ids) if ids else 0
self._product_telemetry_client.capture(
CollectionGetEvent(
Expand All @@ -524,19 +574,18 @@ def _get(
)
)

return self._executor.get(
GetPlan(
Scan(coll),
Filter(ids, where, where_document),
Limit(offset or 0, limit),
Projection(
IncludeEnum.documents in include,
IncludeEnum.embeddings in include,
IncludeEnum.metadatas in include,
False,
IncludeEnum.uris in include,
),
)
return GetResult(
ids=[r["id"] for r in records],
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(metadatas)
if "metadatas" in include
else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
uris=uris if "uris" in include else None, # type: ignore
data=None,
included=include,
)

@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -581,12 +630,21 @@ def _delete(
)

coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)

if (where or where_document) or not ids:
ids_to_delete = self._executor.get(
GetPlan(Scan(coll), Filter(ids, where, where_document))
)["ids"]
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_segment.get_metadata(
where=where,
where_document=where_document,
ids=ids,
request_version_context=request_version_context,
)
ids_to_delete = [r["id"] for r in records]
else:
ids_to_delete = ids

Expand Down Expand Up @@ -616,7 +674,13 @@ def _delete(
def _count(self, collection_id: UUID) -> int:
add_attributes_to_current_span({"collection_id": str(collection_id)})
coll = self._get_collection(collection_id)
return self._executor.count(CountPlan(Scan(coll)))
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)

metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
return metadata_segment.count(request_version_context)

@trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
# We retry on version mismatch errors because the version of the collection
Expand Down Expand Up @@ -672,23 +736,110 @@ def _query(
else where_document
)

allowed_ids = None

coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)

return self._executor.knn(
KNNPlan(
Scan(coll),
KNN(query_embeddings, n_results),
Filter(None, where, where_document),
Projection(
IncludeEnum.documents in include,
IncludeEnum.embeddings in include,
IncludeEnum.metadatas in include,
IncludeEnum.distances in include,
IncludeEnum.uris in include,
),
if where or where_document:
metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_reader.get_metadata(
where=where,
where_document=where_document,
include_metadata=False,
request_version_context=request_version_context,
)
allowed_ids = [r["id"] for r in records]

ids: List[List[str]] = []
distances: List[List[float]] = []
embeddings: List[Embeddings] = []
documents: List[List[Document]] = []
uris: List[List[URI]] = []
metadatas: List[List[t.Metadata]] = []

# If where conditions returned empty list then no need to proceed
# further and can simply return an empty result set here.
if allowed_ids is not None and allowed_ids == []:
for em in range(len(query_embeddings)):
ids.append([])
if "distances" in include:
distances.append([])
if "embeddings" in include:
embeddings.append([])
if "documents" in include:
documents.append([])
if "metadatas" in include:
metadatas.append([])
if "uris" in include:
uris.append([])
else:
query = t.VectorQuery(
vectors=query_embeddings,
k=n_results,
allowed_ids=allowed_ids,
include_embeddings="embeddings" in include,
options=None,
request_version_context=request_version_context,
)

vector_reader = self._manager.get_segment(collection_id, VectorReader)
results = vector_reader.query_vectors(query)

for result in results:
ids.append([r["id"] for r in result])
if "distances" in include:
distances.append([r["distance"] for r in result])
if "embeddings" in include:
embeddings.append([cast(Embedding, r["embedding"]) for r in result])

if "documents" in include or "metadatas" in include or "uris" in include:
all_ids: Set[str] = set()
for id_list in ids:
all_ids.update(id_list)
metadata_reader = self._manager.get_segment(
collection_id, MetadataReader
)
records = metadata_reader.get_metadata(
ids=list(all_ids),
include_metadata=True,
request_version_context=request_version_context,
)
metadata_by_id = {r["id"]: r["metadata"] for r in records}
for id_list in ids:
# In the segment based architecture, it is possible for one segment
# to have a record that another segment does not have. This results in
# data inconsistency. For the case of the local segments and the
# local segment manager, there is a case where a thread writes
# a record to the vector segment but not the metadata segment.
# Then a query'ing thread reads from the vector segment and
# queries the metadata segment. The metadata segment does not have
# the record. In this case we choose to return potentially
# incorrect data in the form of None.
metadata_list = [metadata_by_id.get(id, None) for id in id_list]
if "metadatas" in include:
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
if "documents" in include:
doc_list = [_doc(m) for m in metadata_list]
documents.append(doc_list) # type: ignore
if "uris" in include:
uri_list = [_uri(m) for m in metadata_list]
uris.append(uri_list) # type: ignore

return QueryResult(
ids=ids,
distances=distances if distances else None,
metadatas=metadatas if metadatas else None,
embeddings=embeddings if embeddings else None,
documents=documents if documents else None,
uris=uris if uris else None,
data=None,
included=include,
)

@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -809,3 +960,41 @@ def _records(
operation=operation,
)
yield record


def _doc(metadata: Optional[t.Metadata]) -> Optional[str]:
"""Retrieve the document (if any) from a Metadata map"""

if metadata and "chroma:document" in metadata:
return str(metadata["chroma:document"])
return None


def _uri(metadata: Optional[t.Metadata]) -> Optional[str]:
"""Retrieve the uri (if any) from a Metadata map"""

if metadata and "chroma:uri" in metadata:
return str(metadata["chroma:uri"])
return None


def _clean_metadatas(
metadata: List[Optional[t.Metadata]],
) -> List[Optional[t.Metadata]]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
list of metadata maps."""
return [_clean_metadata(m) for m in metadata]


def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
metadata map."""
if not metadata:
return None
result = {}
for k, v in metadata.items():
if not k.startswith("chroma:"):
result[k] = v
if len(result) == 0:
return None
return result
5 changes: 2 additions & 3 deletions chromadb/ingest/impl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from uuid import UUID

from chromadb.db.base import SqlDB
from chromadb.segment import SegmentManager
from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment
from chromadb.segment import SegmentManager, VectorReader

topic_regex = r"persistent:\/\/(?P<tenant>.+)\/(?P<namespace>.+)\/(?P<topic>.+)"

Expand Down Expand Up @@ -47,4 +46,4 @@ def trigger_vector_segments_max_seq_id_migration(

for collection_id in collection_ids_with_unmigrated_segments:
# Loading the segment triggers the migration on init
segment_manager.get_segment(UUID(collection_id), LocalHnswSegment)
segment_manager.get_segment(UUID(collection_id), VectorReader)
Loading
Loading