diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index 02491ba1035..2a3993cef92 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -1,14 +1,10 @@ -from typing import ( - TYPE_CHECKING, - Optional, - Union, -) -import numpy as np +from typing import TYPE_CHECKING, Optional, Union from chromadb.api.types import ( URI, CollectionMetadata, Embedding, + IncludeEnum, PyEmbedding, Include, Metadata, @@ -64,17 +60,23 @@ async def add( ValueError: If you provide an id that already exists """ - ( - ids, - embeddings, - metadatas, - documents, - uris, - ) = self._validate_and_prepare_embedding_set( - ids, embeddings, metadatas, documents, images, uris + add_request = self._validate_and_prepare_add_request( + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + images=images, + uris=uris, ) - await self._client._add(ids, self.id, embeddings, metadatas, documents, uris) + await self._client._add( + collection_id=self.id, + ids=add_request["ids"], + embeddings=add_request["embeddings"], + metadatas=add_request["metadatas"], + documents=add_request["documents"], + uris=add_request["uris"], + ) async def count(self) -> int: """The total number of embeddings added to the database @@ -92,7 +94,7 @@ async def get( limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents"], + include: Include = [IncludeEnum.metadatas, IncludeEnum.documents], ) -> GetResult: """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. @@ -109,25 +111,27 @@ async def get( GetResult: A GetResult object containing the results. """ - ( - valid_ids, - valid_where, - valid_where_document, - valid_include, - ) = self._validate_and_prepare_get_request(ids, where, where_document, include) + get_request = self._validate_and_prepare_get_request( + ids=ids, + where=where, + where_document=where_document, + include=include, + ) get_results = await self._client._get( - self.id, - valid_ids, - valid_where, - None, - limit, - offset, - where_document=valid_where_document, - include=valid_include, + collection_id=self.id, + ids=get_request["ids"], + where=get_request["where"], + where_document=get_request["where_document"], + include=get_request["include"], + sort=None, + limit=limit, + offset=offset, ) - return self._transform_get_response(get_results, valid_include) + return self._transform_get_response( + response=get_results, include=get_request["include"] + ) async def peek(self, limit: int = 10) -> GetResult: """Get the first few results in the database up to limit @@ -145,7 +149,7 @@ async def query( query_embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[PyEmbedding], ] ] = None, query_texts: Optional[OneOrMany[Document]] = None, @@ -154,7 +158,11 @@ async def query( n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents", "distances"], + include: Include = [ + IncludeEnum.metadatas, + IncludeEnum.documents, + IncludeEnum.distances, + ], ) -> QueryResult: """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. @@ -178,32 +186,30 @@ async def query( """ - ( - valid_query_embeddings, - valid_n_results, - valid_where, - valid_where_document, - ) = self._validate_and_prepare_query_request( - query_embeddings, - query_texts, - query_images, - query_uris, - n_results, - where, - where_document, - include, + # Query data + query_request = self._validate_and_prepare_query_request( + query_embeddings=query_embeddings, + query_texts=query_texts, + query_images=query_images, + query_uris=query_uris, + n_results=n_results, + where=where, + where_document=where_document, + include=include, ) query_results = await self._client._query( collection_id=self.id, - query_embeddings=valid_query_embeddings, - n_results=valid_n_results, - where=valid_where, - where_document=valid_where_document, - include=include, + query_embeddings=query_request["embeddings"], + n_results=query_request["n_results"], + where=query_request["where"], + where_document=query_request["where_document"], + include=query_request["include"], ) - return self._transform_query_response(query_results, include) + return self._transform_query_response( + response=query_results, include=query_request["include"] + ) async def modify( self, name: Optional[str] = None, metadata: Optional[CollectionMetadata] = None @@ -233,7 +239,7 @@ async def update( embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[PyEmbedding], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, @@ -252,17 +258,23 @@ async def update( Returns: None """ - ( - ids, - embeddings, - metadatas, - documents, - uris, - ) = self._validate_and_prepare_update_request( - ids, embeddings, metadatas, documents, images, uris + update_request = self._validate_and_prepare_update_request( + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + images=images, + uris=uris, ) - await self._client._update(self.id, ids, embeddings, metadatas, documents, uris) + await self._client._update( + collection_id=self.id, + ids=update_request["ids"], + embeddings=update_request["embeddings"], + metadatas=update_request["metadatas"], + documents=update_request["documents"], + uris=update_request["uris"], + ) async def upsert( self, @@ -270,7 +282,7 @@ async def upsert( embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[PyEmbedding], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, @@ -289,25 +301,24 @@ async def upsert( Returns: None """ - ( - ids, - embeddings, - metadatas, - documents, - uris, - ) = self._validate_and_prepare_upsert_request( - ids, embeddings, metadatas, documents, images, uris - ) - - await self._client._upsert( - collection_id=self.id, + upsert_request = self._validate_and_prepare_upsert_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, + images=images, uris=uris, ) + await self._client._upsert( + collection_id=self.id, + ids=upsert_request["ids"], + embeddings=upsert_request["embeddings"], + metadatas=upsert_request["metadatas"], + documents=upsert_request["documents"], + uris=upsert_request["uris"], + ) + async def delete( self, ids: Optional[IDs] = None, @@ -327,8 +338,13 @@ async def delete( Raises: ValueError: If you don't provide either ids, where, or where_document """ - (ids, where, where_document) = self._validate_and_prepare_delete_request( + delete_request = self._validate_and_prepare_delete_request( ids, where, where_document ) - await self._client._delete(self.id, ids, where, where_document) + await self._client._delete( + collection_id=self.id, + ids=delete_request["ids"], + where=delete_request["where"], + where_document=delete_request["where_document"], + ) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index f3aaf2f57dd..953d1995503 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,11 +1,11 @@ from typing import TYPE_CHECKING, Optional, Union -import numpy as np from chromadb.api.models.CollectionCommon import CollectionCommon from chromadb.api.types import ( URI, CollectionMetadata, Embedding, + IncludeEnum, PyEmbedding, Include, Metadata, @@ -41,7 +41,7 @@ def count(self) -> int: def add( self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] + embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], @@ -72,17 +72,24 @@ def add( ValueError: If you provide an id that already exists """ - ( - ids, - embeddings, - metadatas, - documents, - uris, - ) = self._validate_and_prepare_embedding_set( - ids, embeddings, metadatas, documents, images, uris + + add_request = self._validate_and_prepare_add_request( + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + images=images, + uris=uris, ) - self._client._add(ids, self.id, embeddings, metadatas, documents, uris) + self._client._add( + collection_id=self.id, + ids=add_request["ids"], + embeddings=add_request["embeddings"], + metadatas=add_request["metadatas"], + documents=add_request["documents"], + uris=add_request["uris"], + ) def get( self, @@ -91,7 +98,7 @@ def get( limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents"], + include: Include = [IncludeEnum.metadatas, IncludeEnum.documents], ) -> GetResult: """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. @@ -108,25 +115,27 @@ def get( GetResult: A GetResult object containing the results. """ - ( - valid_ids, - valid_where, - valid_where_document, - valid_include, - ) = self._validate_and_prepare_get_request(ids, where, where_document, include) + get_request = self._validate_and_prepare_get_request( + ids=ids, + where=where, + where_document=where_document, + include=include, + ) get_results = self._client._get( - self.id, - valid_ids, - valid_where, - None, - limit, - offset, - where_document=valid_where_document, - include=valid_include, + collection_id=self.id, + ids=get_request["ids"], + where=get_request["where"], + where_document=get_request["where_document"], + include=get_request["include"], + sort=None, + limit=limit, + offset=offset, ) - return self._transform_get_response(get_results, include) + return self._transform_get_response( + response=get_results, include=get_request["include"] + ) def peek(self, limit: int = 10) -> GetResult: """Get the first few results in the database up to limit @@ -141,7 +150,7 @@ def peek(self, limit: int = 10) -> GetResult: def query( self, - query_embeddings: Optional[ # type: ignore[type-arg] + query_embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], @@ -153,7 +162,11 @@ def query( n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents", "distances"], + include: Include = [ + IncludeEnum.metadatas, + IncludeEnum.documents, + IncludeEnum.distances, + ], ) -> QueryResult: """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. @@ -178,32 +191,29 @@ def query( """ - ( - valid_query_embeddings, - valid_n_results, - valid_where, - valid_where_document, - ) = self._validate_and_prepare_query_request( - query_embeddings, - query_texts, - query_images, - query_uris, - n_results, - where, - where_document, - include, + query_request = self._validate_and_prepare_query_request( + query_embeddings=query_embeddings, + query_texts=query_texts, + query_images=query_images, + query_uris=query_uris, + n_results=n_results, + where=where, + where_document=where_document, + include=include, ) query_results = self._client._query( collection_id=self.id, - query_embeddings=valid_query_embeddings, - n_results=valid_n_results, - where=valid_where, - where_document=valid_where_document, - include=include, + query_embeddings=query_request["embeddings"], + n_results=query_request["n_results"], + where=query_request["where"], + where_document=query_request["where_document"], + include=query_request["include"], ) - return self._transform_query_response(query_results, include) + return self._transform_query_response( + response=query_results, include=query_request["include"] + ) def modify( self, name: Optional[str] = None, metadata: Optional[CollectionMetadata] = None @@ -230,10 +240,10 @@ def modify( def update( self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] + embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[PyEmbedding], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, @@ -252,22 +262,28 @@ def update( Returns: None """ - ( - ids, - embeddings, - metadatas, - documents, - uris, - ) = self._validate_and_prepare_update_request( - ids, embeddings, metadatas, documents, images, uris + update_request = self._validate_and_prepare_update_request( + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + images=images, + uris=uris, ) - self._client._update(self.id, ids, embeddings, metadatas, documents, uris) + self._client._update( + collection_id=self.id, + ids=update_request["ids"], + embeddings=update_request["embeddings"], + metadatas=update_request["metadatas"], + documents=update_request["documents"], + uris=update_request["uris"], + ) def upsert( self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] + embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], @@ -289,25 +305,24 @@ def upsert( Returns: None """ - ( - ids, - embeddings, - metadatas, - documents, - uris, - ) = self._validate_and_prepare_upsert_request( - ids, embeddings, metadatas, documents, images, uris - ) - - self._client._upsert( - collection_id=self.id, + upsert_request = self._validate_and_prepare_upsert_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, + images=images, uris=uris, ) + self._client._upsert( + collection_id=self.id, + ids=upsert_request["ids"], + embeddings=upsert_request["embeddings"], + metadatas=upsert_request["metadatas"], + documents=upsert_request["documents"], + uris=upsert_request["uris"], + ) + def delete( self, ids: Optional[IDs] = None, @@ -327,8 +342,13 @@ def delete( Raises: ValueError: If you don't provide either ids, where, or where_document """ - (ids, where, where_document) = self._validate_and_prepare_delete_request( + delete_request = self._validate_and_prepare_delete_request( ids, where, where_document ) - self._client._delete(self.id, ids, where, where_document) + self._client._delete( + collection_id=self.id, + ids=delete_request["ids"], + where=delete_request["where"], + where_document=delete_request["where_document"], + ) diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 23b04f653b2..0f161c5eab6 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -1,56 +1,52 @@ +from dataclasses import asdict from typing import ( TYPE_CHECKING, Dict, Generic, Optional, - Tuple, Any, TypeVar, Union, cast, ) +from chromadb.types import Metadata import numpy as np from uuid import UUID import chromadb.utils.embedding_functions as ef from chromadb.api.types import ( URI, + AddRequest, CollectionMetadata, DataLoader, + DeleteRequest, Embedding, Embeddings, + FilterSet, + GetRequest, + IncludeEnum, PyEmbedding, Embeddable, GetResult, Include, Loadable, - Metadata, - Metadatas, Document, - Documents, Image, - Images, + QueryRequest, QueryResult, - URIs, IDs, EmbeddingFunction, ID, OneOrMany, - maybe_cast_one_to_many_ids, - maybe_cast_one_to_many_embedding, - maybe_cast_one_to_many_metadata, - maybe_cast_one_to_many_document, - maybe_cast_one_to_many_image, - maybe_cast_one_to_many_uri, + UpdateRequest, + UpsertRequest, + maybe_cast_one_to_many, validate_ids, - validate_include, validate_metadata, - validate_metadatas, - validate_embeddings, validate_embedding_function, - validate_n_results, validate_where, validate_where_document, + RecordSet, ) # TODO: We should rename the types in chromadb.types to be Models where @@ -149,7 +145,7 @@ def __repr__(self) -> str: def get_model(self) -> CollectionModel: return self._model - def _validate_embedding_set( + def _validate_and_prepare_add_request( self, ids: OneOrMany[ID], embeddings: Optional[ @@ -160,89 +156,137 @@ def _validate_embedding_set( ], metadatas: Optional[OneOrMany[Metadata]], documents: Optional[OneOrMany[Document]], - images: Optional[OneOrMany[Image]] = None, - uris: Optional[OneOrMany[URI]] = None, - require_embeddings_or_data: bool = True, - ) -> Tuple[ - IDs, - Optional[Embeddings], - Optional[Metadatas], - Optional[Documents], - Optional[Images], - Optional[URIs], - ]: - valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) - valid_embeddings = ( - validate_embeddings( - self._normalize_embeddings(maybe_cast_one_to_many_embedding(embeddings)) - ) - if embeddings is not None - else None - ) - valid_metadatas = ( - validate_metadatas(maybe_cast_one_to_many_metadata(metadatas)) - if metadatas is not None - else None - ) - valid_documents = ( - maybe_cast_one_to_many_document(documents) - if documents is not None - else None - ) - valid_images = ( - maybe_cast_one_to_many_image(images) if images is not None else None + images: Optional[OneOrMany[Image]], + uris: Optional[OneOrMany[URI]], + ) -> AddRequest: + # Unpack + add_records = RecordSet.unpack( + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + images=images, + uris=uris, ) - valid_uris = maybe_cast_one_to_many_uri(uris) if uris is not None else None + # Validate + add_records.validate() + add_records.validate_contains_any({"ids"}) + + # Prepare + if add_records.embeddings is None: + add_records.validate_for_embedding() + add_embeddings = self._embed_record_set(add_records) + else: + add_embeddings = add_records.embeddings + + return AddRequest( + ids=add_records.ids, + embeddings=add_embeddings, + metadatas=add_records.metadatas, + documents=add_records.documents, + uris=add_records.uris, + ) - # Check that one of embeddings or ducuments or images is provided - if require_embeddings_or_data: - if ( - valid_embeddings is None - and valid_documents is None - and valid_images is None - and valid_uris is None - ): - raise ValueError( - "You must provide embeddings, documents, images, or uris." - ) + def _validate_and_prepare_get_request( + self, + ids: Optional[OneOrMany[ID]], + where: Optional[Where], + where_document: Optional[WhereDocument], + include: Include, + ) -> GetRequest: + # Unpack + unpacked_ids: Optional[IDs] = maybe_cast_one_to_many(ids) + filters = FilterSet(where=where, where_document=where_document, include=include) - # Only one of documents or images can be provided - if valid_documents is not None and valid_images is not None: - raise ValueError("You can only provide documents or images, not both.") + # Validate + if unpacked_ids is not None: + validate_ids(unpacked_ids) + filters.validate() - # Check that, if they're provided, the lengths of the arrays match the length of ids - if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids): - raise ValueError( - f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}" - ) - if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids): - raise ValueError( - f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}" - ) - if valid_documents is not None and len(valid_documents) != len(valid_ids): - raise ValueError( - f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}" - ) - if valid_images is not None and len(valid_images) != len(valid_ids): - raise ValueError( - f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}" - ) - if valid_uris is not None and len(valid_uris) != len(valid_ids): + # Prepare + if "data" in include and self._data_loader is None: raise ValueError( - f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}" + "You must set a data loader on the collection if loading from URIs." ) - return ( - valid_ids, - valid_embeddings, - valid_metadatas, - valid_documents, - valid_images, - valid_uris, + # We need to include uris in the result from the API to load datas + if "data" in include and "uris" not in include: + filters.include.append("uris") # type: ignore[arg-type] + + return GetRequest( + ids=unpacked_ids, + where=filters.where, + where_document=filters.where_document, + include=filters.include, + ) + + def _validate_and_prepare_query_request( + self, + query_embeddings: Optional[ + Union[ + OneOrMany[Embedding], + OneOrMany[PyEmbedding], + ] + ], + query_texts: Optional[OneOrMany[Document]], + query_images: Optional[OneOrMany[Image]], + query_uris: Optional[OneOrMany[URI]], + n_results: int, + where: Optional[Where], + where_document: Optional[WhereDocument], + include: Include, + ) -> QueryRequest: + # Unpack + query_records = RecordSet.unpack( + embeddings=query_embeddings, + documents=query_texts, + images=query_images, + uris=query_uris, + ) + + filters = FilterSet( + where=where, + where_document=where_document, + include=include, + n_results=n_results, ) - def _validate_and_prepare_embedding_set( + # Validate + query_records.validate() + filters.validate() + + # Prepare + if query_records.embeddings is None: + query_records.validate_for_embedding() + request_embeddings = self._embed_record_set(query_records) + else: + request_embeddings = query_records.embeddings + + if filters.where is None: + request_where = {} + else: + request_where = filters.where + + if filters.where_document is None: + request_where_document = {} + else: + request_where_document = filters.where_document + + # We need to manually include uris in the result from the API to load datas + request_include = filters.include + if "data" in request_include and "uris" not in request_include: + request_include.append(IncludeEnum.uris) + + return QueryRequest( + embeddings=request_embeddings, + where=request_where, + where_document=request_where_document, + include=request_include, + n_results=cast(int, filters.n_results), + ) + + def _validate_and_prepare_update_request( self, ids: OneOrMany[ID], embeddings: Optional[ @@ -255,68 +299,116 @@ def _validate_and_prepare_embedding_set( documents: Optional[OneOrMany[Document]], images: Optional[OneOrMany[Image]], uris: Optional[OneOrMany[URI]], - ) -> Tuple[ - IDs, - Embeddings, - Optional[Metadatas], - Optional[Documents], - Optional[URIs], - ]: - ( - ids, - embeddings, - metadatas, - documents, - images, - uris, - ) = self._validate_embedding_set( - ids, embeddings, metadatas, documents, images, uris + ) -> UpdateRequest: + # Unpack + update_records = RecordSet.unpack( + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + images=images, + uris=uris, ) - # We need to compute the embeddings if they're not provided - if embeddings is None: - # At this point, we know that one of documents or images are provided from the validation above - if documents is not None: - embeddings = self._embed(input=documents) - elif images is not None: - embeddings = self._embed(input=images) + # Validate + update_records.validate() + update_records.validate_contains_any({"ids"}) + + # Prepare + if update_records.embeddings is None: + # TODO: Handle URI updates. + if ( + update_records.documents is not None + or update_records.images is not None + ): + update_records.validate_for_embedding( + embeddable_fields={"documents", "images"} + ) + update_embeddings = self._embed_record_set(update_records) else: - if uris is None: - raise ValueError( - "You must provide either embeddings, documents, images, or uris." - ) - if self._data_loader is None: - raise ValueError( - "You must set a data loader on the collection if loading from URIs." - ) - embeddings = self._embed(self._data_loader(uris)) - - return ids, embeddings, metadatas, documents, uris + update_embeddings = None + else: + update_embeddings = update_records.embeddings + + return UpdateRequest( + ids=update_records.ids, + embeddings=update_embeddings, + metadatas=update_records.metadatas, + documents=update_records.documents, + uris=update_records.uris, + ) - def _validate_and_prepare_get_request( + def _validate_and_prepare_upsert_request( self, - ids: Optional[OneOrMany[ID]], - where: Optional[Where], - where_document: Optional[WhereDocument], - include: Include, - ) -> Tuple[Optional[IDs], Optional[Where], Optional[WhereDocument], Include,]: - valid_where = validate_where(where) if where else None - valid_where_document = ( - validate_where_document(where_document) if where_document else None + ids: OneOrMany[ID], + embeddings: Optional[ + Union[ + OneOrMany[Embedding], + OneOrMany[PyEmbedding], + ] + ] = None, + metadatas: Optional[OneOrMany[Metadata]] = None, + documents: Optional[OneOrMany[Document]] = None, + images: Optional[OneOrMany[Image]] = None, + uris: Optional[OneOrMany[URI]] = None, + ) -> UpsertRequest: + # Unpack + upsert_records = RecordSet.unpack( + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + images=images, + uris=uris, ) - valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None - valid_include = validate_include(include, allow_distances=False) - if "data" in include and self._data_loader is None: - raise ValueError( - "You must set a data loader on the collection if loading from URIs." + # Validate + upsert_records.validate() + upsert_records.validate_contains_any({"ids"}) + + # Prepare + if upsert_records.embeddings is None: + # TODO: Handle URI upserts. + upsert_records.validate_for_embedding( + embeddable_fields={"documents", "images"} ) + upsert_embeddings = self._embed_record_set(upsert_records) + + return UpsertRequest( + ids=upsert_records.ids, + embeddings=upsert_embeddings, + metadatas=upsert_records.metadatas, + documents=upsert_records.documents, + uris=upsert_records.uris, + ) - # We need to include uris in the result from the API to load datas - if "data" in include and "uris" not in include: - valid_include.append("uris") # type: ignore[arg-type] + def _validate_and_prepare_delete_request( + self, + ids: Optional[IDs], + where: Optional[Where], + where_document: Optional[WhereDocument], + ) -> DeleteRequest: + if ids is None and where is None and where_document is None: + raise ValueError( + "At least one of ids, where, or where_document must be provided." + ) - return valid_ids, valid_where, valid_where_document, valid_include + # Unpack + if ids is not None: + request_ids = cast(IDs, maybe_cast_one_to_many(ids)) + validate_ids(request_ids) + else: + request_ids = None + + # Validate - Note that FilterSet is not used here since there is no Include or n_results + if where_document is not None: + validate_where_document(where_document) + if where is not None: + validate_where(where) + + return DeleteRequest( + ids=request_ids, where=where, where_document=where_document + ) def _transform_peek_response(self, response: GetResult) -> GetResult: if response["embeddings"] is not None: @@ -343,91 +435,6 @@ def _transform_get_response( return response - def _validate_and_prepare_query_request( - self, - query_embeddings: Optional[ - Union[ - OneOrMany[Embedding], - OneOrMany[PyEmbedding], - ] - ], - query_texts: Optional[OneOrMany[Document]], - query_images: Optional[OneOrMany[Image]], - query_uris: Optional[OneOrMany[URI]], - n_results: int, - where: Optional[Where], - where_document: Optional[WhereDocument], - include: Include, - ) -> Tuple[Embeddings, int, Where, WhereDocument,]: - # Users must provide only one of query_embeddings, query_texts, query_images, or query_uris - if not ( - (query_embeddings is not None) - ^ (query_texts is not None) - ^ (query_images is not None) - ^ (query_uris is not None) - ): - raise ValueError( - "You must provide one of query_embeddings, query_texts, query_images, or query_uris." - ) - - valid_where = validate_where(where) if where else {} - valid_where_document = ( - validate_where_document(where_document) if where_document else {} - ) - valid_query_embeddings = ( - validate_embeddings( - self._normalize_embeddings( - maybe_cast_one_to_many_embedding(query_embeddings) - ) - ) - if query_embeddings is not None - else None - ) - valid_query_texts = ( - maybe_cast_one_to_many_document(query_texts) - if query_texts is not None - else None - ) - valid_query_images = ( - maybe_cast_one_to_many_image(query_images) - if query_images is not None - else None - ) - valid_query_uris = ( - maybe_cast_one_to_many_uri(query_uris) if query_uris is not None else None - ) - valid_include = validate_include(include, allow_distances=True) - valid_n_results = validate_n_results(n_results) - - # If query_embeddings are not provided, we need to compute them from the inputs - if valid_query_embeddings is None: - if query_texts is not None: - valid_query_embeddings = self._embed(input=valid_query_texts) - elif query_images is not None: - valid_query_embeddings = self._embed(input=valid_query_images) - else: - if valid_query_uris is None: - raise ValueError( - "You must provide either query_embeddings, query_texts, query_images, or query_uris." - ) - if self._data_loader is None: - raise ValueError( - "You must set a data loader on the collection if loading from URIs." - ) - valid_query_embeddings = self._embed( - self._data_loader(valid_query_uris) - ) - - if "data" in include and "uris" not in include: - valid_include.append("uris") # type: ignore[arg-type] - - return ( - valid_query_embeddings, - valid_n_results, - valid_where, - valid_where_document, - ) - def _transform_query_response( self, response: QueryResult, include: Include ) -> QueryResult: @@ -465,113 +472,25 @@ def _update_model_after_modify_success( if metadata: self._model["metadata"] = metadata - def _validate_and_prepare_update_request( - self, - ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] - Union[ - OneOrMany[Embedding], - OneOrMany[np.ndarray], - ] - ], - metadatas: Optional[OneOrMany[Metadata]], - documents: Optional[OneOrMany[Document]], - images: Optional[OneOrMany[Image]], - uris: Optional[OneOrMany[URI]], - ) -> Tuple[ - IDs, - Embeddings, - Optional[Metadatas], - Optional[Documents], - Optional[URIs], - ]: - ( - ids, - embeddings, - metadatas, - documents, - images, - uris, - ) = self._validate_embedding_set( - ids, - embeddings, - metadatas, - documents, - images, - uris, - require_embeddings_or_data=False, - ) - - if embeddings is None: - if documents is not None: - embeddings = self._embed(input=documents) - elif images is not None: - embeddings = self._embed(input=images) - - return ids, cast(Embeddings, embeddings), metadatas, documents, uris - - def _validate_and_prepare_upsert_request( - self, - ids: OneOrMany[ID], - embeddings: Optional[ - Union[ - OneOrMany[Embedding], - OneOrMany[PyEmbedding], - ] - ], - metadatas: Optional[OneOrMany[Metadata]], - documents: Optional[OneOrMany[Document]], - images: Optional[OneOrMany[Image]], - uris: Optional[OneOrMany[URI]], - ) -> Tuple[ - IDs, - Embeddings, - Optional[Metadatas], - Optional[Documents], - Optional[URIs], - ]: - ( - ids, - embeddings, - metadatas, - documents, - images, - uris, - ) = self._validate_embedding_set( - ids, embeddings, metadatas, documents, images, uris - ) - - if embeddings is None: - if documents is not None: - embeddings = self._embed(input=documents) - else: - embeddings = self._embed(input=images) - - return ids, embeddings, metadatas, documents, uris - - def _validate_and_prepare_delete_request( - self, - ids: Optional[IDs], - where: Optional[Where], - where_document: Optional[WhereDocument], - ) -> Tuple[Optional[IDs], Optional[Where], Optional[WhereDocument]]: - ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None - where = validate_where(where) if where else None - where_document = ( - validate_where_document(where_document) if where_document else None + def _embed_record_set(self, record_set: RecordSet) -> Embeddings: + record_dict = asdict(record_set) + for field in record_set.get_embeddable_fields(): + if record_dict[field] is not None: + # uris require special handling + if field == "uris": + if self._data_loader is None: + raise ValueError( + "You must set a data loader on the collection if loading from URIs." + ) + return self._embed(input=self._data_loader(uris=record_dict[field])) + else: + return self._embed(input=record_dict[field]) + raise ValueError( + "Record does not contain any fields that can be embedded." + f"Embeddable Fields: {record_set.get_embeddable_fields()}" + f"Record Fields: {record_dict.keys()}" ) - return (ids, where, where_document) - - @staticmethod - def _normalize_embeddings( - embeddings: Union[ - OneOrMany[Embedding], - OneOrMany[PyEmbedding], - ] - ) -> Embeddings: - return cast(Embeddings, [np.array(embedding) for embedding in embeddings]) - def _embed(self, input: Any) -> Embeddings: if self._embedding_function is None: raise ValueError( diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index ca4d7a01644..6e78c8b820b 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -481,12 +481,15 @@ def _get( log_position=coll.log_position, ) - where = validate_where(where) if where is not None and len(where) > 0 else None - where_document = ( + if where is not None and len(where) > 0: + validate_where(where) + else: + where = None + + if where_document is not None and len(where_document) > 0: validate_where_document(where_document) - if where_document is not None and len(where_document) > 0 - else None - ) + else: + where_document = None metadata_segment = self._manager.get_segment(collection_id, MetadataReader) @@ -580,12 +583,16 @@ def _delete( } ) - where = validate_where(where) if where is not None and len(where) > 0 else None - where_document = ( + if where is not None and len(where) > 0: + validate_where(where) + else: + # TODO: Do we still need empty ( {} ) where? + where = None + + if where_document is not None and len(where_document) > 0: validate_where_document(where_document) - if where_document is not None and len(where_document) > 0 - else None - ) + else: + where_document = None # You must have at least one of non-empty ids, where, or where_document. if ( @@ -707,12 +714,10 @@ def _query( ) ) - where = validate_where(where) if where is not None and len(where) > 0 else where - where_document = ( + if where is not None and len(where) > 0: + validate_where(where) + if where_document is not None and len(where_document) > 0: validate_where_document(where_document) - if where_document is not None and len(where_document) > 0 - else where_document - ) allowed_ids = None diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 4a9b2667e0a..5b9002e617d 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,4 +1,5 @@ -from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast +from dataclasses import dataclass, asdict +from typing import Optional, Set, Union, TypeVar, List, Dict, Any, Tuple, cast from numpy.typing import NDArray import numpy as np from typing_extensions import TypedDict, Protocol, runtime_checkable @@ -27,32 +28,23 @@ T = TypeVar("T") OneOrMany = Union[T, List[T]] -# URIs -URI = str -URIs = List[URI] +def maybe_cast_one_to_many(target: Optional[OneOrMany[T]]) -> Optional[List[T]]: + if target is None: + return None + if isinstance(target, list): + return target + return [target] -def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs: - if isinstance(target, str): - # One URI - return cast(URIs, [target]) - # Already a sequence - return cast(URIs, target) +# URIs +URI = str +URIs = List[URI] # IDs ID = str IDs = List[ID] - -def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: - if isinstance(target, str): - # One ID - return cast(IDs, [target]) - # Already a sequence - return cast(IDs, target) - - # Embeddings PyEmbedding = PyVector PyEmbeddings = List[PyEmbedding] @@ -60,32 +52,38 @@ def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: Embeddings = List[Embedding] -def maybe_cast_one_to_many_embedding( - target: Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]] -) -> Embeddings: - if isinstance(target, List): - # One Embedding +def normalize_embeddings( + target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]] +) -> Optional[Embeddings]: + if target is None: + return None + if isinstance(target, list): + # One PyEmbedding if isinstance(target[0], (int, float)): - return cast(Embeddings, [target]) + return [np.array(target, dtype=np.float32)] + if isinstance(target[0], list): + if isinstance(target[0][0], (int, float)): + return [np.array(embedding, dtype=np.float32) for embedding in target] + if isinstance(target[0], np.ndarray): + return cast(Embeddings, target) + elif isinstance(target, np.ndarray): - if isinstance(target[0], (np.floating, np.integer)): + if target.ndim == 1: + # A single embedding as a numpy array return cast(Embeddings, [target]) - # Already a sequence - return cast(Embeddings, target) + if target.ndim == 2: + # 2-D numpy array (comes out of embedding models) + # TODO: Enforce this at the embedding function level + return list(target) + + raise ValueError( + f"Expected embeddings to be a list of floats or ints, a list of lists, a numpy array, or a list of numpy arrays, got {target}" + ) # Metadatas Metadatas = List[Metadata] - -def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas: - # One Metadata dict - if isinstance(target, dict): - return cast(Metadatas, [target]) - # Already a sequence - return cast(Metadatas, target) - - CollectionMetadata = Dict[str, Any] UpdateCollectionMetadata = UpdateMetadata @@ -100,14 +98,6 @@ def is_document(target: Any) -> bool: return True -def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: - # One Document - if is_document(target): - return cast(Documents, [target]) - # Already a sequence - return cast(Documents, target) - - # Images ImageDType = Union[np.uint, np.int64, np.float64] Image = NDArray[ImageDType] @@ -122,11 +112,144 @@ def is_image(target: Any) -> bool: return True -def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images: - if is_image(target): - return cast(Images, [target]) - # Already a sequence - return cast(Images, target) +@dataclass +class RecordSet: + """ + Internal representation of a record in the database, where all fields + have been unpacked into lists of the right type, and are normalized. + """ + + ids: IDs + embeddings: Optional[ + Embeddings + ] # Optional because embeddings may not have been computed yet + documents: Optional[Documents] + images: Optional[Images] + uris: Optional[URIs] + metadatas: Optional[Metadatas] + + @staticmethod + def get_embeddable_fields() -> Set[str]: + """ + Returns the set of fields that can be embedded. + """ + return {"documents", "images", "uris"} + + @staticmethod + def unpack( + ids: Optional[OneOrMany[ID]] = None, + embeddings: Optional[ + Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]] + ] = None, + documents: Optional[OneOrMany[Document]] = None, + images: Optional[OneOrMany[Image]] = None, + uris: Optional[OneOrMany[URI]] = None, + metadatas: Optional[OneOrMany[Metadata]] = None, + ) -> "RecordSet": + """ + Unpacks the fields of a Record into lists of the right type, and normalizes them. + """ + + return RecordSet( + ids=cast(IDs, maybe_cast_one_to_many(ids)), + embeddings=normalize_embeddings(embeddings), + documents=maybe_cast_one_to_many(documents), + images=maybe_cast_one_to_many(images), + uris=maybe_cast_one_to_many(uris), + metadatas=maybe_cast_one_to_many(metadatas), + ) + + def validate(self) -> None: + """ + Validates the RecordSet, ensuring that all fields are of the right type and length. + """ + + # TODO: We should get all the failing validations, not just the first one + + self._validate_length_consistency() + + # Validate individual fields + if self.ids is not None: + validate_ids(self.ids) + if self.embeddings is not None: + validate_embeddings(self.embeddings) + if self.metadatas is not None: + validate_metadatas(self.metadatas) + if self.documents is not None: + validate_documents(self.documents) + if self.images is not None: + validate_images(self.images) + + # TODO: Validate URIs + + def _validate_length_consistency(self) -> None: + data_dict = asdict(self) + + lengths = [len(lst) for lst in data_dict.values() if lst is not None] + + if not lengths: + raise ValueError("All Record fields are None") + + zero_lengths = [ + key for key, lst in data_dict.items() if lst is not None and len(lst) == 0 + ] + + if zero_lengths: + raise ValueError(f"Non-empty lists are required for {zero_lengths}") + + # If we have exactly one non-None list, we're good + if len(set(lengths)) > 1: + error_str = ", ".join( + f"{key}: {len(lst)}" + for key, lst in data_dict.items() + if lst is not None + ) + raise ValueError(f"Unequal lengths for fields: {error_str}") + + def validate_for_embedding( + self, embeddable_fields: Optional[Set[str]] = None + ) -> None: + """ + Validates that the Record is ready to be embedded, i.e. that it contains exactly one of the embeddable fields. + """ + + if self.embeddings is not None: + raise ValueError( + "Attempting to embed a record that already has embeddings. " + ) + if embeddable_fields is None: + embeddable_fields = self.get_embeddable_fields() + self.validate_contains_one(embeddable_fields) + + def validate_contains_any(self, contains_any: Set[str]) -> None: + """ + Validates that at least one of the fields in contains_any is not None. + """ + self._validate_contains(contains_any) + + if not any(getattr(self, field) is not None for field in contains_any): + raise ValueError( + f"At least one of {', '.join(contains_any)} must be provided." + ) + + def validate_contains_one(self, contains_one: Set[str]) -> None: + """ + Validates that exactly one of the fields in contains_one is not None. + """ + self._validate_contains(contains_one) + if sum(getattr(self, field) is not None for field in contains_one) != 1: + raise ValueError( + f"Exactly one of {', '.join(contains_one)} must be provided." + ) + + def _validate_contains(self, contains: Set[str]) -> None: + """ + Validates that all fields in contains are valid fields of the Record. + """ + if any(field not in asdict(self) for field in contains): + raise ValueError( + f"Invalid field in contains: {', '.join(contains)}, available fields: {', '.join(asdict(self).keys())}" + ) Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID) @@ -163,6 +286,24 @@ class IncludeEnum(str, Enum): Where = Where WhereDocumentOperator = WhereDocumentOperator + +@dataclass +class FilterSet: + include: Include + where: Optional[Where] = None + where_document: Optional[WhereDocument] = None + n_results: Optional[int] = None + + def validate(self) -> None: + validate_include(self.include, allow_distances=True) + if self.where is not None: + validate_where(self.where) + if self.where_document is not None: + validate_where_document(self.where_document) + if self.n_results is not None: + validate_n_results(self.n_results) + + Embeddable = Union[Documents, Images] D = TypeVar("D", bound=Embeddable, contravariant=True) @@ -171,6 +312,24 @@ class IncludeEnum(str, Enum): L = TypeVar("L", covariant=True, bound=Loadable) +class AddRequest(TypedDict): + ids: IDs + embeddings: Embeddings + metadatas: Optional[Metadatas] + documents: Optional[Documents] + uris: Optional[URIs] + + +# Add result doesn't exist. + + +class GetRequest(TypedDict): + ids: Optional[IDs] + where: Optional[Where] + where_document: Optional[WhereDocument] + include: Include + + class GetResult(TypedDict): ids: List[ID] embeddings: Optional[ @@ -183,6 +342,14 @@ class GetResult(TypedDict): included: Include +class QueryRequest(TypedDict): + embeddings: Embeddings + where: Where + where_document: WhereDocument + include: Include + n_results: int + + class QueryResult(TypedDict): ids: List[IDs] embeddings: Optional[ @@ -200,6 +367,37 @@ class QueryResult(TypedDict): included: Include +class UpdateRequest(TypedDict): + ids: IDs + embeddings: Optional[Embeddings] + metadatas: Optional[Metadatas] + documents: Optional[Documents] + uris: Optional[URIs] + + +# Update result doesn't exist. + + +class UpsertRequest(TypedDict): + ids: IDs + embeddings: Embeddings + metadatas: Optional[Metadatas] + documents: Optional[Documents] + uris: Optional[URIs] + + +# Upsert result doesn't exist. + + +class DeleteRequest(TypedDict): + ids: Optional[IDs] + where: Optional[Where] + where_document: Optional[WhereDocument] + + +# Delete result doesn't exist. + + class IndexMetadata(TypedDict): dimensionality: int # The current number of elements in the index (total = additions - deletes) @@ -223,9 +421,8 @@ def __init_subclass__(cls) -> None: def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings: result = call(self, input) - return validate_embeddings( - normalize_embeddings(maybe_cast_one_to_many_embedding(result)) - ) + assert result is not None + return validate_embeddings(cast(Embeddings, normalize_embeddings(result))) setattr(cls, "__call__", __call__) @@ -235,15 +432,6 @@ def embed_with_retries( return cast(Embeddings, retry(**retry_kwargs)(self.__call__)(input)) -def normalize_embeddings( - embeddings: Union[ - OneOrMany[Embedding], - OneOrMany[PyEmbedding], - ] -) -> Embeddings: - return cast(Embeddings, [np.array(embedding) for embedding in embeddings]) - - def validate_embedding_function( embedding_function: EmbeddingFunction[Embeddable], ) -> None: @@ -362,7 +550,7 @@ def validate_metadatas(metadatas: Metadatas) -> Metadatas: return metadatas -def validate_where(where: Where) -> Where: +def validate_where(where: Where) -> None: """ Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions, or in the case of $and and $or, a list of where expressions @@ -442,10 +630,9 @@ def validate_where(where: Where) -> Where: f"Expected where operand value to be a non-empty list, and all values to be of the same type " f"got {operand}" ) - return where -def validate_where_document(where_document: WhereDocument) -> WhereDocument: +def validate_where_document(where_document: WhereDocument) -> None: """ Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or, a list of where_document expressions @@ -483,7 +670,6 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument: raise ValueError( "Expected where document operand value for operator $contains to be a non-empty str" ) - return where_document def validate_include(include: Include, allow_distances: bool) -> Include: @@ -557,6 +743,36 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings: return embeddings +def validate_documents(documents: Documents) -> Documents: + """Validates documents to ensure it is a list of strings""" + if not isinstance(documents, list): + raise ValueError( + f"Expected documents to be a list, got {type(documents).__name__}" + ) + if len(documents) == 0: + raise ValueError( + f"Expected documents to be a non-empty list, got {len(documents)} documents" + ) + for document in documents: + if not is_document(document): + raise ValueError(f"Expected document to be a str, got {document}") + return documents + + +def validate_images(images: Images) -> Images: + """Validates images to ensure it is a list of numpy arrays""" + if not isinstance(images, list): + raise ValueError(f"Expected images to be a list, got {type(images).__name__}") + if len(images) == 0: + raise ValueError( + f"Expected images to be a non-empty list, got {len(images)} images" + ) + for image in images: + if not is_image(image): + raise ValueError(f"Expected image to be a numpy array, got {image}") + return images + + def validate_batch( batch: Tuple[ IDs, diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 0284cfe7b3c..23537d74b4e 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -2,6 +2,7 @@ from random import randint from typing import cast, List, Any, Dict import hypothesis +import numpy as np import pytest import hypothesis.strategies as st from hypothesis import given, settings @@ -269,12 +270,13 @@ def test_out_of_order_ids(client: ClientAPI) -> None: coll = client.create_collection( "test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore ) - embeddings: Embeddings = [[1, 2, 3] for _ in ooo_ids] + embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids] coll.add(ids=ooo_ids, embeddings=embeddings) get_ids = coll.get(ids=ooo_ids)["ids"] assert get_ids == ooo_ids +@pytest.mark.xfail(reason="Partial records aren't in our API contract...") def test_add_partial(client: ClientAPI) -> None: """Tests adding a record set with some of the fields set to None.""" reset(client) diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 0500ce135b2..f463dd91e74 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -8,7 +8,14 @@ from hypothesis import given, settings, HealthCheck from typing import Dict, Set, cast, Union, DefaultDict, Any, List from dataclasses import dataclass -from chromadb.api.types import ID, Embeddings, Include, IDs, validate_embeddings +from chromadb.api.types import ( + ID, + Embeddings, + Include, + IDs, + validate_embeddings, + normalize_embeddings, +) from chromadb.config import System import chromadb.errors as errors from chromadb.api import ClientAPI @@ -796,7 +803,12 @@ def test_autocasting_validate_embeddings_for_compatible_types( supported_types: List[Any], ) -> None: embds = strategies.create_embeddings(10, 10, supported_types) - validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds)) + validated_embeddings = validate_embeddings( + cast( + Embeddings, + normalize_embeddings(embds), + ) + ) assert all( [ isinstance(value, np.ndarray) @@ -816,7 +828,9 @@ def test_autocasting_validate_embeddings_with_ndarray( supported_types: List[Any], ) -> None: embds = strategies.create_embeddings_ndarray(10, 10, supported_types) - validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds)) + validated_embeddings = validate_embeddings( + cast(Embeddings, normalize_embeddings(embds)) + ) assert all( [ isinstance(value, np.ndarray) @@ -837,7 +851,7 @@ def test_autocasting_validate_embeddings_incompatible_types( ) -> None: embds = strategies.create_embeddings(10, 10, unsupported_types) with pytest.raises(ValueError) as e: - validate_embeddings(Collection._normalize_embeddings(embds)) + validate_embeddings(cast(Embeddings, normalize_embeddings(embds))) assert "Expected each value in the embedding to be a int or float" in str(e) diff --git a/chromadb/types.py b/chromadb/types.py index ece3e097aa7..2dc98b826fc 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -197,7 +197,7 @@ class Operation(Enum): PyVector = Union[Sequence[float], Sequence[int]] -Vector = NDArray[Union[np.int32, np.float32]] +Vector = NDArray[Union[np.int32, np.float32]] # TODO: Specify that the vector is 1D class VectorEmbeddingRecord(TypedDict):