Skip to content

Commit

Permalink
[CLN] Rename SubmitEmbeddingRecord to OperationRecord. Remove Topic c…
Browse files Browse the repository at this point in the history
…oncept from Segment/Collection. (#1933)

## Description of changes

This PR has two high-level aims
- Clean up artifacts of our design that were implicitly coupled to
pulsar. Namely topics and multiplexing.
- Begin a rename of SubmitEmbeddingRecord, EmbeddingRecord to names that
more accurately reflect their intents.

I apologize for how large this PR is, but in practice, breaking
something up like this is not really feasible AFAICT, unless we allow
test-breaking stacked PRs...

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Renames SubmitEmbeddingRecord to OperationRecord in order to more
correctly identify what it is - a record of an Operation (future PRs
will rename EmbeddingRecord as well to LogRecord to make its intent
clearer).
- An OperationRecord does not need to store collection_id. This was an
artifact of the pulsar log when we needed to demux data. We now improve
the Ingest interface by presenting producers/consumers over logical log
streams by collection.
- Remove the concept of topic from the Producer/Consumer interfaces - it
is no longer needed in a post pulsar-world. This also means
Collection/Segment don't need to store Topic.
- Removed the AssignmentPolicy concept. This only existed for
multiplexing - which is not a concept without pulsar.
- Update the Rust code with the topic field removed and with the
OperationRecord naming.
- Update Go code with the SysDB changes (No assignment policy + no log)
no as well as the OperationRecord naming.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
HammadB authored Mar 27, 2024
1 parent 739e942 commit 1ce93c7
Show file tree
Hide file tree
Showing 86 changed files with 1,196 additions and 2,652 deletions.
27 changes: 11 additions & 16 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
)

import chromadb.types as t

from typing import Any, Optional, Sequence, Generator, List, cast, Set, Dict
from overrides import override
from uuid import UUID, uuid4
Expand Down Expand Up @@ -123,10 +122,12 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
name=name,
tenant=tenant,
)

