From cd1945bea3d591e78d29a3d112d0d4b2b25cb51c Mon Sep 17 00:00:00 2001
From: iusztinpaul
Date: Wed, 19 Jun 2024 17:09:14 +0300
Subject: [PATCH] feat: Flexible chunking. Add metadata to FE steps
---
.../preprocessing/chunking_data_handlers.py | 32 +++++++++++++++
.../preprocessing/embedding_data_handlers.py | 27 +++++++++++--
.../preprocessing/operations/__init__.py | 2 -
.../preprocessing/operations/chunking.py | 22 +++++-----
.../preprocessing/operations/embeddings.py | 7 ----
llm_engineering/domain/chunks.py | 3 +-
llm_engineering/domain/embedded_chunks.py | 3 +-
llm_engineering/domain/queries.py | 3 +-
.../steps/feature_engineering/clean.py | 17 +++++++-
.../query_data_warehouse.py | 17 +++++++-
.../steps/feature_engineering/rag.py | 40 +++++++++++++++++--
11 files changed, 139 insertions(+), 34 deletions(-)
delete mode 100644 llm_engineering/application/preprocessing/operations/embeddings.py
diff --git a/llm_engineering/application/preprocessing/chunking_data_handlers.py b/llm_engineering/application/preprocessing/chunking_data_handlers.py
index 884919a..da3c6a1 100644
--- a/llm_engineering/application/preprocessing/chunking_data_handlers.py
+++ b/llm_engineering/application/preprocessing/chunking_data_handlers.py
@@ -28,12 +28,26 @@ class ChunkingDataHandler(ABC, Generic[CleanedDocumentT, ChunkT]):
All data transformations logic for the chunking step is done here
"""
+ @property
+ def chunk_size(self) -> int:
+ return 500
+
+ @property
+ def chunk_overlap(self) -> int:
+ return 50
+
@abstractmethod
def chunk(self, data_model: CleanedDocumentT) -> list[ChunkT]:
pass
class PostChunkingHandler(ChunkingDataHandler):
+ def chunk_size(self) -> int:
+ return 250
+
+ def chunk_overlap(self) -> int:
+ return 25
+
def chunk(self, data_model: CleanedPostDocument) -> list[PostChunk]:
data_models_list = []
@@ -49,6 +63,10 @@ def chunk(self, data_model: CleanedPostDocument) -> list[PostChunk]:
document_id=data_model.id,
author_id=data_model.author_id,
image=data_model.image if data_model.image else None,
+ metadata={
+ "chunk_size": self.chunk_size,
+ "chunk_overlap": self.chunk_overlap,
+ },
)
data_models_list.append(model)
@@ -71,6 +89,10 @@ def chunk(self, data_model: CleanedArticleDocument) -> list[ArticleChunk]:
link=data_model.link,
document_id=data_model.id,
author_id=data_model.author_id,
+ metadata={
+ "chunk_size": self.chunk_size,
+ "chunk_overlap": self.chunk_overlap,
+ },
)
data_models_list.append(model)
@@ -78,6 +100,12 @@ def chunk(self, data_model: CleanedArticleDocument) -> list[ArticleChunk]:
class RepositoryChunkingHandler(ChunkingDataHandler):
+ def chunk_size(self) -> int:
+ return 750
+
+ def chunk_overlap(self) -> int:
+ return 75
+
def chunk(self, data_model: CleanedRepositoryDocument) -> list[RepositoryChunk]:
data_models_list = []
@@ -94,6 +122,10 @@ def chunk(self, data_model: CleanedRepositoryDocument) -> list[RepositoryChunk]:
link=data_model.link,
document_id=data_model.id,
author_id=data_model.author_id,
+ metadata={
+ "chunk_size": self.chunk_size,
+ "chunk_overlap": self.chunk_overlap,
+ },
)
data_models_list.append(model)
diff --git a/llm_engineering/application/preprocessing/embedding_data_handlers.py b/llm_engineering/application/preprocessing/embedding_data_handlers.py
index ae83ca7..308ab35 100644
--- a/llm_engineering/application/preprocessing/embedding_data_handlers.py
+++ b/llm_engineering/application/preprocessing/embedding_data_handlers.py
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, cast
+from llm_engineering.application.networks import EmbeddingModelSingleton
from llm_engineering.domain.chunks import ArticleChunk, Chunk, PostChunk, RepositoryChunk
from llm_engineering.domain.embedded_chunks import (
EmbeddedArticleChunk,
@@ -10,11 +11,11 @@
)
from llm_engineering.domain.queries import EmbeddedQuery, Query
-from .operations import embedd_text
-
ChunkT = TypeVar("ChunkT", bound=Chunk)
EmbeddedChunkT = TypeVar("EmbeddedChunkT", bound=EmbeddedChunk)
+embedding_model = EmbeddingModelSingleton()
+
class EmbeddingDataHandler(ABC, Generic[ChunkT, EmbeddedChunkT]):
"""
@@ -27,7 +28,7 @@ def embed(self, data_model: ChunkT) -> EmbeddedChunkT:
def embed_batch(self, data_model: list[ChunkT]) -> list[EmbeddedChunkT]:
embedding_model_input = [data_model.content for data_model in data_model]
- embeddings = embedd_text(embedding_model_input)
+ embeddings = embedding_model(embedding_model_input, to_list=True)
embedded_chunk = [
self.map_model(data_model, cast(list[float], embedding))
@@ -48,6 +49,11 @@ def map_model(self, data_model: Query, embedding: list[float]) -> EmbeddedQuery:
author_id=data_model.author_id,
content=data_model.content,
embedding=embedding,
+ metadata={
+ "embedding_model_id": embedding_model.model_id,
+ "embedding_size": embedding_model.embedding_size,
+ "max_input_length": embedding_model.max_input_length,
+ },
)
@@ -60,6 +66,11 @@ def map_model(self, data_model: PostChunk, embedding: list[float]) -> EmbeddedPo
platform=data_model.platform,
document_id=data_model.document_id,
author_id=data_model.author_id,
+ metadata={
+ "embedding_model_id": embedding_model.model_id,
+ "embedding_size": embedding_model.embedding_size,
+ "max_input_length": embedding_model.max_input_length,
+ },
)
@@ -73,6 +84,11 @@ def map_model(self, data_model: ArticleChunk, embedding: list[float]) -> Embedde
link=data_model.link,
document_id=data_model.document_id,
author_id=data_model.author_id,
+ metadata={
+ "embedding_model_id": embedding_model.model_id,
+ "embedding_size": embedding_model.embedding_size,
+ "max_input_length": embedding_model.max_input_length,
+ },
)
@@ -87,4 +103,9 @@ def map_model(self, data_model: RepositoryChunk, embedding: list[float]) -> Embe
link=data_model.link,
document_id=data_model.document_id,
author_id=data_model.author_id,
+ metadata={
+ "embedding_model_id": embedding_model.model_id,
+ "embedding_size": embedding_model.embedding_size,
+ "max_input_length": embedding_model.max_input_length,
+ },
)
diff --git a/llm_engineering/application/preprocessing/operations/__init__.py b/llm_engineering/application/preprocessing/operations/__init__.py
index 85405b6..c59c187 100644
--- a/llm_engineering/application/preprocessing/operations/__init__.py
+++ b/llm_engineering/application/preprocessing/operations/__init__.py
@@ -1,9 +1,7 @@
from .chunking import chunk_text
from .cleaning import clean_text
-from .embeddings import embedd_text
__all__ = [
"chunk_text",
"clean_text",
- "embedd_text",
]
diff --git a/llm_engineering/application/preprocessing/operations/chunking.py b/llm_engineering/application/preprocessing/operations/chunking.py
index 3c8784c..dfa3ac2 100644
--- a/llm_engineering/application/preprocessing/operations/chunking.py
+++ b/llm_engineering/application/preprocessing/operations/chunking.py
@@ -1,25 +1,21 @@
-from langchain.text_splitter import (
- RecursiveCharacterTextSplitter,
- SentenceTransformersTokenTextSplitter,
-)
+from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from llm_engineering.application.networks import EmbeddingModelSingleton
embedding_model = EmbeddingModelSingleton()
-def chunk_text(text: str) -> list[str]:
- character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n"], chunk_size=500, chunk_overlap=0)
- text_split = character_splitter.split_text(text)
+def chunk_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> list[str]:
+ character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n"], chunk_size=chunk_size, chunk_overlap=0)
+ text_split_by_characters = character_splitter.split_text(text)
token_splitter = SentenceTransformersTokenTextSplitter(
- chunk_overlap=50,
+ chunk_overlap=chunk_overlap,
tokens_per_chunk=embedding_model.max_input_length,
model_name=embedding_model.model_id,
)
- chunks = []
+ chunks_by_tokens = []
+ for section in text_split_by_characters:
+ chunks_by_tokens.extend(token_splitter.split_text(section))
- for section in text_split:
- chunks.extend(token_splitter.split_text(section))
-
- return chunks
+ return chunks_by_tokens
diff --git a/llm_engineering/application/preprocessing/operations/embeddings.py b/llm_engineering/application/preprocessing/operations/embeddings.py
deleted file mode 100644
index 309bcd8..0000000
--- a/llm_engineering/application/preprocessing/operations/embeddings.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from llm_engineering.application.networks import EmbeddingModelSingleton
-
-embedding_model = EmbeddingModelSingleton()
-
-
-def embedd_text(text: str | list[str]) -> list[float] | list[list[float]]:
- return embedding_model(text, to_list=True)
diff --git a/llm_engineering/domain/chunks.py b/llm_engineering/domain/chunks.py
index 943bcca..5b6ad87 100644
--- a/llm_engineering/domain/chunks.py
+++ b/llm_engineering/domain/chunks.py
@@ -1,7 +1,7 @@
from abc import ABC
from typing import Optional
-from pydantic import UUID4
+from pydantic import UUID4, Field
from llm_engineering.domain.base import VectorBaseDocument
from llm_engineering.domain.types import DataCategory
@@ -12,6 +12,7 @@ class Chunk(VectorBaseDocument, ABC):
platform: str
document_id: UUID4
author_id: UUID4
+ metadata: dict = Field(default_factory=dict)
class PostChunk(Chunk):
diff --git a/llm_engineering/domain/embedded_chunks.py b/llm_engineering/domain/embedded_chunks.py
index 02483bb..35a6025 100644
--- a/llm_engineering/domain/embedded_chunks.py
+++ b/llm_engineering/domain/embedded_chunks.py
@@ -1,6 +1,6 @@
from abc import ABC
-from pydantic import UUID4
+from pydantic import UUID4, Field
from llm_engineering.domain.types import DataCategory
@@ -13,6 +13,7 @@ class EmbeddedChunk(VectorBaseDocument, ABC):
platform: str
document_id: UUID4
author_id: UUID4
+ metadata: dict = Field(default_factory=dict)
class EmbeddedPostChunk(EmbeddedChunk):
diff --git a/llm_engineering/domain/queries.py b/llm_engineering/domain/queries.py
index f0b3903..4a96b91 100644
--- a/llm_engineering/domain/queries.py
+++ b/llm_engineering/domain/queries.py
@@ -1,4 +1,4 @@
-from pydantic import UUID4
+from pydantic import UUID4, Field
from llm_engineering.domain.base import VectorBaseDocument
from llm_engineering.domain.types import DataCategory
@@ -7,6 +7,7 @@
class Query(VectorBaseDocument):
content: str
author_id: UUID4 | None = None
+ metadata: dict = Field(default_factory=dict)
class Config:
category = DataCategory.QUERIES
diff --git a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py
index fd34520..3ecc892 100644
--- a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py
+++ b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/clean.py
@@ -1,7 +1,8 @@
from typing_extensions import Annotated
-from zenml import step
+from zenml import get_step_context, step
from llm_engineering.application.preprocessing import CleaningDispatcher
+from llm_engineering.domain.cleaned_documents import CleanedDocument
@step
@@ -13,4 +14,18 @@ def clean_documents(
cleaned_document = CleaningDispatcher.dispatch(document)
cleaned_documents.append(cleaned_document)
+ step_context = get_step_context()
+ step_context.add_output_metadata(output_name="cleaned_documents", metadata=_get_metadata(cleaned_documents))
+
return cleaned_documents
+
+
+def _get_metadata(cleaned_documents: list[CleanedDocument]) -> dict:
+ metadata = {"num_documents": len(cleaned_documents)}
+ for document in cleaned_documents:
+ category = document.get_category()
+ if category not in metadata:
+ metadata[category] = {}
+ metadata[category]["num_documents"] = metadata[category].get("num_documents", 0) + 1
+
+ return metadata
diff --git a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py
index 128ffcf..2d29b09 100644
--- a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py
+++ b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/query_data_warehouse.py
@@ -2,12 +2,13 @@
from loguru import logger
from typing_extensions import Annotated
-from zenml import step
+from zenml import get_step_context, step
from llm_engineering.application import utils
from llm_engineering.domain.base.nosql import NoSQLBaseDocument
from llm_engineering.domain.documents import (
ArticleDocument,
+ Document,
PostDocument,
RepositoryDocument,
UserDocument,
@@ -28,6 +29,9 @@ def query_data_warehouse(
user_documents = [doc for query_result in results.values() for doc in query_result]
+ step_context = get_step_context()
+ step_context.add_output_metadata(output_name="raw_documents", metadata=_get_metadata(user_documents))
+
return user_documents
@@ -63,3 +67,14 @@ def __fetch_posts(user_id) -> list[NoSQLBaseDocument]:
def __fetch_repositories(user_id) -> list[NoSQLBaseDocument]:
return RepositoryDocument.bulk_find(author_id=user_id)
+
+
+def _get_metadata(cleaned_documents: list[Document]) -> dict:
+ metadata = {"num_documents": len(cleaned_documents)}
+ for document in cleaned_documents:
+ collection = document.get_collection_name()
+ if collection not in metadata:
+ metadata[collection] = {}
+ metadata[collection]["num_documents"] = metadata[collection].get("num_documents", 0) + 1
+
+ return metadata
diff --git a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py
index 6c3cb34..681b74a 100644
--- a/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py
+++ b/llm_engineering/interfaces/orchestrator/steps/feature_engineering/rag.py
@@ -1,19 +1,51 @@
from typing_extensions import Annotated
-from zenml import step
+from zenml import get_step_context, step
from llm_engineering.application import utils
from llm_engineering.application.preprocessing import ChunkingDispatcher, EmbeddingDispatcher
+from llm_engineering.domain.chunks import Chunk
+from llm_engineering.domain.embedded_chunks import EmbeddedChunk
@step
def chunk_and_embed(
cleaned_documents: Annotated[list, "cleaned_documents"],
) -> Annotated[list, "embedded_documents"]:
- embedded_documents = []
+ metadata = {"chunking": {}, "embedding": {}, "num_documents": len(cleaned_documents)}
+
+ embedded_chunks = []
for document in cleaned_documents:
chunks = ChunkingDispatcher.dispatch(document)
+ metadata["chunking"] = _add_chunks_metadata(chunks, metadata["chunking"])
+
for batched_chunks in utils.misc.batch(chunks, 10):
batched_embedded_chunks = EmbeddingDispatcher.dispatch(batched_chunks)
- embedded_documents.extend(batched_embedded_chunks)
+ embedded_chunks.extend(batched_embedded_chunks)
+
+ metadata["embedding"] = _add_embeddings_metadata(embedded_chunks, metadata["embedding"])
+ metadata["num_chunks"] = len(embedded_chunks)
+ metadata["num_embedded_chunks"] = len(embedded_chunks)
+
+ step_context = get_step_context()
+ step_context.add_output_metadata(output_name="embedded_documents", metadata=metadata)
+
+ return embedded_chunks
+
+
+def _add_chunks_metadata(chunks: list[Chunk], metadata: dict) -> dict:
+ for chunk in chunks:
+ category = chunk.get_category()
+ if category not in metadata:
+ metadata[category] = chunk.metadata
+ metadata[category]["num_chunks"] = metadata[category].get("num_chunks", 0) + 1
+
+ return metadata
+
+
+def _add_embeddings_metadata(embedded_chunks: list[EmbeddedChunk], metadata: dict) -> dict:
+ for embedded_chunk in embedded_chunks:
+ category = embedded_chunk.get_category()
+ if category not in metadata:
+ metadata[category] = embedded_chunk.metadata
- return embedded_documents
+ return metadata