From cd1945bea3d591e78d29a3d112d0d4b2b25cb51c Mon Sep 17 00:00:00 2001 From: iusztinpaul Date: Wed, 19 Jun 2024 17:09:14 +0300 Subject: [PATCH] feat: Flexible chunking. Add metadata to FE steps --- .../preprocessing/chunking_data_handlers.py | 32 +++++++++++++++ .../preprocessing/embedding_data_handlers.py | 27 +++++++++++-- .../preprocessing/operations/__init__.py | 2 - .../preprocessing/operations/chunking.py | 22 +++++----- .../preprocessing/operations/embeddings.py | 7 ---- llm_engineering/domain/chunks.py | 3 +- llm_engineering/domain/embedded_chunks.py | 3 +- llm_engineering/domain/queries.py | 3 +- .../steps/feature_engineering/clean.py | 17 +++++++- .../query_data_warehouse.py | 17 +++++++- .../steps/feature_engineering/rag.py | 40 +++++++++++++++++-- 11 files changed, 139 insertions(+), 34 deletions(-) delete mode 100644 llm_engineering/application/preprocessing/operations/embeddings.py diff --git a/llm_engineering/application/preprocessing/chunking_data_handlers.py b/llm_engineering/application/preprocessing/chunking_data_handlers.py index 884919a..da3c6a1 100644 --- a/llm_engineering/application/preprocessing/chunking_data_handlers.py +++ b/llm_engineering/application/preprocessing/chunking_data_handlers.py @@ -28,12 +28,26 @@ class ChunkingDataHandler(ABC, Generic[CleanedDocumentT, ChunkT]): All data transformations logic for the chunking step is done here """ + @property + def chunk_size(self) -> int: + return 500 + + @property + def chunk_overlap(self) -> int: + return 50 + @abstractmethod def chunk(self, data_model: CleanedDocumentT) -> list[ChunkT]: pass class PostChunkingHandler(ChunkingDataHandler): + def chunk_size(self) -> int: + return 250 + + def chunk_overlap(self) -> int: + return 25 + def chunk(self, data_model: CleanedPostDocument) -> list[PostChunk]: data_models_list = [] @@ -49,6 +63,10 @@ def chunk(self, data_model: CleanedPostDocument) -> list[PostChunk]: document_id=data_model.id, author_id=data_model.author_id, image=data_model.image if data_model.image else None, + metadata={ + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + }, ) data_models_list.append(model) @@ -71,6 +89,10 @@ def chunk(self, data_model: CleanedArticleDocument) -> list[ArticleChunk]: link=data_model.link, document_id=data_model.id, author_id=data_model.author_id, + metadata={ + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + }, ) data_models_list.append(model) @@ -78,6 +100,12 @@ def chunk(self, data_model: CleanedArticleDocument) -> list[ArticleChunk]: class RepositoryChunkingHandler(ChunkingDataHandler): + def chunk_size(self) -> int: + return 750 + + def chunk_overlap(self) -> int: + return 75 + def chunk(self, data_model: CleanedRepositoryDocument) -> list[RepositoryChunk]: data_models_list = [] @@ -94,6 +122,10 @@ def chunk(self, data_model: CleanedRepositoryDocument) -> list[RepositoryChunk]: link=data_model.link, document_id=data_model.id, author_id=data_model.author_id, + metadata={ + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + }, ) data_models_list.append(model) diff --git a/llm_engineering/application/preprocessing/embedding_data_handlers.py b/llm_engineering/application/preprocessing/embedding_data_handlers.py index ae83ca7..308ab35 100644 --- a/llm_engineering/application/preprocessing/embedding_data_handlers.py +++ b/llm_engineering/application/preprocessing/embedding_data_handlers.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar, cast +from llm_engineering.application.networks import EmbeddingModelSingleton from llm_engineering.domain.chunks import ArticleChunk, Chunk, PostChunk, RepositoryChunk from llm_engineering.domain.embedded_chunks import ( EmbeddedArticleChunk, @@ -10,11 +11,11 @@ ) from llm_engineering.domain.queries import EmbeddedQuery, Query -from .operations import embedd_text - ChunkT = TypeVar("ChunkT", bound=Chunk) EmbeddedChunkT = TypeVar("EmbeddedChunkT", bound=EmbeddedChunk) +embedding_model = EmbeddingModelSingleton() + class EmbeddingDataHandler(ABC, Generic[ChunkT, EmbeddedChunkT]): """ @@ -27,7 +28,7 @@ def embed(self, data_model: ChunkT) -> EmbeddedChunkT: def embed_batch(self, data_model: list[ChunkT]) -> list[EmbeddedChunkT]: embedding_model_input = [data_model.content for data_model in data_model] - embeddings = embedd_text(embedding_model_input) + embeddings = embedding_model(embedding_model_input, to_list=True) embedded_chunk = [ self.map_model(data_model, cast(list[float], embedding)) @@ -48,6 +49,11 @@ def map_model(self, data_model: Query, embedding: list[float]) -> EmbeddedQuery: author_id=data_model.author_id, content=data_model.content, embedding=embedding, + metadata={ + "embedding_model_id": embedding_model.model_id, + "embedding_size": embedding_model.embedding_size, + "max_input_length": embedding_model.max_input_length, + }, ) @@ -60,6 +66,11 @@ def map_model(self, data_model: PostChunk, embedding: list[float]) -> EmbeddedPo platform=data_model.platform, document_id=data_model.document_id, author_id=data_model.author_id, + metadata={ + "embedding_model_id": embedding_model.model_id, + "embedding_size": embedding_model.embedding_size, + "max_input_length": embedding_model.max_input_length, + }, ) @@ -73,6 +84,11 @@ def map_model(self, data_model: ArticleChunk, embedding: list[float]) -> Embedde link=data_model.link, document_id=data_model.document_id, author_id=data_model.author_id, + metadata={ + "embedding_model_id": embedding_model.model_id, + "embedding_size": embedding_model.embedding_size, + "max_input_length": embedding_model.max_input_length, + }, ) @@ -87,4 +103,9 @@ def map_model(self, data_model: RepositoryChunk, embedding: list[float]) -> Embe link=data_model.link, document_id=data_model.document_id, author_id=data_model.author_id, + metadata={ + "embedding_model_id": embedding_model.model_id, + "embedding_size": embedding_model.embedding_size, + "max_input_length": embedding_model.max_input_length, + }, ) diff --git a/llm_engineering/application/preprocessing/operations/__init__.py b/llm_engineering/application/preprocessing/operations/__init__.py index 85405b6..c59c187 100644 --- a/llm_engineering/application/preprocessing/operations/__init__.py +++ b/llm_engineering/application/preprocessing/operations/__init__.py @@ -1,9 +1,7 @@ from .chunking import chunk_text from .cleaning import clean_text -from .embeddings import embedd_text __all__ = [ "chunk_text", "clean_text", - "embedd_text", ] diff --git a/llm_engineering/application/preprocessing/operations/chunking.py b/llm_engineering/application/preprocessing/operations/chunking.py index 3c8784c..dfa3ac2 100644 --- a/llm_engineering/application/preprocessing/operations/chunking.py +++ b/llm_engineering/application/preprocessing/operations/chunking.py @@ -1,25 +1,21 @@ -from langchain.text_splitter import ( - RecursiveCharacterTextSplitter, - SentenceTransformersTokenTextSplitter, -) +from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter from llm_engineering.application.networks import EmbeddingModelSingleton embedding_model = EmbeddingModelSingleton() -def chunk_text(text: str) -> list[str]: - character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n"], chunk_size=500, chunk_overlap=0) - text_split = character_splitter.split_text(text) +def chunk_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> list[str]: + character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n"], chunk_size=chunk_size, chunk_overlap=0) + text_split_by_characters = character_splitter.split_text(text) token_splitter = SentenceTransformersTokenTextSplitter( - chunk_overlap=50, + chunk_overlap=chunk_overlap, tokens_per_chunk=embedding_model.max_input_length, model_name=embedding_model.model_id, ) - chunks = [] + chunks_by_tokens = [] + for section in text_split_by_characters: + chunks_by_tokens.extend(token_splitter.split_text(section)) - for section in text_split: - chunks.extend(token_splitter.split_text(section)) - - return chunks + return chunks_by_tokens diff --git a/llm_engineering/application/preprocessing/operations/embeddings.py b/llm_engineering/application/preprocessing/operations/embeddings.py deleted file mode 100644 index 309bcd8..0000000 --- a/llm_engineering/application/preprocessing/operations/embeddings.py +++ /dev/null @@ -1,7 +0,0 @@ -from llm_engineering.application.networks import EmbeddingModelSingleton - -embedding_model = EmbeddingModelSingleton() - - -def embedd_text(text: str | list[str]) -> list[float] | list[list[float]]: - return embedding_model(text, to_list=True) diff --git a/llm_engineering/domain/chunks.py b/llm_engineering/domain/chunks.py index 943bcca..5b6ad87 100644 --- a/llm_engineering/domain/chunks.py +++ b/llm_engineering/domain/chunks.py @@ -1,7 +1,7 @@ from abc import ABC from typing import Optional -from pydantic import UUID4 +from pydantic import UUID4, Field from llm_engineering.domain.base import VectorBaseDocument from llm_engineering.domain.types import DataCategory @@ -12,6 +12,7 @@ class Chunk(VectorBaseDocument, ABC): platform: str document_id: UUID4 author_id: UUID4 + metadata: dict = Field(default_factory=dict) class PostChunk(Chunk): diff --git a/llm_engineering/domain/embedded_chunks.py b/llm_engineering/domain/embedded_chunks.py index 02483bb..35a6025 100644 --- a/llm_engineering/domain/embedded_chunks.py +++ b/llm_engineering/domain/embedded_chunks.py @@ -1,6 +1,6 @@ from abc import ABC -from pydantic import UUID4 +from pydantic import UUID4, Field from llm_engineering.domain.types import DataCategory @@ -13,6 +13,7 @@ class EmbeddedChunk(VectorBaseDocument, ABC): platform: str document_id: UUID4 author_id: UUID4 + metadata: dict = Field(default_factory=dict) class EmbeddedPostChunk(EmbeddedChunk): diff --git a/llm_engineering/domain/queries.py b/llm_engineering/domain/queries.py index f0b3903..4a96b91 100644 --- a/llm_engineering/domain/queries.py +++ b/llm_engineering/domain/queries.py @@ -1,4 +1,4 @@ -from pydantic import UUID4 +from pydantic import UUID4, Field from llm_engineering.domain.base import VectorBaseDocument from llm_engineering.domain.types import DataCategory @@ -7,6 +7,7 @@ class Query(VectorBaseDocument): content: str author_id: UUID4 | None = None + metadata: dict = Field(default_factory=dict) class Config: category = DataCategory.QUERIES diff --git a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py index fd34520..3ecc892 100644 --- a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py +++ b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py @@ -1,7 +1,8 @@ from typing_extensions import Annotated -from zenml import step +from zenml import get_step_context, step from llm_engineering.application.preprocessing import CleaningDispatcher +from llm_engineering.domain.cleaned_documents import CleanedDocument @step @@ -13,4 +14,18 @@ def clean_documents( cleaned_document = CleaningDispatcher.dispatch(document) cleaned_documents.append(cleaned_document) + step_context = get_step_context() + step_context.add_output_metadata(output_name="cleaned_documents", metadata=_get_metadata(cleaned_documents)) + return cleaned_documents + + +def _get_metadata(cleaned_documents: list[CleanedDocument]) -> dict: + metadata = {"num_documents": len(cleaned_documents)} + for document in cleaned_documents: + category = document.get_category() + if category not in metadata: + metadata[category] = {} + metadata[category]["num_documents"] = metadata[category].get("num_documents", 0) + 1 + + return metadata diff --git a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py index 128ffcf..2d29b09 100644 --- a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py +++ b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py @@ -2,12 +2,13 @@ from loguru import logger from typing_extensions import Annotated -from zenml import step +from zenml import get_step_context, step from llm_engineering.application import utils from llm_engineering.domain.base.nosql import NoSQLBaseDocument from llm_engineering.domain.documents import ( ArticleDocument, + Document, PostDocument, RepositoryDocument, UserDocument, @@ -28,6 +29,9 @@ def query_data_warehouse( user_documents = [doc for query_result in results.values() for doc in query_result] + step_context = get_step_context() + step_context.add_output_metadata(output_name="raw_documents", metadata=_get_metadata(user_documents)) + return user_documents @@ -63,3 +67,14 @@ def __fetch_posts(user_id) -> list[NoSQLBaseDocument]: def __fetch_repositories(user_id) -> list[NoSQLBaseDocument]: return RepositoryDocument.bulk_find(author_id=user_id) + + +def _get_metadata(cleaned_documents: list[Document]) -> dict: + metadata = {"num_documents": len(cleaned_documents)} + for document in cleaned_documents: + collection = document.get_collection_name() + if collection not in metadata: + metadata[collection] = {} + metadata[collection]["num_documents"] = metadata[collection].get("num_documents", 0) + 1 + + return metadata diff --git a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py index 6c3cb34..681b74a 100644 --- a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py +++ b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py @@ -1,19 +1,51 @@ from typing_extensions import Annotated -from zenml import step +from zenml import get_step_context, step from llm_engineering.application import utils from llm_engineering.application.preprocessing import ChunkingDispatcher, EmbeddingDispatcher +from llm_engineering.domain.chunks import Chunk +from llm_engineering.domain.embedded_chunks import EmbeddedChunk @step def chunk_and_embed( cleaned_documents: Annotated[list, "cleaned_documents"], ) -> Annotated[list, "embedded_documents"]: - embedded_documents = [] + metadata = {"chunking": {}, "embedding": {}, "num_documents": len(cleaned_documents)} + + embedded_chunks = [] for document in cleaned_documents: chunks = ChunkingDispatcher.dispatch(document) + metadata["chunking"] = _add_chunks_metadata(chunks, metadata["chunking"]) + for batched_chunks in utils.misc.batch(chunks, 10): batched_embedded_chunks = EmbeddingDispatcher.dispatch(batched_chunks) - embedded_documents.extend(batched_embedded_chunks) + embedded_chunks.extend(batched_embedded_chunks) + + metadata["embedding"] = _add_embeddings_metadata(embedded_chunks, metadata["embedding"]) + metadata["num_chunks"] = len(embedded_chunks) + metadata["num_embedded_chunks"] = len(embedded_chunks) + + step_context = get_step_context() + step_context.add_output_metadata(output_name="embedded_documents", metadata=metadata) + + return embedded_chunks + + +def _add_chunks_metadata(chunks: list[Chunk], metadata: dict) -> dict: + for chunk in chunks: + category = chunk.get_category() + if category not in metadata: + metadata[category] = chunk.metadata + metadata[category]["num_chunks"] = metadata[category].get("num_chunks", 0) + 1 + + return metadata + + +def _add_embeddings_metadata(embedded_chunks: list[EmbeddedChunk], metadata: dict) -> dict: + for embedded_chunk in embedded_chunks: + category = embedded_chunk.get_category() + if category not in metadata: + metadata[category] = embedded_chunk.metadata - return embedded_documents + return metadata