Skip to content

Commit

Permalink
feat: Flexible chunking. Add metadata to FE steps
Browse files Browse the repository at this point in the history
  • Loading branch information
iusztinpaul committed Jun 19, 2024
1 parent 3abe5d4 commit cd1945b
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,26 @@ class ChunkingDataHandler(ABC, Generic[CleanedDocumentT, ChunkT]):
All data transformations logic for the chunking step is done here
"""

@property
def chunk_size(self) -> int:
return 500

@property
def chunk_overlap(self) -> int:
return 50

@abstractmethod
def chunk(self, data_model: CleanedDocumentT) -> list[ChunkT]:
pass


class PostChunkingHandler(ChunkingDataHandler):
def chunk_size(self) -> int:
return 250

def chunk_overlap(self) -> int:
return 25

def chunk(self, data_model: CleanedPostDocument) -> list[PostChunk]:
data_models_list = []

Expand All @@ -49,6 +63,10 @@ def chunk(self, data_model: CleanedPostDocument) -> list[PostChunk]:
document_id=data_model.id,
author_id=data_model.author_id,
image=data_model.image if data_model.image else None,
metadata={
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
},
)
data_models_list.append(model)

Expand All @@ -71,13 +89,23 @@ def chunk(self, data_model: CleanedArticleDocument) -> list[ArticleChunk]:
link=data_model.link,
document_id=data_model.id,
author_id=data_model.author_id,
metadata={
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
},
)
data_models_list.append(model)

return data_models_list


class RepositoryChunkingHandler(ChunkingDataHandler):
def chunk_size(self) -> int:
return 750

def chunk_overlap(self) -> int:
return 75

def chunk(self, data_model: CleanedRepositoryDocument) -> list[RepositoryChunk]:
data_models_list = []

Expand All @@ -94,6 +122,10 @@ def chunk(self, data_model: CleanedRepositoryDocument) -> list[RepositoryChunk]:
link=data_model.link,
document_id=data_model.id,
author_id=data_model.author_id,
metadata={
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
},
)
data_models_list.append(model)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, cast

from llm_engineering.application.networks import EmbeddingModelSingleton
from llm_engineering.domain.chunks import ArticleChunk, Chunk, PostChunk, RepositoryChunk
from llm_engineering.domain.embedded_chunks import (
EmbeddedArticleChunk,
Expand All @@ -10,11 +11,11 @@
)
from llm_engineering.domain.queries import EmbeddedQuery, Query

from .operations import embedd_text

ChunkT = TypeVar("ChunkT", bound=Chunk)
EmbeddedChunkT = TypeVar("EmbeddedChunkT", bound=EmbeddedChunk)

embedding_model = EmbeddingModelSingleton()


class EmbeddingDataHandler(ABC, Generic[ChunkT, EmbeddedChunkT]):
"""
Expand All @@ -27,7 +28,7 @@ def embed(self, data_model: ChunkT) -> EmbeddedChunkT:

def embed_batch(self, data_model: list[ChunkT]) -> list[EmbeddedChunkT]:
embedding_model_input = [data_model.content for data_model in data_model]
embeddings = embedd_text(embedding_model_input)
embeddings = embedding_model(embedding_model_input, to_list=True)

embedded_chunk = [
self.map_model(data_model, cast(list[float], embedding))
Expand All @@ -48,6 +49,11 @@ def map_model(self, data_model: Query, embedding: list[float]) -> EmbeddedQuery:
author_id=data_model.author_id,
content=data_model.content,
embedding=embedding,
metadata={
"embedding_model_id": embedding_model.model_id,
"embedding_size": embedding_model.embedding_size,
"max_input_length": embedding_model.max_input_length,
},
)


Expand All @@ -60,6 +66,11 @@ def map_model(self, data_model: PostChunk, embedding: list[float]) -> EmbeddedPo
platform=data_model.platform,
document_id=data_model.document_id,
author_id=data_model.author_id,
metadata={
"embedding_model_id": embedding_model.model_id,
"embedding_size": embedding_model.embedding_size,
"max_input_length": embedding_model.max_input_length,
},
)


