Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CogVideoX text-to-video generation model #9082

Merged
merged 107 commits into from
Aug 7, 2024

Conversation

zRzRzRzRzRzRzR
Copy link
Contributor

What does this PR do?

This PR converts the CogVideoX model into a diffuser-supported inference model, including a complete pipeline and corresponding modules. The paper is still in the process of being written, which may result in temporary omissions regarding the paper in the documentation.

Who can review?

@yiyixuxu
@stevhliu and @sayakpaul

hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
clear_fake_cp_cache: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curiosity: why is this fake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is executed serially on a single GPU.

a-r-r-o-w and others added 3 commits August 6, 2024 07:02
Co-Authored-By: YiYi Xu <yixu310@gmail.com>
Co-Authored-By: YiYi Xu <yixu310@gmail.com>
Co-Authored-By: YiYi Xu <yixu310@gmail.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member

@zRzRzRzRzRzRzR Just to notify, I have removed the clear_fake_cp_cache parameter and switched to something I thought was more clean and understandable. It should be consistent with the old implementation AFAICT. Now trying to debug why we use double the memory than the SAT implementation.

input_parallel = self.fake_cp_pass_from_previous_rank(inputs)

self._clear_fake_context_parallel_cache()
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
Copy link
Member

@a-r-r-o-w a-r-r-o-w Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zRzRzRzRzRzRzR Is there any reason why we're moving the latents to CPU here but then moving them back to inputs.device in fake_cp_pass_from_previous_rank? Since this implementation is tailored to single GPU, and since I don't think the latents will reside on multiple devices for our implementation, we can get rid of the .contiguous().detach().clone().cpu() part and just keep them on same device. cc @yiyixuxu

Edit: Tried it and it looks like .contiguous().clone() will be required but don't think we would need to move to CPU here for single GPU case

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge this soon!


if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w we normally just adjust the weights (ok to keep this! I don't think we need to update the weights now just FYI ) https://github.com/huggingface/diffusers/blob/main/scripts/convert_sd3_to_diffusers.py#L37

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thanks!

docs/source/en/api/pipelines/cogvideox.md Outdated Show resolved Hide resolved

First, load the pipeline:

```python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to include this code block to demonstrate torch.compile, or is it to show inference time without torch.compile? If it's not necessary, I'm more in favor of just showing the below to keep it simpler.

# create pipeline
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda")

# set to channels_last
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)

# compile
pipeline.transformer = torch.compile(pipeline.transformer)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)

# inference
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Aug 6, 2024

i made a PR to update scheduler config on the hub, https://huggingface.co/THUDM/CogVideoX-2b/discussions/2 we can merge that after this PR is merged here

you can test this pr with revision="refs/pr/2, here is the script I used to run tests on both scehdulers with and without dynamic scheduler

from diffusers.utils import export_to_video
import torch
import numpy as np
import PIL

import tempfile
import imageio

def export_to_video_imageio(video_frames, output_video_path: str = None, fps: int = 8):
    """
    Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
    """
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name

    if isinstance(video_frames[0], PIL.Image.Image):
        video_frames = [np.array(frame) for frame in video_frames]

    with imageio.get_writer(output_video_path, fps=fps) as writer:
        for frame in video_frames:
            writer.append_data(frame)

    return output_video_path



prompts = [
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.",
    "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.",
    "The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.",
    "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.",
    "In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict."
    ]


pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16, revision="refs/pr/2")
pipe.enable_model_cpu_offload()

for prompt in prompts:
        for seed in [3]:
                # test ddim
                pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config)
                assert pipe.scheduler.config._class_name == "CogVideoXDDIMScheduler" and pipe.scheduler.config.timestep_spacing=="trailing"

                generator= torch.Generator(device="cpu").manual_seed(seed)
                video = pipe(prompt, guidance_scale=6, num_inference_steps=50, generator=generator).frames[0]
                export_to_video_imageio(video, f"{prompt[:10]}_{seed}_ddim.mp4", fps=8)


                assert pipe.scheduler.config._class_name == "CogVideoXDDIMScheduler" and pipe.scheduler.config.timestep_spacing=="trailing"
                generator= torch.Generator(device="cpu").manual_seed(seed)
                video = pipe(prompt, guidance_scale=6, num_inference_steps=50, generator=generator, use_dynamic_cfg=True).frames[0]
                export_to_video_imageio(video, f"{prompt[:10]}_{seed}_ddim_dynamic_cfg.mp4", fps=8)

                # test dpm
                pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
                assert pipe.scheduler.config._class_name == "CogVideoXDPMScheduler" and pipe.scheduler.config.timestep_spacing=="trailing"

                generator= torch.Generator(device="cpu").manual_seed(seed)
                video = pipe(prompt, guidance_scale=6, num_inference_steps=50, generator=generator).frames[0]
                export_to_video_imageio(video, f"{prompt[:10]}_{seed}_dpm.mp4", fps=8)
                
                
                assert pipe.scheduler.config._class_name == "CogVideoXDPMScheduler" and pipe.scheduler.config.timestep_spacing=="trailing"
                generator= torch.Generator(device="cpu").manual_seed(seed)
                video = pipe(prompt, guidance_scale=6, num_inference_steps=50, generator=generator, use_dynamic_cfg=True).frames[0]
                export_to_video_imageio(video, f"{prompt[:10]}_{seed}_dpm_dynamic_cfg.mp4", fps=8)

sayakpaul and others added 3 commits August 7, 2024 07:49
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Comment on lines +288 to +298
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have ideally gone to check_inputs().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool, then we should do it here too:

if prompt is not None and type(prompt) is not type(negative_prompt):

as it was copied over from there

Comment on lines +537 to +539
assert (
num_frames <= 48 and num_frames % fps == 0 and fps == 8
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have been raised as a ValueError.


# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
num_frames += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't hurt to have a comment explaining why this addition is required.

@yiyixuxu yiyixuxu merged commit 2dad462 into huggingface:main Aug 7, 2024
15 checks passed
zq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zRzRzRzRzRzRzR @a-r-r-o-w @sayakpaul Hi, What concerns do you have that require gradient checkpointing to be performed while the model is in the training mode? I see it https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py#L98 as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand what you mean? Gradient checkpointing is used to save the inputs instead of all the intermediate activations. This way you can compute the gradients on-the-fly during the backward pass, which results in a lot of memory savings. You would only perform a backward pass when training, which is why there are two conditions here

Copy link

@hkunzhe hkunzhe Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w Thanks for your reply! Assuming the VAE is in eval mode and there is a tensor with requires_grad=True, if we don't set torch.no_grad when fowarding the tensor through the VAE since we need the computation graph obtained by autograd on the VAE (for backpropagation), then vae.enabling_gradient _checkpointing() can also save some VRAM. In summary, I believe whether a model can enable gradient checkpointing is unrelated to whether the model is in the train or eval mode. Thus, the "and" condition may be redundant.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* add CogVideoX

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants