From ad442dde29efc7c1737a323faf0697fe3ba7013a Mon Sep 17 00:00:00 2001 From: rangoliu Date: Tue, 13 Jun 2023 15:25:58 +0800 Subject: [PATCH] [Feature] Add Attention Injection for unet (#1895) * add attention injection code * add basic useage * run injection good. * fix lint * add attention injection fork * add readme * add readme --- configs/controlnet_animation/README.md | 27 ++- .../controlnet_animation/anythingv3_config.py | 3 + .../controlnet_animation_inferencer.py | 194 ++++++++++++------ mmagic/models/archs/__init__.py | 3 +- mmagic/models/archs/attention_injection.py | 154 ++++++++++++++ .../models/editors/controlnet/controlnet.py | 100 ++++++--- 6 files changed, 388 insertions(+), 93 deletions(-) create mode 100644 mmagic/models/archs/attention_injection.py diff --git a/configs/controlnet_animation/README.md b/configs/controlnet_animation/README.md index 24386e310a..5878ff7818 100644 --- a/configs/controlnet_animation/README.md +++ b/configs/controlnet_animation/README.md @@ -10,12 +10,17 @@ -It is difficult to avoid video frame flickering when using stable diffusion to generate video frame by frame. -Here we reproduce a method that effectively avoids video flickering, that is, using controlnet and multi-frame rendering. -[ControlNet](https://github.com/lllyasviel/ControlNet) is a neural network structure to control diffusion models by adding extra conditions. +It is difficult to keep consistency and avoid video frame flickering when using stable diffusion to generate video frame by frame. +Here we reproduce two methods that effectively avoid video flickering: + +**Controlnet with multi-frame rendering**. [ControlNet](https://github.com/lllyasviel/ControlNet) is a neural network structure to control diffusion models by adding extra conditions. [Multi-frame rendering](https://xanthius.itch.io/multi-frame-rendering-for-stablediffusion) is a community method to reduce flickering. We use controlnet with hed condition and stable diffusion img2img for multi-frame rendering. +**Controlnet with attention injection**. Attention injection is widely used to generate the current frame from a reference image. There is an implementation in [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet#reference-only-control) and we use some of their code to create the animation in this repo. + +You may need 40G GPU memory to run controlnet with multi-frame rendering and 10G GPU memory for controlnet with attention injection. If the config file is not changed, it defaults to using controlnet with attention injection. + ## Demos prompt key words: a handsome man, silver hair, smiling, play basketball @@ -81,6 +86,22 @@ editor.infer(video=video, prompt=prompt, negative_prompt=negative_prompt, save_p python demo/gradio_controlnet_animation.py ``` +### 3. Change config to use multi-frame rendering or attention injection. + +change "inference_method" in [anythingv3 config](./anythingv3_config.py) + +To use multi-frame rendering. + +```python +inference_method = 'multi-frame rendering' +``` + +To use attention injection. + +```python +inference_method = 'attention_injection' +``` + ## Play animation with SAM We also provide a demo to play controlnet animation with sam, for details, please see [OpenMMLab PlayGround](https://github.com/open-mmlab/playground/blob/main/mmediting_sam/README.md). diff --git a/configs/controlnet_animation/anythingv3_config.py b/configs/controlnet_animation/anythingv3_config.py index e213322548..b6159dc378 100644 --- a/configs/controlnet_animation/anythingv3_config.py +++ b/configs/controlnet_animation/anythingv3_config.py @@ -4,6 +4,9 @@ control_detector = 'lllyasviel/ControlNet' control_scheduler = 'UniPCMultistepScheduler' +# method type : 'multi-frame rendering' or 'attention_injection' +inference_method = 'attention_injection' + model = dict( type='ControlStableDiffusionImg2Img', vae=dict( diff --git a/mmagic/apis/inferencers/controlnet_animation_inferencer.py b/mmagic/apis/inferencers/controlnet_animation_inferencer.py index 8cf7b46d32..56bb98fb67 100644 --- a/mmagic/apis/inferencers/controlnet_animation_inferencer.py +++ b/mmagic/apis/inferencers/controlnet_animation_inferencer.py @@ -77,6 +77,9 @@ def __init__(self, **kwargs) -> None: cfg = Config.fromfile(config) self.hed = HEDdetector.from_pretrained(cfg.control_detector) + self.inference_method = cfg.inference_method + if self.inference_method == 'attention_injection': + cfg.model.attention_injection = True self.pipe = MODELS.build(cfg.model).cuda() control_scheduler_cfg = dict( @@ -99,6 +102,7 @@ def __call__(self, num_inference_steps=20, seed=1, output_fps=None, + reference_img=None, **kwargs) -> Union[Dict, List[Dict]]: """Call the inferencer. @@ -164,83 +168,151 @@ def __call__(self, from_video = False - # first result - image = None - if from_video: - image = PIL.Image.fromarray(all_images[0]) - else: - image = load_image(all_images[0]) - image = image.resize((image_width, image_height)) - detect_resolution = min(image_width, image_height) - hed_image = self.hed( - image, - detect_resolution=detect_resolution, - image_resolution=detect_resolution) - hed_image = hed_image.resize((image_width, image_height)) - - result = self.pipe.infer( - control=hed_image, - latent_image=image, - prompt=prompt, - negative_prompt=negative_prompt, - strength=strength, - controlnet_conditioning_scale=controlnet_conditioning_scale, - num_inference_steps=num_inference_steps, - latents=init_noise_all_frame)['samples'][0] - - first_result = result - first_hed = hed_image - last_result = result - last_hed = hed_image - - for ind in range(len(all_images)): + if self.inference_method == 'multi-frame rendering': + # first result if from_video: - if ind % sample_rate > 0: - continue - image = PIL.Image.fromarray(all_images[ind]) + image = PIL.Image.fromarray(all_images[0]) else: - image = load_image(all_images[ind]) - print('processing frame ind ' + str(ind)) - + image = load_image(all_images[0]) image = image.resize((image_width, image_height)) - hed_image = self.hed(image, image_resolution=image_width) - - concat_img = PIL.Image.new('RGB', (image_width * 3, image_height)) - concat_img.paste(last_result, (0, 0)) - concat_img.paste(image, (image_width, 0)) - concat_img.paste(first_result, (image_width * 2, 0)) - - concat_hed = PIL.Image.new('RGB', (image_width * 3, image_height), - 'black') - concat_hed.paste(last_hed, (0, 0)) - concat_hed.paste(hed_image, (image_width, 0)) - concat_hed.paste(first_hed, (image_width * 2, 0)) + detect_resolution = min(image_width, image_height) + hed_image = self.hed( + image, + detect_resolution=detect_resolution, + image_resolution=detect_resolution) + hed_image = hed_image.resize((image_width, image_height)) result = self.pipe.infer( - control=concat_hed, - latent_image=concat_img, + control=hed_image, + latent_image=image, prompt=prompt, negative_prompt=negative_prompt, strength=strength, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=num_inference_steps, - latents=init_noise_all_frame_cat, - latent_mask=latent_mask, - )['samples'][0] - result = result.crop( - (image_width, 0, image_width * 2, image_height)) + latents=init_noise_all_frame)['samples'][0] + first_result = result + first_hed = hed_image last_result = result last_hed = hed_image + for ind in range(len(all_images)): + if from_video: + if ind % sample_rate > 0: + continue + image = PIL.Image.fromarray(all_images[ind]) + else: + image = load_image(all_images[ind]) + print('processing frame ind ' + str(ind)) + + image = image.resize((image_width, image_height)) + hed_image = self.hed(image, image_resolution=image_width) + + concat_img = PIL.Image.new('RGB', + (image_width * 3, image_height)) + concat_img.paste(last_result, (0, 0)) + concat_img.paste(image, (image_width, 0)) + concat_img.paste(first_result, (image_width * 2, 0)) + + concat_hed = PIL.Image.new('RGB', + (image_width * 3, image_height), + 'black') + concat_hed.paste(last_hed, (0, 0)) + concat_hed.paste(hed_image, (image_width, 0)) + concat_hed.paste(first_hed, (image_width * 2, 0)) + + result = self.pipe.infer( + control=concat_hed, + latent_image=concat_img, + prompt=prompt, + negative_prompt=negative_prompt, + strength=strength, + controlnet_conditioning_scale= # noqa + controlnet_conditioning_scale, + num_inference_steps=num_inference_steps, + latents=init_noise_all_frame_cat, + latent_mask=latent_mask, + )['samples'][0] + result = result.crop( + (image_width, 0, image_width * 2, image_height)) + + last_result = result + last_hed = hed_image + + if from_video: + video_writer.write(np.flip(np.asarray(result), axis=2)) + else: + frame_name = frame_files[ind].split('/')[-1] + save_name = os.path.join(save_path, frame_name) + result.save(save_name) + if from_video: - video_writer.write(np.flip(np.asarray(result), axis=2)) + video_writer.release() + else: + if reference_img is None: + if from_video: + image = PIL.Image.fromarray(all_images[0]) + else: + image = load_image(all_images[0]) + image = image.resize((image_width, image_height)) + detect_resolution = min(image_width, image_height) + hed_image = self.hed( + image, + detect_resolution=detect_resolution, + image_resolution=detect_resolution) + hed_image = hed_image.resize((image_width, image_height)) + + result = self.pipe.infer( + control=hed_image, + latent_image=image, + prompt=prompt, + negative_prompt=negative_prompt, + strength=strength, + controlnet_conditioning_scale= # noqa + controlnet_conditioning_scale, + num_inference_steps=num_inference_steps, + latents=init_noise_all_frame)['samples'][0] + + reference_img = result else: - frame_name = frame_files[ind].split('/')[-1] - save_name = os.path.join(save_path, frame_name) - result.save(save_name) + reference_img = load_image(reference_img) + reference_img = reference_img.resize( + (image_width, image_height)) + + for ind in range(len(all_images)): + if from_video: + if ind % sample_rate > 0: + continue + image = PIL.Image.fromarray(all_images[ind]) + else: + image = load_image(all_images[ind]) + print('processing frame ind ' + str(ind)) + + image = image.resize((image_width, image_height)) + hed_image = self.hed(image, image_resolution=image_width) + + result = self.pipe.infer( + control=hed_image, + latent_image=image, + prompt=prompt, + negative_prompt=negative_prompt, + strength=strength, + controlnet_conditioning_scale= # noqa + controlnet_conditioning_scale, + num_inference_steps=num_inference_steps, + latents=init_noise_all_frame, + reference_img=reference_img, + )['samples'][0] + + if from_video: + video_writer.write(np.flip(np.asarray(result), axis=2)) + else: + frame_name = frame_files[ind].split('/')[-1] + save_name = os.path.join(save_path, frame_name) + result.save(save_name) - if from_video: - video_writer.release() + if from_video: + video_writer.release() return save_path diff --git a/mmagic/models/archs/__init__.py b/mmagic/models/archs/__init__.py index 5c1d4c0b42..f33271b509 100644 --- a/mmagic/models/archs/__init__.py +++ b/mmagic/models/archs/__init__.py @@ -6,6 +6,7 @@ from mmagic.utils import try_import from .all_gather_layer import AllGatherLayer from .aspp import ASPP +from .attention_injection import AttentionInjection from .conv import * # noqa: F401, F403 from .downsample import pixel_unshuffle from .ensemble import SpatialTemporalEnsemble @@ -74,5 +75,5 @@ def gen_wrapped_cls(module, module_name): 'SimpleEncoderDecoder', 'MultiLayerDiscriminator', 'PatchDiscriminator', 'VGG16', 'ResNet', 'AllGatherLayer', 'ResidualBlockNoBN', 'LoRAWrapper', 'set_lora', 'set_lora_disable', 'set_lora_enable', - 'set_only_lora_trainable', 'TokenizerWrapper' + 'set_only_lora_trainable', 'TokenizerWrapper', 'AttentionInjection' ] diff --git a/mmagic/models/archs/attention_injection.py b/mmagic/models/archs/attention_injection.py new file mode 100644 index 0000000000..d96486cea8 --- /dev/null +++ b/mmagic/models/archs/attention_injection.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from enum import Enum + +import torch +import torch.nn as nn +from diffusers.models.attention import BasicTransformerBlock +from torch import Tensor + +AttentionStatus = Enum('ATTENTION_STATUS', 'READ WRITE DISABLE') + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +class AttentionInjection(nn.Module): + """Wrapper for stable diffusion unet. + + Args: + module (nn.Module): The module to be wrapped. + """ + + def __init__(self, module: nn.Module, injection_weight=5): + super().__init__() + self.attention_status = AttentionStatus.READ + self.style_cfgs = [] + self.unet = module + + attn_inject = self + + def transformer_forward_replacement( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + class_labels=None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( # noqa + hidden_states, + timestep, + class_labels, + hidden_dtype=hidden_states.dtype) + else: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = None + self_attention_context = norm_hidden_states + if attn_inject.attention_status == AttentionStatus.WRITE: + self.bank.append(self_attention_context.detach().clone()) + if attn_inject.attention_status == AttentionStatus.READ: + if len(self.bank) > 0: + self.bank = self.bank * injection_weight + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=torch.cat( + [self_attention_context] + self.bank, dim=1)) + # attn_output = self.attn1( + # norm_hidden_states, + # encoder_hidden_states=self.bank[0]) + self.bank = [] + if attn_output is None: + attn_output = self.attn1(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + cross_attention_kwargs = cross_attention_kwargs if \ + cross_attention_kwargs is not None else {} + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm else self.norm2(hidden_states)) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + all_modules = torch_dfs(self.unet) + + attn_modules = [ + module for module in all_modules + if isinstance(module, BasicTransformerBlock) + ] + for i, module in enumerate(attn_modules): + if getattr(module, '_original_inner_forward', None) is None: + module._original_inner_forward = module.forward + module.forward = transformer_forward_replacement.__get__( + module, BasicTransformerBlock) + module.bank = [] + + def forward(self, + x: Tensor, + t, + encoder_hidden_states=None, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + ref_x=None) -> Tensor: + """Forward and add LoRA mapping. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: The output tensor. + """ + if ref_x is not None: + self.attention_status = AttentionStatus.WRITE + self.unet( + ref_x, + t, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals= # noqa + down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual) + self.attention_status = AttentionStatus.READ + output = self.unet( + x, + t, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals= # noqa + down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual) + + return output diff --git a/mmagic/models/editors/controlnet/controlnet.py b/mmagic/models/editors/controlnet/controlnet.py index 77a61346f4..f752d58ecf 100644 --- a/mmagic/models/editors/controlnet/controlnet.py +++ b/mmagic/models/editors/controlnet/controlnet.py @@ -13,6 +13,7 @@ from torch import Tensor from tqdm import tqdm +from mmagic.models.archs import AttentionInjection from mmagic.models.utils import build_module from mmagic.registry import MODELS from mmagic.structures import DataSample @@ -69,7 +70,8 @@ def __init__(self, noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor=dict(type='DataPreprocessor'), - init_cfg: Optional[dict] = None): + init_cfg: Optional[dict] = None, + attention_injection=False): super().__init__(vae, text_encoder, tokenizer, unet, scheduler, test_scheduler, dtype, enable_xformers, noise_offset_weight, tomesd_cfg, data_preprocessor, @@ -86,6 +88,9 @@ def __init__(self, self.text_encoder.requires_grad_(False) self.unet.requires_grad_(False) + if attention_injection: + self.unet = AttentionInjection(self.unet) + def init_weights(self): """Initialize the weights. Noted that this function will only be called at train. If you want to inference with a different unet model, you can @@ -668,26 +673,30 @@ def prepare_latent_image(self, image, dtype): return image @torch.no_grad() - def infer(self, - prompt: Union[str, List[str]], - latent_image: Union[torch.FloatTensor, Image.Image, - List[torch.FloatTensor], - List[Image.Image]] = None, - latent_mask: torch.FloatTensor = None, - strength: float = 1.0, - height: Optional[int] = None, - width: Optional[int] = None, - control: Optional[Union[str, np.ndarray, torch.Tensor]] = None, - controlnet_conditioning_scale: float = 1.0, - num_inference_steps: int = 20, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - return_type='image', - show_progress=True): + def infer( + self, + prompt: Union[str, List[str]], + latent_image: Union[torch.FloatTensor, Image.Image, + List[torch.FloatTensor], List[Image.Image]] = None, + latent_mask: torch.FloatTensor = None, + strength: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + control: Optional[Union[str, np.ndarray, torch.Tensor]] = None, + controlnet_conditioning_scale: float = 1.0, + num_inference_steps: int = 20, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + return_type='image', + show_progress=True, + reference_img: Union[torch.FloatTensor, Image.Image, + List[torch.FloatTensor], + List[Image.Image]] = None, + ): """Function invoked when calling the pipeline for generation. Args: @@ -774,6 +783,10 @@ def infer(self, latent_image = self.prepare_latent_image(latent_image, self.controlnet.dtype) + if reference_img is not None: + reference_img = self.prepare_latent_image(reference_img, + self.controlnet.dtype) + # 3. Encode input prompt text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, @@ -802,6 +815,17 @@ def infer(self, generator, noise=latents) + if reference_img is not None: + _, ref_img_vae_latents = self.prepare_latents( + reference_img, + latent_timestep, + batch_size, + num_images_per_prompt, + text_embeddings.dtype, + device, + generator, + noise=latents) + # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_test_scheduler_extra_step_kwargs( generator, eta) @@ -816,6 +840,17 @@ def infer(self, latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) + if reference_img is not None: + ref_img_vae_latents_t = self.scheduler.add_noise( + ref_img_vae_latents, torch.randn_like(ref_img_vae_latents), + t) + ref_img_vae_latents_model_input = torch.cat( + [ref_img_vae_latents_t] * 2) if \ + do_classifier_free_guidance else ref_img_vae_latents_t + ref_img_vae_latents_model_input = \ + self.test_scheduler.scale_model_input( + ref_img_vae_latents_model_input, t) + down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, t, @@ -831,13 +866,22 @@ def infer(self, mid_block_res_sample *= controlnet_conditioning_scale # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - )['sample'] + if reference_img is not None: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ref_x=ref_img_vae_latents_model_input)['sample'] + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + )['sample'] # perform guidance if do_classifier_free_guidance: