Skip to content

Commit

Permalink
Request types
Browse files Browse the repository at this point in the history
  • Loading branch information
atroyn committed Oct 6, 2024
1 parent 91ba2c9 commit 61b96f3
Show file tree
Hide file tree
Showing 6 changed files with 724 additions and 533 deletions.
172 changes: 94 additions & 78 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from typing import (
TYPE_CHECKING,
Optional,
Union,
)
import numpy as np
from typing import TYPE_CHECKING, Optional, Union

from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
IncludeEnum,
PyEmbedding,
Include,
Metadata,
Expand Down Expand Up @@ -64,17 +60,23 @@ async def add(
ValueError: If you provide an id that already exists
"""
(
ids,
embeddings,
metadatas,
documents,
uris,
) = self._validate_and_prepare_embedding_set(
ids, embeddings, metadatas, documents, images, uris
add_request = self._validate_and_prepare_add_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)

await self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
await self._client._add(
collection_id=self.id,
ids=add_request["ids"],
embeddings=add_request["embeddings"],
metadatas=add_request["metadatas"],
documents=add_request["documents"],
uris=add_request["uris"],
)

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand All @@ -92,7 +94,7 @@ async def get(
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
include: Include = [IncludeEnum.metadatas, IncludeEnum.documents],
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Expand All @@ -109,25 +111,27 @@ async def get(
GetResult: A GetResult object containing the results.
"""
(
valid_ids,
valid_where,
valid_where_document,
valid_include,
) = self._validate_and_prepare_get_request(ids, where, where_document, include)
get_request = self._validate_and_prepare_get_request(
ids=ids,
where=where,
where_document=where_document,
include=include,
)

get_results = await self._client._get(
self.id,
valid_ids,
valid_where,
None,
limit,
offset,
where_document=valid_where_document,
include=valid_include,
collection_id=self.id,
ids=get_request["ids"],
where=get_request["where"],
where_document=get_request["where_document"],
include=get_request["include"],
sort=None,
limit=limit,
offset=offset,
)

return self._transform_get_response(get_results, valid_include)
return self._transform_get_response(
response=get_results, include=get_request["include"]
)

async def peek(self, limit: int = 10) -> GetResult:
"""Get the first few results in the database up to limit
Expand All @@ -145,7 +149,7 @@ async def query(
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[PyEmbedding],
]
] = None,
query_texts: Optional[OneOrMany[Document]] = None,
Expand All @@ -154,7 +158,11 @@ async def query(
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"],
include: Include = [
IncludeEnum.metadatas,
IncludeEnum.documents,
IncludeEnum.distances,
],
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
Expand All @@ -178,32 +186,30 @@ async def query(
"""

(
valid_query_embeddings,
valid_n_results,
valid_where,
valid_where_document,
) = self._validate_and_prepare_query_request(
query_embeddings,
query_texts,
query_images,
query_uris,
n_results,
where,
where_document,
include,
# Query data
query_request = self._validate_and_prepare_query_request(
query_embeddings=query_embeddings,
query_texts=query_texts,
query_images=query_images,
query_uris=query_uris,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
)

query_results = await self._client._query(
collection_id=self.id,
query_embeddings=valid_query_embeddings,
n_results=valid_n_results,
where=valid_where,
where_document=valid_where_document,
include=include,
query_embeddings=query_request["embeddings"],
n_results=query_request["n_results"],
where=query_request["where"],
where_document=query_request["where_document"],
include=query_request["include"],
)

return self._transform_query_response(query_results, include)
return self._transform_query_response(
response=query_results, include=query_request["include"]
)

async def modify(
self, name: Optional[str] = None, metadata: Optional[CollectionMetadata] = None
Expand Down Expand Up @@ -233,7 +239,7 @@ async def update(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand All @@ -252,25 +258,31 @@ async def update(
Returns:
None
"""
(
ids,
embeddings,
metadatas,
documents,
uris,
) = self._validate_and_prepare_update_request(
ids, embeddings, metadatas, documents, images, uris
update_request = self._validate_and_prepare_update_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)

await self._client._update(self.id, ids, embeddings, metadatas, documents, uris)
await self._client._update(
collection_id=self.id,
ids=update_request["ids"],
embeddings=update_request["embeddings"],
metadatas=update_request["metadatas"],
documents=update_request["documents"],
uris=update_request["uris"],
)

async def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand All @@ -289,25 +301,24 @@ async def upsert(
Returns:
None
"""
(
ids,
embeddings,
metadatas,
documents,
uris,
) = self._validate_and_prepare_upsert_request(
ids, embeddings, metadatas, documents, images, uris
)

await self._client._upsert(
collection_id=self.id,
upsert_request = self._validate_and_prepare_upsert_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)

await self._client._upsert(
collection_id=self.id,
ids=upsert_request["ids"],
embeddings=upsert_request["embeddings"],
metadatas=upsert_request["metadatas"],
documents=upsert_request["documents"],
uris=upsert_request["uris"],
)

async def delete(
self,
ids: Optional[IDs] = None,
Expand All @@ -327,8 +338,13 @@ async def delete(
Raises:
ValueError: If you don't provide either ids, where, or where_document
"""
(ids, where, where_document) = self._validate_and_prepare_delete_request(
delete_request = self._validate_and_prepare_delete_request(
ids, where, where_document
)

await self._client._delete(self.id, ids, where, where_document)
await self._client._delete(
collection_id=self.id,
ids=delete_request["ids"],
where=delete_request["where"],
where_document=delete_request["where_document"],
)
Loading

0 comments on commit 61b96f3

Please sign in to comment.