Expand All @@ -73,6 +84,11 @@ def map_model(self, data_model: ArticleChunk, embedding: list[float]) -> Embedde
link=data_model.link,
document_id=data_model.document_id,
author_id=data_model.author_id,
metadata={
"embedding_model_id": embedding_model.model_id,
"embedding_size": embedding_model.embedding_size,
"max_input_length": embedding_model.max_input_length,
},
)


Expand All @@ -87,4 +103,9 @@ def map_model(self, data_model: RepositoryChunk, embedding: list[float]) -> Embe
link=data_model.link,
document_id=data_model.document_id,
author_id=data_model.author_id,
metadata={
"embedding_model_id": embedding_model.model_id,
"embedding_size": embedding_model.embedding_size,
"max_input_length": embedding_model.max_input_length,
},
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .chunking import chunk_text
from .cleaning import clean_text
from .embeddings import embedd_text

__all__ = [
"chunk_text",
"clean_text",
"embedd_text",
]
22 changes: 9 additions & 13 deletions llm_engineering/application/preprocessing/operations/chunking.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SentenceTransformersTokenTextSplitter,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter

from llm_engineering.application.networks import EmbeddingModelSingleton

embedding_model = EmbeddingModelSingleton()


def chunk_text(text: str) -> list[str]:
character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n"], chunk_size=500, chunk_overlap=0)
text_split = character_splitter.split_text(text)
def chunk_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> list[str]:
character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n"], chunk_size=chunk_size, chunk_overlap=0)
text_split_by_characters = character_splitter.split_text(text)

token_splitter = SentenceTransformersTokenTextSplitter(
chunk_overlap=50,
chunk_overlap=chunk_overlap,
tokens_per_chunk=embedding_model.max_input_length,
model_name=embedding_model.model_id,
)
chunks = []
chunks_by_tokens = []
for section in text_split_by_characters:
chunks_by_tokens.extend(token_splitter.split_text(section))

for section in text_split:
chunks.extend(token_splitter.split_text(section))

return chunks
return chunks_by_tokens

This file was deleted.

3 changes: 2 additions & 1 deletion llm_engineering/domain/chunks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
from typing import Optional

from pydantic import UUID4
from pydantic import UUID4, Field

from llm_engineering.domain.base import VectorBaseDocument
from llm_engineering.domain.types import DataCategory
Expand All @@ -12,6 +12,7 @@ class Chunk(VectorBaseDocument, ABC):
platform: str
document_id: UUID4
author_id: UUID4
metadata: dict = Field(default_factory=dict)


class PostChunk(Chunk):
Expand Down
3 changes: 2 additions & 1 deletion llm_engineering/domain/embedded_chunks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC

from pydantic import UUID4
from pydantic import UUID4, Field

from llm_engineering.domain.types import DataCategory

Expand All @@ -13,6 +13,7 @@ class EmbeddedChunk(VectorBaseDocument, ABC):
platform: str
document_id: UUID4
author_id: UUID4
metadata: dict = Field(default_factory=dict)


class EmbeddedPostChunk(EmbeddedChunk):
Expand Down
3 changes: 2 additions & 1 deletion llm_engineering/domain/queries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import UUID4
from pydantic import UUID4, Field

from llm_engineering.domain.base import VectorBaseDocument
from llm_engineering.domain.types import DataCategory
Expand All @@ -7,6 +7,7 @@
class Query(VectorBaseDocument):
content: str
author_id: UUID4 | None = None
metadata: dict = Field(default_factory=dict)

class Config:
category = DataCategory.QUERIES
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing_extensions import Annotated
from zenml import step
from zenml import get_step_context, step

from llm_engineering.application.preprocessing import CleaningDispatcher
from llm_engineering.domain.cleaned_documents import CleanedDocument