@trace_method("SegmentAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
return self._sysdb.get_database(name=name, tenant=tenant)

@trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
Expand All @@ -136,6 +137,7 @@ def create_tenant(self, name: str) -> None:
self._sysdb.create_tenant(
name=name,
)

@trace_method("SegmentAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> t.Tenant:
Expand Down Expand Up @@ -374,15 +376,14 @@ def _add(
for r in _records(
t.Operation.ADD,
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
CollectionAddEvent(
Expand Down Expand Up @@ -417,15 +418,14 @@ def _update(
for r in _records(
t.Operation.UPDATE,
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
CollectionUpdateEvent(
Expand Down Expand Up @@ -462,15 +462,14 @@ def _upsert(
for r in _records(
t.Operation.UPSERT,
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

return True

Expand Down Expand Up @@ -632,12 +631,10 @@ def _delete(
return []

records_to_submit = []
for r in _records(
operation=t.Operation.DELETE, ids=ids_to_delete, collection_id=collection_id
):
for r in _records(operation=t.Operation.DELETE, ids=ids_to_delete):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
CollectionDeleteEvent(
Expand Down Expand Up @@ -803,7 +800,7 @@ def max_batch_size(self) -> int:
# used for channel assignment in the distributed version of the system.
@trace_method("SegmentAPI._validate_embedding_record", OpenTelemetryGranularity.ALL)
def _validate_embedding_record(
self, collection: t.Collection, record: t.SubmitEmbeddingRecord
self, collection: t.Collection, record: t.OperationRecord
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
add_attributes_to_current_span({"collection_id": str(collection["id"])})
Expand Down Expand Up @@ -845,12 +842,11 @@ def _get_collection(self, collection_id: UUID) -> t.Collection:
def _records(
operation: t.Operation,
ids: IDs,
collection_id: UUID,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> Generator[t.SubmitEmbeddingRecord, None, None]:
) -> Generator[t.OperationRecord, None, None]:
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of
SubmitEmbeddingRecords"""

Expand All @@ -877,13 +873,12 @@ def _records(
else:
metadata = {"chroma:uri": uri}

record = t.SubmitEmbeddingRecord(
record = t.OperationRecord(
id=id,
embedding=embeddings[i] if embeddings else None,
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now
metadata=metadata,
operation=operation,
collection_id=collection_id,
)
yield record

Expand Down
9 changes: 4 additions & 5 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
"chromadb.ingest.Producer": "chroma_producer_impl",
"chromadb.ingest.Consumer": "chroma_consumer_impl",
"chromadb.quota.QuotaProvider": "chroma_quota_provider_impl",
"chromadb.ingest.CollectionAssignmentPolicy": "chroma_collection_assignment_policy_impl", # noqa
"chromadb.db.system.SysDB": "chroma_sysdb_impl",
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
Expand All @@ -86,9 +85,12 @@
class Settings(BaseSettings): # type: ignore
environment: str = ""

# Legacy config has to be kept around because pydantic will error
# Legacy config that has to be kept around because pydantic will error
# on nonexisting keys
chroma_db_impl: Optional[str] = None
chroma_collection_assignment_policy_impl: str = (
"chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy"
)
# Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI"
chroma_api_impl: str = "chromadb.api.segment.SegmentAPI"
chroma_product_telemetry_impl: str = "chromadb.telemetry.product.posthog.Posthog"
Expand All @@ -109,9 +111,6 @@ class Settings(BaseSettings): # type: ignore
# Distributed architecture specific components
chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory"
chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider"
chroma_collection_assignment_policy_impl: str = (
"chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy"
)
worker_memberlist_name: str = "query-service-memberlist"
chroma_coordinator_host = "localhost"

Expand Down
21 changes: 0 additions & 21 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
UpdateSegmentRequest,
)
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
from chromadb.telemetry.opentelemetry import OpenTelemetryClient
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.types import (
Collection,
Expand Down Expand Up @@ -145,14 +144,12 @@ def get_segments(
id: Optional[UUID] = None,
type: Optional[str] = None,
scope: Optional[SegmentScope] = None,
topic: Optional[str] = None,
collection: Optional[UUID] = None,
) -> Sequence[Segment]:
request = GetSegmentsRequest(
id=id.hex if id else None,
type=type,
scope=to_proto_segment_scope(scope) if scope else None,
topic=topic,
collection=collection.hex if collection else None,
)
response = self._sys_db_stub.GetSegments(request)
Expand All @@ -166,14 +163,9 @@ def get_segments(
def update_segment(
self,
id: UUID,
topic: OptionalArgument[Optional[str]] = Unspecified(),
collection: OptionalArgument[Optional[UUID]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
write_topic = None
if topic != Unspecified():
write_topic = cast(Union[str, None], topic)

write_collection = None
if collection != Unspecified():
write_collection = cast(Union[UUID, None], collection)
Expand All @@ -184,17 +176,12 @@ def update_segment(

request = UpdateSegmentRequest(
id=id.hex,
topic=write_topic,
collection=write_collection.hex if write_collection else None,
metadata=to_proto_update_metadata(write_metadata)
if write_metadata
else None,
)

if topic is None:
request.ClearField("topic")
request.reset_topic = True

if collection is None:
request.ClearField("collection")
request.reset_collection = True
Expand Down Expand Up @@ -252,7 +239,6 @@ def delete_collection(
def get_collections(
self,
id: Optional[UUID] = None,
topic: Optional[str] = None,
name: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand All @@ -262,7 +248,6 @@ def get_collections(
# TODO: implement limit and offset in the gRPC service
request = GetCollectionsRequest(
id=id.hex if id else None,
topic=topic,
name=name,
tenant=tenant,
database=database,
Expand All @@ -277,15 +262,10 @@ def get_collections(
def update_collection(
self,
id: UUID,
topic: OptionalArgument[str] = Unspecified(),
name: OptionalArgument[str] = Unspecified(),
dimension: OptionalArgument[Optional[int]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
write_topic = None
if topic != Unspecified():
write_topic = cast(str, topic)

write_name = None
if name != Unspecified():
write_name = cast(str, name)
Expand All @@ -300,7 +280,6 @@ def update_collection(

request = UpdateCollectionRequest(
id=id.hex,
topic=write_topic,
name=write_name,
dimension=write_dimension,
metadata=to_proto_update_metadata(write_metadata)
Expand Down
18 changes: 1 addition & 17 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Dict, cast
from uuid import UUID
from overrides import overrides
from chromadb.ingest import CollectionAssignmentPolicy
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
from chromadb.proto.convert import (
from_proto_metadata,
Expand Down Expand Up @@ -38,7 +37,7 @@
UpdateCollectionRequest,
UpdateCollectionResponse,
UpdateSegmentRequest,
UpdateSegmentResponse
UpdateSegmentResponse,
)
from chromadb.proto.coordinator_pb2_grpc import (
SysDBServicer,
Expand All @@ -55,7 +54,6 @@ class GrpcMockSysDB(SysDBServicer, Component):

_server: grpc.Server
_server_port: int
_assignment_policy: CollectionAssignmentPolicy
_segments: Dict[str, Segment] = {}
_tenants_to_databases_to_collections: Dict[
str, Dict[str, Dict[str, Collection]]
Expand All @@ -64,7 +62,6 @@ class GrpcMockSysDB(SysDBServicer, Component):

def __init__(self, system: System):
self._server_port = system.settings.require("chroma_server_grpc_port")
self._assignment_policy = system.instance(CollectionAssignmentPolicy)
return super().__init__(system)

@overrides
Expand Down Expand Up @@ -203,7 +200,6 @@ def GetSegments(
if request.HasField("scope")
else None
)
target_topic = request.topic if request.HasField("topic") else None
target_collection = (
UUID(hex=request.collection) if request.HasField("collection") else None
)
Expand All @@ -216,8 +212,6 @@ def GetSegments(
continue
if target_scope and segment["scope"] != target_scope:
continue
if target_topic and segment["topic"] != target_topic:
continue
if target_collection and segment["collection"] != target_collection:
continue
found_segments.append(segment)
Expand All @@ -238,10 +232,6 @@ def UpdateSegment(
)
else:
segment = self._segments[id_to_update.hex]
if request.HasField("topic"):
segment["topic"] = request.topic
if request.HasField("reset_topic") and request.reset_topic:
segment["topic"] = None
if request.HasField("collection"):
segment["collection"] = UUID(hex=request.collection)
if request.HasField("reset_collection") and request.reset_collection:
Expand Down Expand Up @@ -326,7 +316,6 @@ def CreateCollection(
name=request.name,
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
topic=self._assignment_policy.assign_collection(id),
database=database,
tenant=tenant,
)
Expand Down Expand Up @@ -368,7 +357,6 @@ def GetCollections(
self, request: GetCollectionsRequest, context: grpc.ServicerContext
) -> GetCollectionsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_topic = request.topic if request.HasField("topic") else None
target_name = request.name if request.HasField("name") else None

tenant = request.tenant
Expand All @@ -387,8 +375,6 @@ def GetCollections(
for collection in collections.values():
if target_id and collection["id"] != target_id:
continue
if target_topic and collection["topic"] != target_topic:
continue
if target_name and collection["name"] != target_name:
continue
found_collections.append(collection)
Expand Down Expand Up @@ -418,8 +404,6 @@ def UpdateCollection(
)
else:
collection = collections[id_to_update.hex]
if request.HasField("topic"):
collection["topic"] = request.topic
if request.HasField("name"):
collection["name"] = request.name
if request.HasField("dimension"):
Expand Down
Loading

0 comments on commit 1ce93c7

Please sign in to comment.