Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular backend - overrides #6692

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

StAlKeR7779
Copy link
Contributor

@StAlKeR7779 StAlKeR7779 commented Jul 28, 2024

Summary

Initial implementation of overrides in modular backend, should be used in inpaint and tiled extensions(also in preview extension after preview event rewrite).
Created PR now just to have ability to discuss.
To be precise - need to decide how better to implement arguments in overrides.

Related Issues / Discussions

#6606
https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

QA Instructions

Run with set USE_MODULAR_DENOISE environment.

Merge Plan

Discuss, then merge.

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • Documentation added / updated (if applicable)

@github-actions github-actions bot added python PRs that change python files invocations PRs that change invocations backend PRs that change backend files labels Jul 28, 2024
@StAlKeR7779 StAlKeR7779 marked this pull request as ready for review July 30, 2024 01:37
@StAlKeR7779 StAlKeR7779 changed the title [WIP] Modular backend - overrides Modular backend - overrides Jul 30, 2024
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)

# pass extensions manager as arg to allow override access it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want this? Seems like it just opens the door for a bunch of messiness.

Copy link
Contributor Author

@StAlKeR7779 StAlKeR7779 Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How else tiled denoise will be able to call original step function or callbacks?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this will be easier to discuss in the context of the tiled denoise PR? It seems to me that if we can avoid passing the ext_manager down to callbacks/overrides then that would keep things quite a bit simpler.

@@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext
for cb in callbacks:
cb.function(ctx)

def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have typed function signatures for each override type given that the signatures are known and there aren't very many of them (instead of passing *args and **kwargs).

@@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext
for cb in callbacks:
cb.function(ctx)

def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for passing orig_function? If the orig_function needs to be called, it feels like those use cases could be solved with callbacks.

Copy link
Contributor Author

@StAlKeR7779 StAlKeR7779 Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least to allow extension manager to run original non-overriden implementation.
And also simply because it more flexible, you don't need to implement underlying logic if you only patch it slightly.
Also tiled decode will use orig function of step on each tile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including it in the function signature implies that the function / override should handle it.

For the case that you're describing, I imagined that it would just look like this:

from ... import unet_forward

class AnExtension(ExtensionBase):
    @override(ExtensionOverrideType.UNET_FORWARD)
    def custom_unet_forward(self, ...):
        # Do some stuff...
        unet_forward(...)
        

What do you think?

Comment on lines 62 to +63
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docs explaining the difference between _overrides and _callbacks. Include guidance for developers for how they should decide between using a callback vs. an override. In some cases, both could work, so highlight the things that should be considered to decide between them.

@@ -21,11 +26,19 @@ def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
# A list of extensions in the order that they were added to the ExtensionsManager.
self._extensions: List[ExtensionBase] = []
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add unit tests for the new ExtensionBase/ExtensionsManager functionality given it's core role. It should be straightforward - you can use the existing tests for reference. I think we'd roughly want tests for each of the following:

  • When an override is registered, it get's called
  • Calling an override type with no override registered behaves as expected
  • When duplicate overrides are registered, a meaningful error is raised

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files invocations PRs that change invocations python PRs that change python files
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants