Skip to content

Commit

Permalink
[Feature] Add Attention Injection for unet (#1895)
Browse files Browse the repository at this point in the history
* add attention injection code

* add basic useage

* run injection good.

* fix lint

* add attention injection fork

* add readme

* add readme
  • Loading branch information
liuwenran authored Jun 13, 2023
1 parent 582c1a6 commit ad442dd
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 93 deletions.
27 changes: 24 additions & 3 deletions configs/controlnet_animation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@

<!-- [ABSTRACT] -->

It is difficult to avoid video frame flickering when using stable diffusion to generate video frame by frame.
Here we reproduce a method that effectively avoids video flickering, that is, using controlnet and multi-frame rendering.
[ControlNet](https://github.com/lllyasviel/ControlNet) is a neural network structure to control diffusion models by adding extra conditions.
It is difficult to keep consistency and avoid video frame flickering when using stable diffusion to generate video frame by frame.
Here we reproduce two methods that effectively avoid video flickering:

**Controlnet with multi-frame rendering**. [ControlNet](https://github.com/lllyasviel/ControlNet) is a neural network structure to control diffusion models by adding extra conditions.
[Multi-frame rendering](https://xanthius.itch.io/multi-frame-rendering-for-stablediffusion) is a community method to reduce flickering.
We use controlnet with hed condition and stable diffusion img2img for multi-frame rendering.

**Controlnet with attention injection**. Attention injection is widely used to generate the current frame from a reference image. There is an implementation in [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet#reference-only-control) and we use some of their code to create the animation in this repo.

You may need 40G GPU memory to run controlnet with multi-frame rendering and 10G GPU memory for controlnet with attention injection. If the config file is not changed, it defaults to using controlnet with attention injection.

## Demos

prompt key words: a handsome man, silver hair, smiling, play basketball
Expand Down Expand Up @@ -81,6 +86,22 @@ editor.infer(video=video, prompt=prompt, negative_prompt=negative_prompt, save_p
python demo/gradio_controlnet_animation.py
```

### 3. Change config to use multi-frame rendering or attention injection.

change "inference_method" in [anythingv3 config](./anythingv3_config.py)

To use multi-frame rendering.

```python
inference_method = 'multi-frame rendering'
```

To use attention injection.

```python
inference_method = 'attention_injection'
```

## Play animation with SAM

We also provide a demo to play controlnet animation with sam, for details, please see [OpenMMLab PlayGround](https://github.com/open-mmlab/playground/blob/main/mmediting_sam/README.md).
Expand Down
3 changes: 3 additions & 0 deletions configs/controlnet_animation/anythingv3_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
control_detector = 'lllyasviel/ControlNet'
control_scheduler = 'UniPCMultistepScheduler'

# method type : 'multi-frame rendering' or 'attention_injection'
inference_method = 'attention_injection'

model = dict(
type='ControlStableDiffusionImg2Img',
vae=dict(
Expand Down
194 changes: 133 additions & 61 deletions mmagic/apis/inferencers/controlnet_animation_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self,
**kwargs) -> None:
cfg = Config.fromfile(config)
self.hed = HEDdetector.from_pretrained(cfg.control_detector)
self.inference_method = cfg.inference_method
if self.inference_method == 'attention_injection':
cfg.model.attention_injection = True
self.pipe = MODELS.build(cfg.model).cuda()

control_scheduler_cfg = dict(
Expand All @@ -99,6 +102,7 @@ def __call__(self,
num_inference_steps=20,
seed=1,
output_fps=None,
reference_img=None,
**kwargs) -> Union[Dict, List[Dict]]:
"""Call the inferencer.
Expand Down Expand Up @@ -164,83 +168,151 @@ def __call__(self,

from_video = False

# first result
image = None
if from_video:
image = PIL.Image.fromarray(all_images[0])
else:
image = load_image(all_images[0])
image = image.resize((image_width, image_height))
detect_resolution = min(image_width, image_height)
hed_image = self.hed(
image,
detect_resolution=detect_resolution,
image_resolution=detect_resolution)
hed_image = hed_image.resize((image_width, image_height))

result = self.pipe.infer(
control=hed_image,
latent_image=image,
prompt=prompt,
negative_prompt=negative_prompt,
strength=strength,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
latents=init_noise_all_frame)['samples'][0]

first_result = result
first_hed = hed_image
last_result = result
last_hed = hed_image

for ind in range(len(all_images)):
if self.inference_method == 'multi-frame rendering':
# first result
if from_video:
if ind % sample_rate > 0:
continue
image = PIL.Image.fromarray(all_images[ind])
image = PIL.Image.fromarray(all_images[0])
else:
image = load_image(all_images[ind])
print('processing frame ind ' + str(ind))

image = load_image(all_images[0])
image = image.resize((image_width, image_height))
hed_image = self.hed(image, image_resolution=image_width)

concat_img = PIL.Image.new('RGB', (image_width * 3, image_height))
concat_img.paste(last_result, (0, 0))
concat_img.paste(image, (image_width, 0))
concat_img.paste(first_result, (image_width * 2, 0))

concat_hed = PIL.Image.new('RGB', (image_width * 3, image_height),
'black')
concat_hed.paste(last_hed, (0, 0))
concat_hed.paste(hed_image, (image_width, 0))
concat_hed.paste(first_hed, (image_width * 2, 0))
detect_resolution = min(image_width, image_height)
hed_image = self.hed(
image,
detect_resolution=detect_resolution,
image_resolution=detect_resolution)
hed_image = hed_image.resize((image_width, image_height))

result = self.pipe.infer(
control=concat_hed,
latent_image=concat_img,
control=hed_image,
latent_image=image,
prompt=prompt,
negative_prompt=negative_prompt,
strength=strength,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
latents=init_noise_all_frame_cat,
latent_mask=latent_mask,
)['samples'][0]
result = result.crop(
(image_width, 0, image_width * 2, image_height))
latents=init_noise_all_frame)['samples'][0]

first_result = result
first_hed = hed_image
last_result = result
last_hed = hed_image

for ind in range(len(all_images)):
if from_video:
if ind % sample_rate > 0:
continue
image = PIL.Image.fromarray(all_images[ind])
else:
image = load_image(all_images[ind])
print('processing frame ind ' + str(ind))

image = image.resize((image_width, image_height))
hed_image = self.hed(image, image_resolution=image_width)

concat_img = PIL.Image.new('RGB',
(image_width * 3, image_height))
concat_img.paste(last_result, (0, 0))
concat_img.paste(image, (image_width, 0))
concat_img.paste(first_result, (image_width * 2, 0))

concat_hed = PIL.Image.new('RGB',
(image_width * 3, image_height),
'black')
concat_hed.paste(last_hed, (0, 0))
concat_hed.paste(hed_image, (image_width, 0))
concat_hed.paste(first_hed, (image_width * 2, 0))

result = self.pipe.infer(
control=concat_hed,
latent_image=concat_img,
prompt=prompt,
negative_prompt=negative_prompt,
strength=strength,
controlnet_conditioning_scale= # noqa
controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
latents=init_noise_all_frame_cat,
latent_mask=latent_mask,
)['samples'][0]
result = result.crop(
(image_width, 0, image_width * 2, image_height))

last_result = result
last_hed = hed_image

if from_video:
video_writer.write(np.flip(np.asarray(result), axis=2))
else:
frame_name = frame_files[ind].split('/')[-1]
save_name = os.path.join(save_path, frame_name)
result.save(save_name)

if from_video:
video_writer.write(np.flip(np.asarray(result), axis=2))
video_writer.release()
else:
if reference_img is None:
if from_video:
image = PIL.Image.fromarray(all_images[0])
else:
image = load_image(all_images[0])
image = image.resize((image_width, image_height))
detect_resolution = min(image_width, image_height)
hed_image = self.hed(
image,
detect_resolution=detect_resolution,
image_resolution=detect_resolution)
hed_image = hed_image.resize((image_width, image_height))

result = self.pipe.infer(
control=hed_image,
latent_image=image,
prompt=prompt,
negative_prompt=negative_prompt,
strength=strength,
controlnet_conditioning_scale= # noqa
controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
latents=init_noise_all_frame)['samples'][0]

reference_img = result
else:
frame_name = frame_files[ind].split('/')[-1]
save_name = os.path.join(save_path, frame_name)
result.save(save_name)
reference_img = load_image(reference_img)
reference_img = reference_img.resize(
(image_width, image_height))

for ind in range(len(all_images)):
if from_video:
if ind % sample_rate > 0:
continue
image = PIL.Image.fromarray(all_images[ind])
else:
image = load_image(all_images[ind])
print('processing frame ind ' + str(ind))

image = image.resize((image_width, image_height))
hed_image = self.hed(image, image_resolution=image_width)

result = self.pipe.infer(
control=hed_image,
latent_image=image,
prompt=prompt,
negative_prompt=negative_prompt,
strength=strength,
controlnet_conditioning_scale= # noqa
controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
latents=init_noise_all_frame,
reference_img=reference_img,
)['samples'][0]

if from_video:
video_writer.write(np.flip(np.asarray(result), axis=2))
else:
frame_name = frame_files[ind].split('/')[-1]
save_name = os.path.join(save_path, frame_name)
result.save(save_name)

if from_video:
video_writer.release()
if from_video:
video_writer.release()

return save_path
3 changes: 2 additions & 1 deletion mmagic/models/archs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mmagic.utils import try_import
from .all_gather_layer import AllGatherLayer
from .aspp import ASPP
from .attention_injection import AttentionInjection
from .conv import * # noqa: F401, F403
from .downsample import pixel_unshuffle
from .ensemble import SpatialTemporalEnsemble
Expand Down Expand Up @@ -74,5 +75,5 @@ def gen_wrapped_cls(module, module_name):
'SimpleEncoderDecoder', 'MultiLayerDiscriminator', 'PatchDiscriminator',
'VGG16', 'ResNet', 'AllGatherLayer', 'ResidualBlockNoBN', 'LoRAWrapper',
'set_lora', 'set_lora_disable', 'set_lora_enable',
'set_only_lora_trainable', 'TokenizerWrapper'
'set_only_lora_trainable', 'TokenizerWrapper', 'AttentionInjection'
]
Loading

0 comments on commit ad442dd

Please sign in to comment.