Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLN] Separate validation and transformation logic #2899

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 94 additions & 78 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from typing import (
TYPE_CHECKING,
Optional,
Union,
)
import numpy as np
from typing import TYPE_CHECKING, Optional, Union

from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
IncludeEnum,
PyEmbedding,
Include,
Metadata,
Expand Down Expand Up @@ -64,17 +60,23 @@ async def add(
ValueError: If you provide an id that already exists

"""
(
ids,
embeddings,
metadatas,
documents,
uris,
) = self._validate_and_prepare_embedding_set(
ids, embeddings, metadatas, documents, images, uris
add_request = self._validate_and_prepare_add_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)

await self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
await self._client._add(
collection_id=self.id,
ids=add_request["ids"],
embeddings=add_request["embeddings"],
metadatas=add_request["metadatas"],
documents=add_request["documents"],
uris=add_request["uris"],
)

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand All @@ -92,7 +94,7 @@ async def get(
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
include: Include = [IncludeEnum.metadatas, IncludeEnum.documents],
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Expand All @@ -109,25 +111,27 @@ async def get(
GetResult: A GetResult object containing the results.

"""
(
valid_ids,
valid_where,
valid_where_document,
valid_include,
) = self._validate_and_prepare_get_request(ids, where, where_document, include)
get_request = self._validate_and_prepare_get_request(
ids=ids,
where=where,
where_document=where_document,
include=include,
)

get_results = await self._client._get(
self.id,
valid_ids,
valid_where,
None,
limit,
offset,
where_document=valid_where_document,
include=valid_include,
collection_id=self.id,
ids=get_request["ids"],
where=get_request["where"],
where_document=get_request["where_document"],
include=get_request["include"],
sort=None,
limit=limit,
offset=offset,
)

return self._transform_get_response(get_results, valid_include)
return self._transform_get_response(
response=get_results, include=get_request["include"]
)

async def peek(self, limit: int = 10) -> GetResult:
"""Get the first few results in the database up to limit
Expand All @@ -145,7 +149,7 @@ async def query(
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[PyEmbedding],
]
] = None,
query_texts: Optional[OneOrMany[Document]] = None,
Expand All @@ -154,7 +158,11 @@ async def query(
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"],
include: Include = [
IncludeEnum.metadatas,
IncludeEnum.documents,
IncludeEnum.distances,
],
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.

Expand All @@ -178,32 +186,30 @@ async def query(

"""

(
valid_query_embeddings,
valid_n_results,
valid_where,
valid_where_document,
) = self._validate_and_prepare_query_request(
query_embeddings,
query_texts,
query_images,
query_uris,
n_results,
where,
where_document,
include,
# Query data
query_request = self._validate_and_prepare_query_request(
query_embeddings=query_embeddings,
query_texts=query_texts,
query_images=query_images,
query_uris=query_uris,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
)

query_results = await self._client._query(
collection_id=self.id,
query_embeddings=valid_query_embeddings,
n_results=valid_n_results,
where=valid_where,
where_document=valid_where_document,
include=include,
query_embeddings=query_request["embeddings"],
n_results=query_request["n_results"],
where=query_request["where"],
where_document=query_request["where_document"],
include=query_request["include"],
)

return self._transform_query_response(query_results, include)
return self._transform_query_response(
response=query_results, include=query_request["include"]
)

async def modify(
self, name: Optional[str] = None, metadata: Optional[CollectionMetadata] = None
Expand Down Expand Up @@ -233,7 +239,7 @@ async def update(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand All @@ -252,25 +258,31 @@ async def update(
Returns:
None
"""
(
ids,
embeddings,
metadatas,
documents,
uris,
) = self._validate_and_prepare_update_request(
ids, embeddings, metadatas, documents, images, uris
update_request = self._validate_and_prepare_update_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)

await self._client._update(self.id, ids, embeddings, metadatas, documents, uris)
await self._client._update(
collection_id=self.id,
ids=update_request["ids"],
embeddings=update_request["embeddings"],
metadatas=update_request["metadatas"],
documents=update_request["documents"],
uris=update_request["uris"],
)

async def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand All @@ -289,25 +301,24 @@ async def upsert(
Returns:
None
"""
(
ids,
embeddings,
metadatas,
documents,
uris,
) = self._validate_and_prepare_upsert_request(
ids, embeddings, metadatas, documents, images, uris
)

await self._client._upsert(
collection_id=self.id,
upsert_request = self._validate_and_prepare_upsert_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)

await self._client._upsert(
collection_id=self.id,
ids=upsert_request["ids"],
embeddings=upsert_request["embeddings"],
metadatas=upsert_request["metadatas"],
documents=upsert_request["documents"],
uris=upsert_request["uris"],
)

async def delete(
self,
ids: Optional[IDs] = None,
Expand All @@ -327,8 +338,13 @@ async def delete(
Raises:
ValueError: If you don't provide either ids, where, or where_document
"""
(ids, where, where_document) = self._validate_and_prepare_delete_request(
delete_request = self._validate_and_prepare_delete_request(
ids, where, where_document
)

await self._client._delete(self.id, ids, where, where_document)
await self._client._delete(
collection_id=self.id,
ids=delete_request["ids"],
where=delete_request["where"],
where_document=delete_request["where_document"],
)
Loading
Loading