@step
Expand All @@ -13,4 +14,18 @@ def clean_documents(
cleaned_document = CleaningDispatcher.dispatch(document)
cleaned_documents.append(cleaned_document)

step_context = get_step_context()
step_context.add_output_metadata(output_name="cleaned_documents", metadata=_get_metadata(cleaned_documents))

return cleaned_documents


def _get_metadata(cleaned_documents: list[CleanedDocument]) -> dict:
metadata = {"num_documents": len(cleaned_documents)}
for document in cleaned_documents:
category = document.get_category()
if category not in metadata:
metadata[category] = {}
metadata[category]["num_documents"] = metadata[category].get("num_documents", 0) + 1

return metadata
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from loguru import logger
from typing_extensions import Annotated
from zenml import step
from zenml import get_step_context, step

from llm_engineering.application import utils
from llm_engineering.domain.base.nosql import NoSQLBaseDocument
from llm_engineering.domain.documents import (
ArticleDocument,
Document,
PostDocument,
RepositoryDocument,
UserDocument,
Expand All @@ -28,6 +29,9 @@ def query_data_warehouse(

user_documents = [doc for query_result in results.values() for doc in query_result]

step_context = get_step_context()
step_context.add_output_metadata(output_name="raw_documents", metadata=_get_metadata(user_documents))

return user_documents


Expand Down Expand Up @@ -63,3 +67,14 @@ def __fetch_posts(user_id) -> list[NoSQLBaseDocument]:

def __fetch_repositories(user_id) -> list[NoSQLBaseDocument]:
return RepositoryDocument.bulk_find(author_id=user_id)


def _get_metadata(cleaned_documents: list[Document]) -> dict:
metadata = {"num_documents": len(cleaned_documents)}
for document in cleaned_documents:
collection = document.get_collection_name()
if collection not in metadata:
metadata[collection] = {}
metadata[collection]["num_documents"] = metadata[collection].get("num_documents", 0) + 1

return metadata
Original file line number Diff line number Diff line change
@@ -1,19 +1,51 @@
from typing_extensions import Annotated
from zenml import step
from zenml import get_step_context, step

from llm_engineering.application import utils
from llm_engineering.application.preprocessing import ChunkingDispatcher, EmbeddingDispatcher
from llm_engineering.domain.chunks import Chunk
from llm_engineering.domain.embedded_chunks import EmbeddedChunk


@step
def chunk_and_embed(
cleaned_documents: Annotated[list, "cleaned_documents"],
) -> Annotated[list, "embedded_documents"]:
embedded_documents = []
metadata = {"chunking": {}, "embedding": {}, "num_documents": len(cleaned_documents)}

embedded_chunks = []
for document in cleaned_documents:
chunks = ChunkingDispatcher.dispatch(document)
metadata["chunking"] = _add_chunks_metadata(chunks, metadata["chunking"])

for batched_chunks in utils.misc.batch(chunks, 10):
batched_embedded_chunks = EmbeddingDispatcher.dispatch(batched_chunks)
embedded_documents.extend(batched_embedded_chunks)
embedded_chunks.extend(batched_embedded_chunks)

metadata["embedding"] = _add_embeddings_metadata(embedded_chunks, metadata["embedding"])
metadata["num_chunks"] = len(embedded_chunks)
metadata["num_embedded_chunks"] = len(embedded_chunks)

step_context = get_step_context()
step_context.add_output_metadata(output_name="embedded_documents", metadata=metadata)

return embedded_chunks


def _add_chunks_metadata(chunks: list[Chunk], metadata: dict) -> dict:
for chunk in chunks:
category = chunk.get_category()
if category not in metadata:
metadata[category] = chunk.metadata
metadata[category]["num_chunks"] = metadata[category].get("num_chunks", 0) + 1

return metadata


def _add_embeddings_metadata(embedded_chunks: list[EmbeddedChunk], metadata: dict) -> dict:
for embedded_chunk in embedded_chunks:
category = embedded_chunk.get_category()
if category not in metadata:
metadata[category] = embedded_chunk.metadata

return embedded_documents
return metadata

0 comments on commit cd1945b

Please sign in to comment.