From c8e5491be004f0a49a83cfb8340daf5b7839bd51 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 30 Jul 2024 13:15:58 +0800 Subject: [PATCH 01/94] Create autoencoder_kl3d.py --- .../models/autoencoders/autoencoder_kl3d.py | 920 ++++++++++++++++++ 1 file changed, 920 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl3d.py diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py new file mode 100644 index 000000000000..83c73e92ae68 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -0,0 +1,920 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from beartype import beartype +from beartype.typing import Callable, Optional, Tuple, Union +from einops import rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +def normalize3d(in_channels, z_ch, add_conv): + return SpatialNorm3D( + in_channels, + z_ch, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, + affine=True + ) + + +class SafeConv3d(torch.nn.Conv3d): + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 + if memory_count > 2: # Set to 2GB + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) for i in + range(1, len(input_chunks))] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(SafeConv3d, self).forward(input) + + +class OriginCausalConv3d(nn.Module): + @beartype + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode='constant', + **kwargs + ): + super().__init__() + + def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + dilation = kwargs.pop('dilation', 1) + stride = kwargs.pop('stride', 1) + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = SafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + if self.pad_mode == 'constant': + causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_3d, mode='constant', value=0) + elif self.pad_mode == 'first': + pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) + x = torch.cat([pad_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode='constant', value=0) + elif self.pad_mode == 'reflect': + # reflect padding + reflect_x = x[:, :, 1:self.time_pad + 1, :, :].flip(dims=[2]) + if reflect_x.shape[2] < self.time_pad: + reflect_x = torch.cat( + [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2) + x = torch.cat([reflect_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode='constant', value=0) + else: + raise ValueError("Invalid pad mode") + return self.conv(x) + + +class CausalConv3d(OriginCausalConv3d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_cache = None + + def forward(self, x): + if self.time_pad == 0: + return super().forward(x) + if self.conv_cache is None: + self.conv_cache = x[:, :, -self.time_pad:].detach().clone().cpu() + return super().forward(x) + else: + # print(self.conv_cache.shape, x.shape) + x = torch.cat([self.conv_cache.to(x.device), x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode='constant', value=0) + self.conv_cache = None + return self.conv(x) + + +class SpatialNorm3D(nn.Module): + def __init__(self, f_channels, z_channels, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, + pad_mode='constant', **norm_layer_params): + super().__init__() + self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + self.add_conv = add_conv + if self.add_conv: + self.conv = CausalConv3d(z_channels, z_channels, kernel_size=3, pad_mode=pad_mode) + + self.conv_y = CausalConv3d(z_channels, f_channels, kernel_size=1, pad_mode=pad_mode) + self.conv_b = CausalConv3d(z_channels, f_channels, kernel_size=1, pad_mode=pad_mode) + + def forward(self, f, z): + if z.shape[2] > 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = z[:, :, :1], z[:, :, 1:] + z_first = torch.nn.functional.interpolate(z_first, size=f_first_size, mode="nearest") + z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size, mode="nearest") + z = torch.cat([z_first, z_rest], dim=2) + else: + z = torch.nn.functional.interpolate(z, size=f.shape[-3:], mode="nearest") + if self.add_conv: + z = self.conv(z) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(z) + self.conv_b(z) + return new_f + + +class UpSample3D(nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + compress_time: bool = False + ): + super(UpSample3D, self).__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") + x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + elif x.shape[2] > 1: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + else: + x = x.squeeze(2) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = x[:, :, None, :, :] + else: + # only interpolate 2D + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.with_conv: + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.conv(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + return x + + +class DownSample3D(nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool = False, + compress_time: bool = False, + out_channels: Optional[int] = None + ): + super(DownSample3D, self).__init__() + self.with_conv = with_conv + if out_channels is None: + out_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=0 + ) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + h, w = x.shape[-2:] + x = rearrange(x, 'b c t h w -> (b h w) c t') + + if x.shape[-1] % 2 == 1: + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] + + if x_rest.shape[-1] > 0: + x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w) + else: + x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w) + + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.conv(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + else: + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + return x + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int, + conv_shortcut: bool = False, + dropout: float, + act_fn: str = "silu", + temb_channels: int = 512, + z_ch: Optional[int] = None, + add_conv: bool = False, + pad_mode: str = 'constant', + norm_num_groups: int = 32, + normalization: Callable = None + ): + super(ResnetBlock3D, self).__init__() + self.in_channels = in_channels + self.act_fn = get_activation(act_fn) + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if normalization is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=norm_num_groups, eps=1e-6) + else: + self.norm1 = normalization( + in_channels, + z_ch=z_ch, + add_conv=add_conv, + ) + self.norm2 = normalization( + out_channels, + z_ch=z_ch, + add_conv=add_conv + ) + + self.conv1 = CausalConv3d( + in_channels, + out_channels, + kernel_size=3, + pad_mode=pad_mode + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + else: + self.nin_shortcut = SafeConv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x, temb, z=None): + h = x + if z is not None: + h = self.norm1(h, z) + else: + h = self.norm1(h) + h = self.act_fn(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.act_fn(temb))[:, :, None, None, None] + + if z is not None: + h = self.norm2(h, z) + else: + h = self.norm2(h) + h = self.act_fn(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock2D(nn.Module): + def __init__(self, in_channels, norm_num_groups): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + + t = h_.shape[2] + h_ = rearrange(h_, "b c t h w -> (b t) c h w") + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + + # # original version, nan in fp16 + # w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + # w_ = w_ * (int(c)**(-0.5)) + # # implement c**-0.5 on q + + q = q * (int(c) ** (-0.5)) + w_ = torch.bmm(q, k) + # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t) + + return x + h_ + + +class Encoder3D(nn.Module): + def __init__( + self, + *, + ch: int, + in_channels: int = 3, + out_channels: int = 16, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + act_fn: str = "silu", + norm_num_groups: int = 32, + attn_resolutions=None, + dropout: float = 0.0, + resamp_with_conv: bool = True, + resolution: int, + z_channels: int, + double_z: bool = True, + pad_mode: str = 'first', + temporal_compress_times: int = 4, + ): + super(Encoder3D, self).__init__() + if attn_resolutions is None: + attn_resolutions = [] + self.act_fn = get_activation(act_fn) + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.attn_resolutions = attn_resolutions + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + self.conv_in = CausalConv3d(in_channels, self.ch, kernel_size=3, pad_mode=pad_mode) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=0, + act_fn=act_fn, + dropout=dropout, + norm_num_groups=norm_num_groups, + pad_mode=pad_mode + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + AttnBlock2D(block_in) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level < self.temporal_compress_level: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) + else: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + act_fn=act_fn, + temb_channels=0, + norm_num_groups=norm_num_groups, + dropout=dropout, pad_mode=pad_mode + ) + if len(attn_resolutions) > 0: + self.mid.attn_1 = AttnBlock2D(block_in) + self.mid.block_2 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + act_fn=act_fn, + temb_channels=0, + norm_num_groups=norm_num_groups, + dropout=dropout, pad_mode=pad_mode + ) + + self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=norm_num_groups, eps=1e-6) + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = CausalConv3d( + block_in, conv_out_channels if double_z else z_channels, + kernel_size=3, + pad_mode=pad_mode + ) + + def forward(self, x): + + # timestep embedding + temb = None + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h, temb) + + if len(self.attn_resolutions): + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = self.act_fn(h) + h = self.conv_out(h) + return h + + +class Decoder3D(nn.Module): + def __init__( + self, *, + ch: int, + in_channels: int = 16, + out_channels: int = 3, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions=None, + act_fn: str = "silu", + dropout: float = 0.0, + resamp_with_conv: bool = True, + resolution: int, + z_channels: int, + give_pre_end: bool = False, + z_ch: Optional[int] = None, + add_conv: bool = False, + pad_mode: str = 'first', + temporal_compress_times: int = 4, + norm_num_groups=32, + ): + super(Decoder3D, self).__init__() + if attn_resolutions is None: + attn_resolutions = [] + self.ch = ch + self.act_fn = get_activation(act_fn) + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.attn_resolutions = attn_resolutions + self.norm_num_groups = norm_num_groups + + # log2 of temporal_compress_times + + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if z_ch is None: + z_ch = z_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=0, + dropout=dropout, + act_fn=act_fn, + z_ch=z_ch, + add_conv=add_conv, + normalization=normalize3d, + norm_num_groups=norm_num_groups, + pad_mode=pad_mode + ) + if len(attn_resolutions) > 0: + self.mid.attn_1 = AttnBlock2D(block_in) + self.mid.block_2 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=0, + dropout=dropout, + act_fn=act_fn, + z_ch=z_ch, + add_conv=add_conv, + normalization=normalize3d, + norm_num_groups=norm_num_groups, + pad_mode=pad_mode + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=0, + act_fn=act_fn, + dropout=dropout, + z_ch=z_ch, + add_conv=add_conv, + normalization=normalize3d, + norm_num_groups=norm_num_groups, + pad_mode=pad_mode + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + AttnBlock2D + (block_in=block_in, norm_num_groups=norm_num_groups) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = UpSample3D(block_in, resamp_with_conv, compress_time=False) + else: + up.upsample = UpSample3D(block_in, resamp_with_conv, compress_time=True) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + self.norm_out = normalize3d(block_in, z_ch, add_conv=add_conv) + + self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, z): + + # timestep embedding + temb = None + + # z to block_in + + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, z) + if len(self.attn_resolutions) > 0: + h = self.mid.attn_1(h, z) + h = self.mid.block_2(h, temb, z) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, z) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, z) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, z) + h = self.act_fn(h) + h = self.conv_out(h) + + return h + + +class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["ResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + ch: int = 128, + block_out_channels: Tuple[int] = (1, 2, 2, 4), + layers_per_block: int = 3, + act_fn: str = "silu", + latent_channels: int = 16, + norm_num_groups: int = 32, + sample_size: int = 256, + + # Do Not Know how to use + scaling_factor: float = 0.13025, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + mid_block_add_attention: bool = True + ): + super().__init__() + + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + ch_mult=block_out_channels, + ch=ch, + num_res_blocks=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + resolution=sample_size, + z_channels=latent_channels, + ) + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + ch=ch, + ch_mult=block_out_channels, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + num_res_blocks=layers_per_block, + resolution=sample_size, + z_channels=latent_channels, + ) + self.quant_conv = SafeConv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = SafeConv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = False + + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + # self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder3D, Decoder3D)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self.encoder(x) + if self.quant_conv is not None: + h = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(h) + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + dec = self.decoder(z) + if not return_dict: + return (dec,) + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec From c341786f3eac63357b1ce6ed17bc9c40d7b4d077 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 30 Jul 2024 13:25:08 +0800 Subject: [PATCH 02/94] vae draft --- .../geodiff_molecule_conformation.ipynb | 7222 ++++++++--------- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/autoencoders/__init__.py | 1 + .../models/autoencoders/autoencoder_kl3d.py | 411 +- src/diffusers/utils/dummy_pt_objects.py | 15 + 6 files changed, 3805 insertions(+), 3846 deletions(-) diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb index bde093802a5d..19b87bc18012 100644 --- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb +++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb @@ -1,3652 +1,3652 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "F88mignPnalS" - }, - "source": [ - "# Introduction\n", - "\n", - "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", - "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", - "\n", - "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", - "\n", - "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", - "\n", - "> Colab made by [natolambert](https://twitter.com/natolambert).\n", - "\n", - "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7cnwXMocnuzB" - }, - "source": [ - "## Installations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Install Conda" - ], - "metadata": { - "id": "ff9SxWnaNId9" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1g_6zOabItDk" - }, - "source": [ - "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K0ofXobG5Y-X", - "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2021 NVIDIA Corporation\n", - "Built on Sun_Feb_14_21:12:58_PST_2021\n", - "Cuda compilation tools, release 11.2, V11.2.152\n", - "Build cuda_11.2.r11.2/compiler.29618528_0\n" - ] - } - ], - "source": [ - "!nvcc --version" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfthW90vI0nw" - }, - "source": [ - "Install Conda for some more complex dependencies for geometric networks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2WNFzSnbiE0k", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install -q condacolab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NUsbWYCUI7Km" - }, - "source": [ - "Setup Conda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FZelreINdmd0", - "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✨🍰✨ Everything looks OK!\n" - ] - } - ], - "source": [ - "import condacolab\n", - "condacolab.install()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JzDHaPU7I9Sn" - }, - "source": [ - "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMxRjHhL7w8V", - "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", - "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - cudatoolkit=11.1\n", - " - pytorch\n", - " - torchaudio\n", - " - torchvision\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 960 KB\n", - "\n", - "The following packages will be UPDATED:\n", - "\n", - " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", - "Preparing transaction: / \b\bdone\n", - "Verifying transaction: \\ \b\bdone\n", - "Executing transaction: / \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", - "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" - ] - }, - { - "cell_type": "markdown", - "source": [ - "Need to remove a pathspec for colab that specifies the incorrect cuda version." - ], - "metadata": { - "id": "QDS6FPZ0Tu5b" - } - }, - { - "cell_type": "code", - "source": [ - "!rm /usr/local/conda-meta/pinned" - ], - "metadata": { - "id": "dq1lxR10TtrR", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z1L3DdZOJB30" - }, - "source": [ - "Install torch geometric (used in the model later)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D5ukfCOWfjzK", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - pytorch-geometric=1.7.2\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " decorator-4.4.2 | py_0 11 KB conda-forge\n", - " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", - " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", - " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", - " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", - " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", - " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", - " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", - " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", - " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", - " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", - " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", - " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", - " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", - " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", - " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", - " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", - " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", - " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", - " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 55.9 MB\n", - "\n", - "The following NEW packages will be INSTALLED:\n", - "\n", - " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", - " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", - " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", - " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", - " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", - " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", - " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", - " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", - " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", - " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", - " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", - " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", - " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", - " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", - " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", - " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", - " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", - " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", - " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", - "\n", - "The following packages will be DOWNGRADED:\n", - "\n", - " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", - "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", - "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", - "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", - "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", - "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", - "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", - "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", - "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", - "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", - "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", - "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", - "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", - "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", - "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", - "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", - "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", - "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", - "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", - "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", - "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", - "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install -c rusty1s pytorch-geometric=1.7.2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppxv6Mdkalbc" - }, - "source": [ - "### Install Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mgQA_XN-XGY2", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content\n", - "Cloning into 'diffusers'...\n", - "remote: Enumerating objects: 9298, done.\u001b[K\n", - "remote: Counting objects: 100% (40/40), done.\u001b[K\n", - "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", - "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", - "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", - "Resolving deltas: 100% (6168/6168), done.\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "%cd /content\n", - "\n", - "# install latest HF diffusers (will update to the release once added)\n", - "!git clone https://github.com/huggingface/diffusers.git\n", - "!pip install -q /content/diffusers\n", - "\n", - "# dependencies for diffusers\n", - "!pip install -q datasets transformers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LZO6AJKuJKO8" - }, - "source": [ - "Check that torch is installed correctly and utilizing the GPU in the colab" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gZt7BNi1e1PA", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 - }, - "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "True\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'1.8.2'" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 8 - } - ], - "source": [ - "import torch\n", - "print(torch.cuda.is_available())\n", - "torch.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLE7CqlfJNUO" - }, - "source": [ - "### Install Chemistry-specific Dependencies\n", - "\n", - "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0CPv_NvehRz3", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting rdkit\n", - " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", - "Installing collected packages: rdkit\n", - "Successfully installed rdkit-2022.3.5\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install rdkit" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88GaDbDPxJ5I" - }, - "source": [ - "### Get viewer from nglview\n", - "\n", - "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", - "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", - "The rdmol in this object is a source of ground truth for the generated molecules.\n", - "\n", - "You will use one rendering function from nglviewer later!\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jcl8GCS2mz6t", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting nglview\n", - " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", - "Collecting jupyterlab-widgets\n", - " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipywidgets>=7\n", - " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting widgetsnbextension~=4.0\n", - " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipython>=6.1.0\n", - " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipykernel>=4.5.1\n", - " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting traitlets>=4.3.1\n", - " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", - "Collecting pyzmq>=17\n", - " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting matplotlib-inline>=0.1\n", - " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", - "Collecting tornado>=6.1\n", - " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting nest-asyncio\n", - " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", - "Collecting debugpy>=1.0\n", - " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting psutil\n", - " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jupyter-client>=6.1.12\n", - " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pickleshare\n", - " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", - "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", - "Collecting backcall\n", - " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pexpect>4.3\n", - " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pygments\n", - " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jedi>=0.16\n", - " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", - " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", - "Collecting parso<0.9.0,>=0.8.0\n", - " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", - "Collecting entrypoints\n", - " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", - "Collecting jupyter-core>=4.9.2\n", - " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ptyprocess>=0.5\n", - " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", - "Collecting wcwidth\n", - " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", - "Building wheels for collected packages: nglview\n", - " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", - " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", - "Successfully built nglview\n", - "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", - "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "pexpect", - "pickleshare", - "wcwidth" - ] - } - } - }, - "metadata": {} - } - ], - "source": [ - "!pip install nglview" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" + }, + "source": [ + "## Installations\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Install Conda" + ], + "metadata": { + "id": "ff9SxWnaNId9" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1g_6zOabItDk" + }, + "source": [ + "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K0ofXobG5Y-X", + "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2021 NVIDIA Corporation\n", + "Built on Sun_Feb_14_21:12:58_PST_2021\n", + "Cuda compilation tools, release 11.2, V11.2.152\n", + "Build cuda_11.2.r11.2/compiler.29618528_0\n" + ] + } + ], + "source": [ + "!nvcc --version" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VfthW90vI0nw" + }, + "source": [ + "Install Conda for some more complex dependencies for geometric networks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2WNFzSnbiE0k", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", + "\u001B[0m" + ] + } + ], + "source": [ + "!pip install -q condacolab" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" + }, + "source": [ + "Setup Conda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] + } + ], + "source": [ + "import condacolab\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" + }, + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ], + "metadata": { + "id": "QDS6FPZ0Tu5b" + } + }, + { + "cell_type": "code", + "source": [ + "!rm /usr/local/conda-meta/pinned" + ], + "metadata": { + "id": "dq1lxR10TtrR", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D5ukfCOWfjzK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mgQA_XN-XGY2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001B[K\n", + "remote: Counting objects: 100% (40/40), done.\u001B[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001B[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001B[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", + " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m757.0/757.0 kB\u001B[0m \u001B[31m52.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m163.5/163.5 kB\u001B[0m \u001B[31m21.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m40.8/40.8 kB\u001B[0m \u001B[31m5.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m596.3/596.3 kB\u001B[0m \u001B[31m51.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25h Building wheel for diffusers (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", + "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m432.7/432.7 kB\u001B[0m \u001B[31m36.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m5.3/5.3 MB\u001B[0m \u001B[31m90.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m35.3/35.3 MB\u001B[0m \u001B[31m39.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m115.1/115.1 kB\u001B[0m \u001B[31m16.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m948.0/948.0 kB\u001B[0m \u001B[31m63.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m212.2/212.2 kB\u001B[0m \u001B[31m21.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m95.8/95.8 kB\u001B[0m \u001B[31m12.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m140.8/140.8 kB\u001B[0m \u001B[31m18.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m7.6/7.6 MB\u001B[0m \u001B[31m104.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m148.0/148.0 kB\u001B[0m \u001B[31m20.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m231.3/231.3 kB\u001B[0m \u001B[31m30.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m94.8/94.8 kB\u001B[0m \u001B[31m14.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m58.8/58.8 kB\u001B[0m \u001B[31m8.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25h\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", + "\u001B[0m" + ] + } + ], + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers\n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" + }, + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZt7BNi1e1PA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 }, + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "## Create a diffusion model" - ], - "metadata": { - "id": "8t8_e_uVLdKB" - } + "output_type": "stream", + "name": "stdout", + "text": [ + "True\n" + ] }, { - "cell_type": "markdown", - "source": [ - "### Model class(es)" + "output_type": "execute_result", + "data": { + "text/plain": [ + "'1.8.2'" ], - "metadata": { - "id": "G0rMncVtNSqU" + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" } - }, - { - "cell_type": "markdown", - "source": [ - "Imports" - ], - "metadata": { - "id": "L5FEXz5oXkzt" + }, + "metadata": {}, + "execution_count": 8 + } + ], + "source": [ + "import torch\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0CPv_NvehRz3", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m36.8/36.8 MB\u001B[0m \u001B[31m34.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", + "\u001B[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jcl8GCS2mz6t", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m5.7/5.7 MB\u001B[0m \u001B[31m91.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25h Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", + " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m384.1/384.1 kB\u001B[0m \u001B[31m40.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m134.4/134.4 kB\u001B[0m \u001B[31m21.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m2.0/2.0 MB\u001B[0m \u001B[31m84.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m793.8/793.8 kB\u001B[0m \u001B[31m60.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m138.4/138.4 kB\u001B[0m \u001B[31m20.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m107.1/107.1 kB\u001B[0m \u001B[31m17.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.1/1.1 MB\u001B[0m \u001B[31m68.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m424.0/424.0 kB\u001B[0m \u001B[31m41.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.8/1.8 MB\u001B[0m \u001B[31m83.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m281.3/281.3 kB\u001B[0m \u001B[31m33.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m132.2/132.2 kB\u001B[0m \u001B[31m19.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m59.0/59.0 kB\u001B[0m \u001B[31m7.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.1/1.1 MB\u001B[0m \u001B[31m70.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.6/1.6 MB\u001B[0m \u001B[31m83.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m382.3/382.3 kB\u001B[0m \u001B[31m40.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m100.8/100.8 kB\u001B[0m \u001B[31m14.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m88.4/88.4 kB\u001B[0m \u001B[31m14.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", + "\u001B[0m" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } } - }, - { - "cell_type": "code", - "source": [ - "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", - "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", - "from dataclasses import dataclass\n", - "from typing import Callable, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch import Tensor, nn\n", - "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", - "\n", - "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", - "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", - "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", - "from torch_scatter import scatter_add\n", - "from torch_sparse import SparseTensor, coalesce\n", - "\n", - "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", - "from diffusers.modeling_utils import ModelMixin\n", - "from diffusers.utils import BaseOutput\n" - ], - "metadata": { - "id": "-3-P4w5sXkRU" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Helper classes" + }, + "metadata": {} + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a diffusion model" + ], + "metadata": { + "id": "8t8_e_uVLdKB" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Model class(es)" + ], + "metadata": { + "id": "G0rMncVtNSqU" + } + }, + { + "cell_type": "markdown", + "source": [ + "Imports" + ], + "metadata": { + "id": "L5FEXz5oXkzt" + } + }, + { + "cell_type": "code", + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput\n" + ], + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Helper classes" + ], + "metadata": { + "id": "EzJQXPN_XrMX" + } + }, + { + "cell_type": "code", + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.Tensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0\n" + ], + "metadata": { + "id": "oR1Y56QiLY90" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Main model class!" + ], + "metadata": { + "id": "QWrHJFcYXyUB" + } + }, + { + "cell_type": "code", + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" + ], + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" + }, + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DyCo0nsqjbml", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00 0:\n", - " self.dropout = nn.Dropout(dropout)\n", - " else:\n", - " self.dropout = None\n", - "\n", - " self.layers = nn.ModuleList()\n", - " for i in range(len(self.dims) - 1):\n", - " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\"\"\"\n", - " for i, layer in enumerate(self.layers):\n", - " x = layer(x)\n", - " if i < len(self.layers) - 1:\n", - " if self.activation:\n", - " x = self.activation(x)\n", - " if self.dropout:\n", - " x = self.dropout(x)\n", - " return x\n", - "\n", - "\n", - "class ShiftedSoftplus(torch.nn.Module):\n", - " def __init__(self):\n", - " super(ShiftedSoftplus, self).__init__()\n", - " self.shift = torch.log(torch.tensor(2.0)).item()\n", - "\n", - " def forward(self, x):\n", - " return F.softplus(x) - self.shift\n", - "\n", - "\n", - "class CFConv(MessagePassing):\n", - " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", - " super(CFConv, self).__init__(aggr=\"add\")\n", - " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", - " self.lin2 = Linear(num_filters, out_channels)\n", - " self.nn = mlp\n", - " self.cutoff = cutoff\n", - " self.smooth = smooth\n", - "\n", - " self.reset_parameters()\n", - "\n", - " def reset_parameters(self):\n", - " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", - " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", - " self.lin2.bias.data.fill_(0)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " if self.smooth:\n", - " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", - " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", - " else:\n", - " C = (edge_length <= self.cutoff).float()\n", - " W = self.nn(edge_attr) * C.view(-1, 1)\n", - "\n", - " x = self.lin1(x)\n", - " x = self.propagate(edge_index, x=x, W=W)\n", - " x = self.lin2(x)\n", - " return x\n", - "\n", - " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", - " return x_j * W\n", - "\n", - "\n", - "class InteractionBlock(torch.nn.Module):\n", - " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", - " super(InteractionBlock, self).__init__()\n", - " mlp = Sequential(\n", - " Linear(num_gaussians, num_filters),\n", - " ShiftedSoftplus(),\n", - " Linear(num_filters, num_filters),\n", - " )\n", - " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", - " self.act = ShiftedSoftplus()\n", - " self.lin = Linear(hidden_channels, hidden_channels)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " x = self.conv(x, edge_index, edge_length, edge_attr)\n", - " x = self.act(x)\n", - " x = self.lin(x)\n", - " return x\n", - "\n", - "\n", - "class SchNetEncoder(Module):\n", - " def __init__(\n", - " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.hidden_channels = hidden_channels\n", - " self.num_filters = num_filters\n", - " self.num_interactions = num_interactions\n", - " self.cutoff = cutoff\n", - "\n", - " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", - "\n", - " self.interactions = ModuleList()\n", - " for _ in range(num_interactions):\n", - " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", - " self.interactions.append(block)\n", - "\n", - " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", - " if embed_node:\n", - " assert z.dim() == 1 and z.dtype == torch.long\n", - " h = self.embedding(z)\n", - " else:\n", - " h = z\n", - " for interaction in self.interactions:\n", - " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", - "\n", - " return h\n", - "\n", - "\n", - "class GINEConv(MessagePassing):\n", - " \"\"\"\n", - " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", - " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", - " \"\"\"\n", - "\n", - " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", - " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", - " self.nn = mlp\n", - " self.initial_eps = eps\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " if train_eps:\n", - " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", - " else:\n", - " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", - "\n", - " def forward(\n", - " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", - " ) -> torch.Tensor:\n", - " \"\"\"\"\"\"\n", - " if isinstance(x, torch.Tensor):\n", - " x: OptPairTensor = (x, x)\n", - "\n", - " # Node and edge feature dimensionalites need to match.\n", - " if isinstance(edge_index, torch.Tensor):\n", - " assert edge_attr is not None\n", - " assert x[0].size(-1) == edge_attr.size(-1)\n", - " elif isinstance(edge_index, SparseTensor):\n", - " assert x[0].size(-1) == edge_index.size(-1)\n", - "\n", - " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", - " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", - "\n", - " x_r = x[1]\n", - " if x_r is not None:\n", - " out += (1 + self.eps) * x_r\n", - "\n", - " return self.nn(out)\n", - "\n", - " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", - " if self.activation:\n", - " return self.activation(x_j + edge_attr)\n", - " else:\n", - " return x_j + edge_attr\n", - "\n", - " def __repr__(self):\n", - " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", - "\n", - "\n", - "class GINEncoder(torch.nn.Module):\n", - " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", - " super().__init__()\n", - "\n", - " self.hidden_dim = hidden_dim\n", - " self.num_convs = num_convs\n", - " self.short_cut = short_cut\n", - " self.concat_hidden = concat_hidden\n", - " self.node_emb = nn.Embedding(100, hidden_dim)\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " self.convs = nn.ModuleList()\n", - " for i in range(self.num_convs):\n", - " self.convs.append(\n", - " GINEConv(\n", - " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", - " activation=activation,\n", - " )\n", - " )\n", - "\n", - " def forward(self, z, edge_index, edge_attr):\n", - " \"\"\"\n", - " Input:\n", - " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", - " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", - " Output:\n", - " node_feature: graph feature\n", - " \"\"\"\n", - "\n", - " node_attr = self.node_emb(z) # (num_node, hidden)\n", - "\n", - " hiddens = []\n", - " conv_input = node_attr # (num_node, hidden)\n", - "\n", - " for conv_idx, conv in enumerate(self.convs):\n", - " hidden = conv(conv_input, edge_index, edge_attr)\n", - " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", - " hidden = self.activation(hidden)\n", - " assert hidden.shape == conv_input.shape\n", - " if self.short_cut and hidden.shape == conv_input.shape:\n", - " hidden += conv_input\n", - "\n", - " hiddens.append(hidden)\n", - " conv_input = hidden\n", - "\n", - " if self.concat_hidden:\n", - " node_feature = torch.cat(hiddens, dim=-1)\n", - " else:\n", - " node_feature = hiddens[-1]\n", - "\n", - " return node_feature\n", - "\n", - "\n", - "class MLPEdgeEncoder(Module):\n", - " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", - " super().__init__()\n", - " self.hidden_dim = hidden_dim\n", - " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", - " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", - "\n", - " @property\n", - " def out_channels(self):\n", - " return self.hidden_dim\n", - "\n", - " def forward(self, edge_length, edge_type):\n", - " \"\"\"\n", - " Input:\n", - " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", - " Returns:\n", - " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", - " \"\"\"\n", - " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", - " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", - " return d_emb * edge_attr # (num_edge, hidden)\n", - "\n", - "\n", - "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", - " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", - " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", - " return h_pair\n", - "\n", - "\n", - "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", - " \"\"\"\n", - " Args:\n", - " num_nodes: Number of atoms.\n", - " edge_index: Bond indices of the original graph.\n", - " edge_type: Bond types of the original graph.\n", - " order: Extension order.\n", - " Returns:\n", - " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", - " \"\"\"\n", - "\n", - " def binarize(x):\n", - " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", - "\n", - " def get_higher_order_adj_matrix(adj, order):\n", - " \"\"\"\n", - " Args:\n", - " adj: (N, N)\n", - " type_mat: (N, N)\n", - " Returns:\n", - " Following attributes will be updated:\n", - " - edge_index\n", - " - edge_type\n", - " Following attributes will be added to the data object:\n", - " - bond_edge_index: Original edge_index.\n", - " \"\"\"\n", - " adj_mats = [\n", - " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", - " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", - " ]\n", - "\n", - " for i in range(2, order + 1):\n", - " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", - " order_mat = torch.zeros_like(adj)\n", - "\n", - " for i in range(1, order + 1):\n", - " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", - "\n", - " return order_mat\n", - "\n", - " num_types = 22\n", - " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", - " # from rdkit.Chem.rdchem import BondType as BT\n", - " N = num_nodes\n", - " adj = to_dense_adj(edge_index).squeeze(0)\n", - " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", - "\n", - " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", - " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", - " assert (type_mat * type_highorder == 0).all()\n", - " type_new = type_mat + type_highorder\n", - "\n", - " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", - " _, edge_order = dense_to_sparse(adj_order)\n", - "\n", - " # data.bond_edge_index = data.edge_index # Save original edges\n", - " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", - " assert edge_type.dim() == 1\n", - " N = pos.size(0)\n", - "\n", - " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", - "\n", - " if is_sidechain is None:\n", - " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", - " else:\n", - " # fetch sidechain and its batch index\n", - " is_sidechain = is_sidechain.bool()\n", - " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", - " sidechain_pos = pos[is_sidechain]\n", - " sidechain_index = dummy_index[is_sidechain]\n", - " sidechain_batch = batch[is_sidechain]\n", - "\n", - " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", - " r_edge_index_x = assign_index[1]\n", - " r_edge_index_y = assign_index[0]\n", - " r_edge_index_y = sidechain_index[r_edge_index_y]\n", - "\n", - " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", - " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", - " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", - " # delete self loop\n", - " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", - "\n", - " rgraph_adj = torch.sparse.LongTensor(\n", - " rgraph_edge_index,\n", - " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", - " torch.Size([N, N]),\n", - " )\n", - "\n", - " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", - "\n", - " new_edge_index = composed_adj.indices()\n", - " new_edge_type = composed_adj.values().long()\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def extend_graph_order_radius(\n", - " num_nodes,\n", - " pos,\n", - " edge_index,\n", - " edge_type,\n", - " batch,\n", - " order=3,\n", - " cutoff=10.0,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - "):\n", - " if extend_order:\n", - " edge_index, edge_type = _extend_graph_order(\n", - " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", - " )\n", - "\n", - " if extend_radius:\n", - " edge_index, edge_type = _extend_to_radius_graph(\n", - " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", - " )\n", - "\n", - " return edge_index, edge_type\n", - "\n", - "\n", - "def get_distance(pos, edge_index):\n", - " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", - "\n", - "\n", - "def graph_field_network(score_d, pos, edge_index, edge_length):\n", - " \"\"\"\n", - " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", - " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", - " \"\"\"\n", - " N = pos.size(0)\n", - " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", - " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", - " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", - " ) # (N, 3)\n", - " return score_pos\n", - "\n", - "\n", - "def clip_norm(vec, limit, p=2):\n", - " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", - " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", - " return vec * denom\n", - "\n", - "\n", - "def is_local_edge(edge_type):\n", - " return edge_type > 0\n" - ], - "metadata": { - "id": "oR1Y56QiLY90" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Main model class!" + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading: 0%| | 0.00/401 [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] + } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load('/content/molecules.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JVjz6iH_H6Eh", + "colab": { + "base_uri": "https://localhost:8080/" }, + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "class MoleculeGNN(ModelMixin, ConfigMixin):\n", - " @register_to_config\n", - " def __init__(\n", - " self,\n", - " hidden_dim=128,\n", - " num_convs=6,\n", - " num_convs_local=4,\n", - " cutoff=10.0,\n", - " mlp_act=\"relu\",\n", - " edge_order=3,\n", - " edge_encoder=\"mlp\",\n", - " smooth_conv=True,\n", - " ):\n", - " super().__init__()\n", - " self.cutoff = cutoff\n", - " self.edge_encoder = edge_encoder\n", - " self.edge_order = edge_order\n", - "\n", - " \"\"\"\n", - " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", - " in SchNetEncoder\n", - " \"\"\"\n", - " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - "\n", - " \"\"\"\n", - " The graph neural network that extracts node-wise features.\n", - " \"\"\"\n", - " self.encoder_global = SchNetEncoder(\n", - " hidden_channels=hidden_dim,\n", - " num_filters=hidden_dim,\n", - " num_interactions=num_convs,\n", - " edge_channels=self.edge_encoder_global.out_channels,\n", - " cutoff=cutoff,\n", - " smooth=smooth_conv,\n", - " )\n", - " self.encoder_local = GINEncoder(\n", - " hidden_dim=hidden_dim,\n", - " num_convs=num_convs_local,\n", - " )\n", - "\n", - " \"\"\"\n", - " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", - " gradients w.r.t. edge_length (out_dim = 1).\n", - " \"\"\"\n", - " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " \"\"\"\n", - " Incorporate parameters together\n", - " \"\"\"\n", - " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", - " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", - "\n", - " def _forward(\n", - " self,\n", - " atom_type,\n", - " pos,\n", - " bond_index,\n", - " bond_type,\n", - " batch,\n", - " time_step, # NOTE, model trained without timestep performed best\n", - " edge_index=None,\n", - " edge_type=None,\n", - " edge_length=None,\n", - " return_edges=False,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " atom_type: Types of atoms, (N, ).\n", - " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", - " bond_type: Bond types, (E, ).\n", - " batch: Node index to graph index, (N, ).\n", - " \"\"\"\n", - " N = atom_type.size(0)\n", - " if edge_index is None or edge_type is None or edge_length is None:\n", - " edge_index, edge_type = extend_graph_order_radius(\n", - " num_nodes=N,\n", - " pos=pos,\n", - " edge_index=bond_index,\n", - " edge_type=bond_type,\n", - " batch=batch,\n", - " order=self.edge_order,\n", - " cutoff=self.cutoff,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " is_sidechain=is_sidechain,\n", - " )\n", - " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", - " local_edge_mask = is_local_edge(edge_type) # (E, )\n", - "\n", - " # with the parameterization of NCSNv2\n", - " # DDPM loss implicit handle the noise variance scale conditioning\n", - " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", - "\n", - " # Encoding global\n", - " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - "\n", - " # Global\n", - " node_attr_global = self.encoder_global(\n", - " z=atom_type,\n", - " edge_index=edge_index,\n", - " edge_length=edge_length,\n", - " edge_attr=edge_attr_global,\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_global = assemble_atom_pair_feature(\n", - " node_attr=node_attr_global,\n", - " edge_index=edge_index,\n", - " edge_attr=edge_attr_global,\n", - " ) # (E_global, 2H)\n", - " # Invariant features of edges (radius graph, global)\n", - " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", - "\n", - " # Encoding local\n", - " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - " # edge_attr += temb_edge\n", - "\n", - " # Local\n", - " node_attr_local = self.encoder_local(\n", - " z=atom_type,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_local = assemble_atom_pair_feature(\n", - " node_attr=node_attr_local,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " ) # (E_local, 2H)\n", - "\n", - " # Invariant features of edges (bond graph, local)\n", - " if isinstance(sigma_edge, torch.Tensor):\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", - " 1.0 / sigma_edge[local_edge_mask]\n", - " ) # (E_local, 1)\n", - " else:\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", - "\n", - " if return_edges:\n", - " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", - " else:\n", - " return edge_inv_global, edge_inv_local\n", - "\n", - " def forward(\n", - " self,\n", - " sample,\n", - " timestep: Union[torch.Tensor, float, int],\n", - " return_dict: bool = True,\n", - " sigma=1.0,\n", - " global_start_sigma=0.5,\n", - " w_global=1.0,\n", - " extend_order=False,\n", - " extend_radius=True,\n", - " clip_local=None,\n", - " clip_global=1000.0,\n", - " ) -> Union[MoleculeGNNOutput, Tuple]:\n", - " r\"\"\"\n", - " Args:\n", - " sample: packed torch geometric object\n", - " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", - " return_dict (`bool`, *optional*, defaults to `True`):\n", - " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", - " Returns:\n", - " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", - " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", - " \"\"\"\n", - "\n", - " # unpack sample\n", - " atom_type = sample.atom_type\n", - " bond_index = sample.edge_index\n", - " bond_type = sample.edge_type\n", - " num_graphs = sample.num_graphs\n", - " pos = sample.pos\n", - "\n", - " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", - "\n", - " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", - " atom_type=atom_type,\n", - " pos=sample.pos,\n", - " bond_index=bond_index,\n", - " bond_type=bond_type,\n", - " batch=sample.batch,\n", - " time_step=timesteps,\n", - " return_edges=True,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " ) # (E_global, 1), (E_local, 1)\n", - "\n", - " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", - " node_eq_local = graph_field_network(\n", - " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", - " )\n", - " if clip_local is not None:\n", - " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", - "\n", - " # Global\n", - " if sigma < global_start_sigma:\n", - " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", - " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", - " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", - " else:\n", - " node_eq_global = 0\n", - "\n", - " # Sum\n", - " eps_pos = node_eq_local + node_eq_global * w_global\n", - "\n", - " if not return_dict:\n", - " return (-eps_pos,)\n", - "\n", - " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" - ], - "metadata": { - "id": "MCeZA1qQXzoK" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CCIrPYSJj9wd" - }, - "source": [ - "### Load pretrained model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YdrAr6Ch--Ab" - }, - "source": [ - "#### Load a model\n", - "The model used is a design an\n", - "equivariant convolutional layer, named graph field network (GFN).\n", - "\n", - "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + "output_type": "execute_result", + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" ] + }, + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run the diffusion process" + ], + "metadata": { + "id": "vHNiZAUxNgoy" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "from torch_geometric.data import Data, Batch\n", + "from torch_scatter import scatter_add, scatter_mean\n", + "from tqdm import tqdm\n", + "import copy\n", + "import os\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = 'cuda'\n", + "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 #0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None\n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = '/content/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x9xuLUNg26z1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)):\n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input['pos_ref'] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + "\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", + "\n", + " with open(save_path, 'wb') as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Render the results!" + ], + "metadata": { + "id": "fSApwSaZNndW" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Helper functions" + ], + "metadata": { + "id": "RjaVuR15NqzF" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KieVE1vc0_Vs", + "colab": { + "base_uri": "https://localhost:8080/" }, + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DyCo0nsqjbml", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 172, - "referenced_widgets": [ - "d90f304e9560472eacfbdd11e46765eb", - "1c6246f15b654f4daa11c9bcf997b78c", - "c2321b3bff6f490ca12040a20308f555", - "b7feb522161f4cf4b7cc7c1a078ff12d", - "e2d368556e494ae7ae4e2e992af2cd4f", - "bbef741e76ec41b7ab7187b487a383df", - "561f742d418d4721b0670cc8dd62e22c", - "872915dd1bb84f538c44e26badabafdd", - "d022575f1fa2446d891650897f187b4d", - "fdc393f3468c432aa0ada05e238a5436", - "2c9362906e4b40189f16d14aa9a348da", - "6010fc8daa7a44d5aec4b830ec2ebaa1", - "7e0bb1b8d65249d3974200686b193be2", - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "6526646be5ed415c84d1245b040e629b", - "24d31fc3576e43dd9f8301d2ef3a37ab", - "2918bfaadc8d4b1a9832522c40dfefb8", - "a4bfdca35cc54dae8812720f1b276a08", - "e4901541199b45c6a18824627692fc39", - "f915cf874246446595206221e900b2fe", - "a9e388f22a9742aaaf538e22575c9433", - "42f6c3db29d7484ba6b4f73590abd2f4" - ] - }, - "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Downloading: 0%| | 0.00/3.27M [00:00" ], - "metadata": { - "id": "HdclRaqoUWUD" + "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" + }, + "metadata": {} + } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", + "molSize=(450,300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace('svg:','')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" + }, + "source": [ + "Generate the 3d molecule!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aT1Bkb8YxJfV", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "695ab5bbf30a4ab19df1f9f33469f314" } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PlOkPySoJ1m9" - }, - "source": [ - "#### Create scheduler\n", - "Note, other schedulers are used in the paper for slightly improved performance over DDPM." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nNHnIk9CkAb2" - }, - "outputs": [], - "source": [ - "from diffusers import DDPMScheduler" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RnDJdDBztjFF" - }, - "outputs": [], - "source": [ - "num_timesteps = 1000\n", - "scheduler = DDPMScheduler(num_train_timesteps=num_timesteps,beta_schedule=\"sigmoid\",beta_start=1e-7, beta_end=2e-3, clip_sample=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1vh3fpSAflkL" - }, - "source": [ - "### Get a dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B6qzaGjVKFVk" - }, - "source": [ - "Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\n", - "\n", - "(direct downloading from the hub does not yet work for this datatype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jbLl3EJdgj3x" - }, - "outputs": [], - "source": [ - "# from google.colab import files" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "E591lVuTgxPE" - }, - "outputs": [], - "source": [ - "# uploaded = files.upload()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KUNxfK3ln98Q" - }, - "source": [ - "Load the dataset with torch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7L4iOShTpcQX", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "7f2dcd29-493e-44de-98d1-3ad50f109a4a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2022-10-12 18:32:19-- https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\n", - "Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 127774 (125K) [application/octet-stream]\n", - "Saving to: ‘molecules.pkl’\n", - "\n", - "molecules.pkl 100%[===================>] 124.78K 180KB/s in 0.7s \n", - "\n", - "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", - "\n" - ] - } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "dataset = torch.load('/content/molecules.pkl')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QZcmy1EvKQRk" - }, - "source": [ - "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JVjz6iH_H6Eh", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" - ] - }, - "metadata": {}, - "execution_count": 20 + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" } + } + } + } + } + ], + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxtq8I-I18C-", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "NGLWidget()" ], - "source": [ - "dataset[0]" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Run the diffusion process" - ], - "metadata": { - "id": "vHNiZAUxNgoy" + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "be446195da2b4ff2aec21ec5ff963a54" } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZ1KZrxKqENg" - }, - "source": [ - "#### Helper Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s240tYueqKKf" - }, - "outputs": [], - "source": [ - "from torch_geometric.data import Data, Batch\n", - "from torch_scatter import scatter_add, scatter_mean\n", - "from tqdm import tqdm\n", - "import copy\n", - "import os\n", - "\n", - "def repeat_data(data: Data, num_repeat) -> Batch:\n", - " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", - " return Batch.from_data_list(datas)\n", - "\n", - "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", - " datas = batch.to_data_list()\n", - " new_data = []\n", - " for i in range(num_repeat):\n", - " new_data += copy.deepcopy(datas)\n", - " return Batch.from_data_list(new_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMnQTk0eqT7Z" - }, - "source": [ - "#### Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WYGkzqgzrHmF" - }, - "outputs": [], - "source": [ - "num_samples = 1 # solutions per molecule\n", - "num_molecules = 3\n", - "\n", - "DEVICE = 'cuda'\n", - "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", - "# constants for inference\n", - "w_global = 0.5 #0,.3 for qm9\n", - "global_start_sigma = 0.5\n", - "eta = 1.0\n", - "clip_local = None\n", - "clip_pos = None\n", - "\n", - "# constands for data handling\n", - "save_traj = False\n", - "save_data = False\n", - "output_dir = '/content/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-xD5bJ3SqM7t" - }, - "source": [ - "#### Generate samples!\n", - "Note that the 3d representation of a molecule is referred to as the **conformation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "x9xuLUNg26z1", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " after removing the cwd from sys.path.\n", - "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" - ] + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" } - ], - "source": [ - "results = []\n", - "\n", - "# define sigmas\n", - "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", - "sigmas = sigmas.to(DEVICE)\n", - "\n", - "for count, data in enumerate(tqdm(dataset)):\n", - " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", - "\n", - " data_input = data.clone()\n", - " data_input['pos_ref'] = None\n", - " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", - "\n", - " # initial configuration\n", - " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", - "\n", - " # for logging animation of denoising\n", - " pos_traj = []\n", - " with torch.no_grad():\n", - "\n", - " # scale initial sample\n", - " pos = pos_init * sigmas[-1]\n", - " for t in scheduler.timesteps:\n", - " batch.pos = pos\n", - "\n", - " # generate geometry with model, then filter it\n", - " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", - "\n", - " # Update\n", - " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", - "\n", - " pos = reconstructed_pos\n", - "\n", - " if torch.isnan(pos).any():\n", - " print(\"NaN detected. Please restart.\")\n", - " raise FloatingPointError()\n", - "\n", - " # recenter graph of positions for next iteration\n", - " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", - "\n", - " # optional clipping\n", - " if clip_pos is not None:\n", - " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", - " pos_traj.append(pos.clone().cpu())\n", - "\n", - " pos_gen = pos.cpu()\n", - " if save_traj:\n", - " pos_gen_traj = pos_traj.cpu()\n", - " data.pos_gen = torch.stack(pos_gen_traj)\n", - " else:\n", - " data.pos_gen = pos_gen\n", - " results.append(data)\n", - "\n", - "\n", - "if save_data:\n", - " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", - "\n", - " with open(save_path, 'wb') as f:\n", - " pickle.dump(results, f)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Render the results!" - ], - "metadata": { - "id": "fSApwSaZNndW" + } } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d47Zxo2OKdgZ" - }, - "source": [ - "This function allows us to render 3d in colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e9Cd0kCAv9b8" - }, - "outputs": [], - "source": [ - "from google.colab import output\n", - "output.enable_custom_widget_manager()" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Helper functions" + } + } + ], + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" ], - "metadata": { - "id": "RjaVuR15NqzF" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "28rBYa9NKhlz" - }, - "source": [ - "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LKdKdwxcyTQ6" - }, - "outputs": [], - "source": [ - "from copy import deepcopy\n", - "def set_rdmol_positions(rdkit_mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " mol = deepcopy(rdkit_mol)\n", - " set_rdmol_positions_(mol, pos)\n", - " return mol\n", - "\n", - "def set_rdmol_positions_(mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " for i in range(pos.shape[0]):\n", - " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", - " return mol\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NuE10hcpKmzK" - }, - "source": [ - "Process the generated data to make it easy to view." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KieVE1vc0_Vs", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "collect 5 generated molecules in `mols`\n" - ] - } + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" ], - "source": [ - "# the model can generate multiple conformations per 2d geometry\n", - "num_gen = results[0]['pos_gen'].shape[0]\n", - "\n", - "# init storage objects\n", - "mols_gen = []\n", - "mols_orig = []\n", - "for to_process in results:\n", - "\n", - " # store the reference 3d position\n", - " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # store the generated 3d position\n", - " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # copy data to new object\n", - " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", - "\n", - " # append results\n", - " mols_gen.append(new_mol)\n", - " mols_orig.append(to_process.rdmol)\n", - "\n", - "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tin89JwMKp4v" - }, - "source": [ - "Import tools to visualize the 2d chemical diagram of the molecule." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yqV6gllSZn38" - }, - "outputs": [], - "source": [ - "from rdkit.Chem import AllChem\n", - "from rdkit import Chem\n", - "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", - "from IPython.display import SVG, display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TFNKmGddVoOk" - }, - "source": [ - "Select molecule to visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KzuwLlrrVaGc" - }, - "outputs": [], - "source": [ - "idx = 0\n", - "assert idx < len(results), \"selected molecule that was not generated\"" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Viewing" + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_name": "ColormakerRegistryModel", + "model_module_version": "3.0.1", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_name": "NGLModel", + "model_module_version": "3.0.1", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292777, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 ], - "metadata": { - "id": "hkb8w0_SNtU8" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I3R4QBQeKttN" + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" }, - "source": [ - "This 2D rendering is the equivalent of the **input to the model**!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gkQRWjraaKex", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 321 - }, - "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" - }, - "metadata": {} + "_ngl_msg_archive": [ + { + "target": "Stage", + "type": "call_method", + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "args": [ + { + "type": "blob", + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "binary": false + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" } + } ], - "source": [ - "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", - "molSize=(450,300)\n", - "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", - "drawer.DrawMolecule(mc)\n", - "drawer.FinishDrawing()\n", - "svg = drawer.GetDrawingText()\n", - "display(SVG(svg.replace('svg:','')))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z4FDMYMxKw2I" + "_ngl_original_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" }, - "source": [ - "Generate the 3d molecule!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aT1Bkb8YxJfV", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "695ab5bbf30a4ab19df1f9f33469f314", - "eac6a8dcdc9d4335a2e51031793ead29" - ] - }, - "outputId": "b98870ae-049d-4386-b676-166e9526bda2" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "695ab5bbf30a4ab19df1f9f33469f314" - } + "_ngl_repr_dict": { + "0": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } } - ], - "source": [ - "from nglview import show_rdkit as show" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pxtq8I-I18C-", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "be446195da2b4ff2aec21ec5ff963a54", - "c6596896148b4a8a9c57963b67c7782f", - "2489b5e5648541fbbdceadb05632a050", - "01e0ba4e5da04914b4652b8d58565d7b", - "c30e6c2f3e2a44dbbb3d63bd519acaa4", - "f31c6e40e9b2466a9064a2669933ecd5", - "19308ccac642498ab8b58462e3f1b0bb", - "4a081cdc2ec3421ca79dd933b7e2b0c4", - "e5c0d75eb5e1447abd560c8f2c6017e1", - "5146907ef6764654ad7d598baebc8b58", - "144ec959b7604a2cabb5ca46ae5e5379", - "abce2a80e6304df3899109c6d6cac199", - "65195cb7a4134f4887e9dd19f3676462" - ] - }, - "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "NGLWidget()" - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "be446195da2b4ff2aec21ec5ff963a54" - } + }, + "1": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } } - ], - "source": [ - "# new molecule\n", - "show(mols_gen[idx])" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "KJr4h2mwXeTo" + } }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "d90f304e9560472eacfbdd11e46765eb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", - "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", - "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" - ], - "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" - } - }, - "1c6246f15b654f4daa11c9bcf997b78c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", - "placeholder": "​", - "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", - "value": "Downloading: 100%" - } - }, - "c2321b3bff6f490ca12040a20308f555": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", - "max": 3271865, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", - "value": 3271865 - } - }, - "b7feb522161f4cf4b7cc7c1a078ff12d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", - "placeholder": "​", - "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", - "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" - } - }, - "e2d368556e494ae7ae4e2e992af2cd4f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "bbef741e76ec41b7ab7187b487a383df": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "561f742d418d4721b0670cc8dd62e22c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "872915dd1bb84f538c44e26badabafdd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d022575f1fa2446d891650897f187b4d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fdc393f3468c432aa0ada05e238a5436": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c9362906e4b40189f16d14aa9a348da": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6010fc8daa7a44d5aec4b830ec2ebaa1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", - "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "IPY_MODEL_6526646be5ed415c84d1245b040e629b" - ], - "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" - } - }, - "7e0bb1b8d65249d3974200686b193be2": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", - "placeholder": "​", - "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", - "value": "Downloading: 100%" - } - }, - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", - "max": 401, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", - "value": 401 - } - }, - "6526646be5ed415c84d1245b040e629b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", - "placeholder": "​", - "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", - "value": " 401/401 [00:00<00:00, 13.5kB/s]" - } - }, - "24d31fc3576e43dd9f8301d2ef3a37ab": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2918bfaadc8d4b1a9832522c40dfefb8": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a4bfdca35cc54dae8812720f1b276a08": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e4901541199b45c6a18824627692fc39": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f915cf874246446595206221e900b2fe": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "a9e388f22a9742aaaf538e22575c9433": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "42f6c3db29d7484ba6b4f73590abd2f4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "695ab5bbf30a4ab19df1f9f33469f314": { - "model_module": "nglview-js-widgets", - "model_name": "ColormakerRegistryModel", - "model_module_version": "3.0.1", - "state": { - "_dom_classes": [], - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "ColormakerRegistryModel", - "_msg_ar": [], - "_msg_q": [], - "_ready": false, - "_view_count": null, - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "ColormakerRegistryView", - "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" - } - }, - "eac6a8dcdc9d4335a2e51031793ead29": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be446195da2b4ff2aec21ec5ff963a54": { - "model_module": "nglview-js-widgets", - "model_name": "NGLModel", - "model_module_version": "3.0.1", - "state": { - "_camera_orientation": [ - -15.519693580202304, - -14.065056548036177, - -23.53197484807691, - 0, - -23.357853515109753, - 20.94055073042662, - 2.888695042134944, - 0, - 14.352363398292777, - 18.870825741878015, - -20.744689572909344, - 0, - 0.2724999189376831, - 0.6940000057220459, - -0.3734999895095825, - 1 - ], - "_camera_str": "orthographic", - "_dom_classes": [], - "_gui_theme": null, - "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", - "_igui": null, - "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "NGLModel", - "_ngl_color_dict": {}, - "_ngl_coordinate_resource": {}, - "_ngl_full_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_msg_archive": [ - { - "target": "Stage", - "type": "call_method", - "methodName": "loadFile", - "reconstruc_color_scheme": false, - "args": [ - { - "type": "blob", - "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", - "binary": false - } - ], - "kwargs": { - "defaultRepresentation": true, - "ext": "pdb" - } - } - ], - "_ngl_original_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_repr_dict": { - "0": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - }, - "1": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - } - }, - "_ngl_serialize": false, - "_ngl_version": "", - "_ngl_view_id": [ - "FB989FD1-5B9C-446B-8914-6B58AF85446D" - ], - "_player_dict": {}, - "_scene_position": {}, - "_scene_rotation": {}, - "_synced_model_ids": [], - "_synced_repr_model_ids": [], - "_view_count": null, - "_view_height": "", - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "NGLView", - "_view_width": "", - "background": "white", - "frame": 0, - "gui_style": null, - "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", - "max_frame": 0, - "n_components": 2, - "picked": {} - } - }, - "c6596896148b4a8a9c57963b67c7782f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2489b5e5648541fbbdceadb05632a050": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ButtonView", - "button_style": "", - "description": "", - "disabled": false, - "icon": "compress", - "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", - "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", - "tooltip": "" - } - }, - "01e0ba4e5da04914b4652b8d58565d7b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", - "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" - ], - "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" - } - }, - "c30e6c2f3e2a44dbbb3d63bd519acaa4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f31c6e40e9b2466a9064a2669933ecd5": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "19308ccac642498ab8b58462e3f1b0bb": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4a081cdc2ec3421ca79dd933b7e2b0c4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "SliderStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "", - "handle_color": null - } - }, - "e5c0d75eb5e1447abd560c8f2c6017e1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "PlayModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "PlayModel", - "_playing": false, - "_repeat": false, - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "PlayView", - "description": "", - "description_tooltip": null, - "disabled": false, - "interval": 100, - "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", - "max": 0, - "min": 0, - "show_repeat": true, - "step": 1, - "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", - "value": 0 - } - }, - "5146907ef6764654ad7d598baebc8b58": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntSliderModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "IntSliderModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "IntSliderView", - "continuous_update": true, - "description": "", - "description_tooltip": null, - "disabled": false, - "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", - "max": 0, - "min": 0, - "orientation": "horizontal", - "readout": true, - "readout_format": "d", - "step": 1, - "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", - "value": 0 - } - }, - "144ec959b7604a2cabb5ca46ae5e5379": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "abce2a80e6304df3899109c6d6cac199": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "34px" - } - }, - "65195cb7a4134f4887e9dd19f3676462": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "button_color": null, - "font_weight": "" - } - } - } + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" + ], + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + ], + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "SliderStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "PlayModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntSliderModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f42ccc064624..3e472606e1e7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -78,6 +78,7 @@ "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", + "AutoencoderKL3D", "AutoencoderKLTemporalDecoder", "AutoencoderTiny", "ConsistencyDecoderVAE", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d35786ee7642..f07e906c87d3 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -74,6 +74,7 @@ from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKL3D, AutoencoderKLTemporalDecoder, AutoencoderTiny, ConsistencyDecoderVAE, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 5c47748d62e0..60624865cddf 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,5 +1,6 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl3d import AutoencoderKL3D from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index 83c73e92ae68..de40596b643f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -24,21 +24,22 @@ def normalize3d(in_channels, z_ch, add_conv): add_conv=add_conv, num_groups=32, eps=1e-6, - affine=True + affine=True, ) class SafeConv3d(torch.nn.Conv3d): def forward(self, input): - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 if memory_count > 2: # Set to 2GB kernel_size = self.kernel_size[0] part_num = int(memory_count / 2) + 1 input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW if kernel_size > 1: input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) for i in - range(1, len(input_chunks))] + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] output_chunks = [] for input_chunk in input_chunks: @@ -52,12 +53,7 @@ def forward(self, input): class OriginCausalConv3d(nn.Module): @beartype def __init__( - self, - chan_in, - chan_out, - kernel_size: Union[int, Tuple[int, int, int]], - pad_mode='constant', - **kwargs + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs ): super().__init__() @@ -68,8 +64,8 @@ def cast_tuple(t, length=1): time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - dilation = kwargs.pop('dilation', 1) - stride = kwargs.pop('stride', 1) + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) self.pad_mode = pad_mode time_pad = dilation * (time_kernel_size - 1) + (1 - stride) @@ -86,23 +82,24 @@ def cast_tuple(t, length=1): self.conv = SafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): - if self.pad_mode == 'constant': + if self.pad_mode == "constant": causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_3d, mode='constant', value=0) - elif self.pad_mode == 'first': + x = F.pad(x, causal_padding_3d, mode="constant", value=0) + elif self.pad_mode == "first": pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) x = torch.cat([pad_x, x], dim=2) causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode='constant', value=0) - elif self.pad_mode == 'reflect': + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + elif self.pad_mode == "reflect": # reflect padding - reflect_x = x[:, :, 1:self.time_pad + 1, :, :].flip(dims=[2]) + reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) if reflect_x.shape[2] < self.time_pad: reflect_x = torch.cat( - [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2) + [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 + ) x = torch.cat([reflect_x, x], dim=2) causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode='constant', value=0) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) else: raise ValueError("Invalid pad mode") return self.conv(x) @@ -117,20 +114,28 @@ def forward(self, x): if self.time_pad == 0: return super().forward(x) if self.conv_cache is None: - self.conv_cache = x[:, :, -self.time_pad:].detach().clone().cpu() + self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() return super().forward(x) else: # print(self.conv_cache.shape, x.shape) x = torch.cat([self.conv_cache.to(x.device), x], dim=2) causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode='constant', value=0) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) self.conv_cache = None return self.conv(x) class SpatialNorm3D(nn.Module): - def __init__(self, f_channels, z_channels, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, - pad_mode='constant', **norm_layer_params): + def __init__( + self, + f_channels, + z_channels, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", + **norm_layer_params, + ): super().__init__() self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) if freeze_norm_layer: @@ -161,22 +166,11 @@ def forward(self, f, z): class UpSample3D(nn.Module): - def __init__( - self, - in_channels: int, - with_conv: bool, - compress_time: bool = False - ): + def __init__(self, in_channels: int, with_conv: bool, compress_time: bool = False): super(UpSample3D, self).__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1 - ) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time def forward(self, x): @@ -197,25 +191,25 @@ def forward(self, x): else: # only interpolate 2D t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) if self.with_conv: t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.conv(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x class DownSample3D(nn.Module): def __init__( - self, - in_channels: int, - with_conv: bool = False, - compress_time: bool = False, - out_channels: Optional[int] = None + self, + in_channels: int, + with_conv: bool = False, + compress_time: bool = False, + out_channels: Optional[int] = None, ): super(DownSample3D, self).__init__() self.with_conv = with_conv @@ -223,19 +217,13 @@ def __init__( out_channels = in_channels if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=2, - padding=0 - ) + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time def forward(self, x): if self.compress_time: h, w = x.shape[-2:] - x = rearrange(x, 'b c t h w -> (b h w) c t') + x = rearrange(x, "b c t h w -> (b h w) c t") if x.shape[-1] % 2 == 1: # split first frame @@ -244,41 +232,41 @@ def forward(self, x): if x_rest.shape[-1] > 0: x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) else: x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) - x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.conv(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) else: t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x class ResnetBlock3D(nn.Module): def __init__( - self, - *, - in_channels: int, - out_channels: int, - conv_shortcut: bool = False, - dropout: float, - act_fn: str = "silu", - temb_channels: int = 512, - z_ch: Optional[int] = None, - add_conv: bool = False, - pad_mode: str = 'constant', - norm_num_groups: int = 32, - normalization: Callable = None + self, + *, + in_channels: int, + out_channels: int, + conv_shortcut: bool = False, + dropout: float, + act_fn: str = "silu", + temb_channels: int = 512, + z_ch: Optional[int] = None, + add_conv: bool = False, + pad_mode: str = "constant", + norm_num_groups: int = 32, + normalization: Callable = None, ): super(ResnetBlock3D, self).__init__() self.in_channels = in_channels @@ -296,18 +284,9 @@ def __init__( z_ch=z_ch, add_conv=add_conv, ) - self.norm2 = normalization( - out_channels, - z_ch=z_ch, - add_conv=add_conv - ) + self.norm2 = normalization(out_channels, z_ch=z_ch, add_conv=add_conv) - self.conv1 = CausalConv3d( - in_channels, - out_channels, - kernel_size=3, - pad_mode=pad_mode - ) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) @@ -317,13 +296,7 @@ def __init__( if self.use_conv_shortcut: self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) else: - self.nin_shortcut = SafeConv3d( - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0 - ) + self.nin_shortcut = SafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb, z=None): h = x @@ -360,34 +333,10 @@ def __init__(self, in_channels, norm_num_groups): self.in_channels = in_channels self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) - self.q = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -432,23 +381,23 @@ def forward(self, x): class Encoder3D(nn.Module): def __init__( - self, - *, - ch: int, - in_channels: int = 3, - out_channels: int = 16, - ch_mult: Tuple[int, ...] = (1, 2, 4, 8), - num_res_blocks: int, - act_fn: str = "silu", - norm_num_groups: int = 32, - attn_resolutions=None, - dropout: float = 0.0, - resamp_with_conv: bool = True, - resolution: int, - z_channels: int, - double_z: bool = True, - pad_mode: str = 'first', - temporal_compress_times: int = 4, + self, + *, + ch: int, + in_channels: int = 3, + out_channels: int = 16, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + act_fn: str = "silu", + norm_num_groups: int = 32, + attn_resolutions=None, + dropout: float = 0.0, + resamp_with_conv: bool = True, + resolution: int, + z_channels: int, + double_z: bool = True, + pad_mode: str = "first", + temporal_compress_times: int = 4, ): super(Encoder3D, self).__init__() if attn_resolutions is None: @@ -483,14 +432,12 @@ def __init__( act_fn=act_fn, dropout=dropout, norm_num_groups=norm_num_groups, - pad_mode=pad_mode + pad_mode=pad_mode, ) ) block_in = block_out if curr_res in attn_resolutions: - attn.append( - AttnBlock2D(block_in) - ) + attn.append(AttnBlock2D(block_in)) down = nn.Module() down.block = block down.attn = attn @@ -510,7 +457,8 @@ def __init__( act_fn=act_fn, temb_channels=0, norm_num_groups=norm_num_groups, - dropout=dropout, pad_mode=pad_mode + dropout=dropout, + pad_mode=pad_mode, ) if len(attn_resolutions) > 0: self.mid.attn_1 = AttnBlock2D(block_in) @@ -520,19 +468,17 @@ def __init__( act_fn=act_fn, temb_channels=0, norm_num_groups=norm_num_groups, - dropout=dropout, pad_mode=pad_mode + dropout=dropout, + pad_mode=pad_mode, ) self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=norm_num_groups, eps=1e-6) conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = CausalConv3d( - block_in, conv_out_channels if double_z else z_channels, - kernel_size=3, - pad_mode=pad_mode + block_in, conv_out_channels if double_z else z_channels, kernel_size=3, pad_mode=pad_mode ) def forward(self, x): - # timestep embedding temb = None @@ -562,24 +508,25 @@ def forward(self, x): class Decoder3D(nn.Module): def __init__( - self, *, - ch: int, - in_channels: int = 16, - out_channels: int = 3, - ch_mult: Tuple[int, ...] = (1, 2, 4, 8), - num_res_blocks: int, - attn_resolutions=None, - act_fn: str = "silu", - dropout: float = 0.0, - resamp_with_conv: bool = True, - resolution: int, - z_channels: int, - give_pre_end: bool = False, - z_ch: Optional[int] = None, - add_conv: bool = False, - pad_mode: str = 'first', - temporal_compress_times: int = 4, - norm_num_groups=32, + self, + *, + ch: int, + in_channels: int = 16, + out_channels: int = 3, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions=None, + act_fn: str = "silu", + dropout: float = 0.0, + resamp_with_conv: bool = True, + resolution: int, + z_channels: int, + give_pre_end: bool = False, + z_ch: Optional[int] = None, + add_conv: bool = False, + pad_mode: str = "first", + temporal_compress_times: int = 4, + norm_num_groups=32, ): super(Decoder3D, self).__init__() if attn_resolutions is None: @@ -605,8 +552,7 @@ def __init__( block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode) @@ -622,7 +568,7 @@ def __init__( add_conv=add_conv, normalization=normalize3d, norm_num_groups=norm_num_groups, - pad_mode=pad_mode + pad_mode=pad_mode, ) if len(attn_resolutions) > 0: self.mid.attn_1 = AttnBlock2D(block_in) @@ -636,7 +582,7 @@ def __init__( add_conv=add_conv, normalization=normalize3d, norm_num_groups=norm_num_groups, - pad_mode=pad_mode + pad_mode=pad_mode, ) # upsampling @@ -657,15 +603,12 @@ def __init__( add_conv=add_conv, normalization=normalize3d, norm_num_groups=norm_num_groups, - pad_mode=pad_mode + pad_mode=pad_mode, ) ) block_in = block_out if curr_res in attn_resolutions: - attn.append( - AttnBlock2D - (block_in=block_in, norm_num_groups=norm_num_groups) - ) + attn.append(AttnBlock2D(block_in=block_in, norm_num_groups=norm_num_groups)) up = nn.Module() up.block = block up.attn = attn @@ -682,7 +625,6 @@ def __init__( self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) def forward(self, z): - # timestep embedding temb = None @@ -718,66 +660,65 @@ def forward(self, z): class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" - A VAE model with KL loss for encoding images into latents and decoding latent representations into images. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): - Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): Sample input size. - scaling_factor (`float`, *optional*, defaults to 0.18215): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - force_upcast (`bool`, *optional*, default to `True`): - If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE - can be fine-tuned / trained to a lower range without loosing too much precision in which case - `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - mid_block_add_attention (`bool`, *optional*, default to `True`): - If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the - mid_block will only have resnet blocks - """ + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ _supports_gradient_checkpointing = True _no_split_modules = ["ResnetBlock3D"] @register_to_config def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), - ch: int = 128, - block_out_channels: Tuple[int] = (1, 2, 2, 4), - layers_per_block: int = 3, - act_fn: str = "silu", - latent_channels: int = 16, - norm_num_groups: int = 32, - sample_size: int = 256, - - # Do Not Know how to use - scaling_factor: float = 0.13025, - shift_factor: Optional[float] = None, - latents_mean: Optional[Tuple[float]] = None, - latents_std: Optional[Tuple[float]] = None, - force_upcast: float = True, - use_quant_conv: bool = False, - use_post_quant_conv: bool = False, - mid_block_add_attention: bool = True + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + ch: int = 128, + block_out_channels: Tuple[int] = (1, 2, 2, 4), + layers_per_block: int = 3, + act_fn: str = "silu", + latent_channels: int = 16, + norm_num_groups: int = 32, + sample_size: int = 256, + # Do Not Know how to use + scaling_factor: float = 0.13025, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + mid_block_add_attention: bool = True, ): super().__init__() @@ -854,7 +795,7 @@ def disable_slicing(self): @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -878,7 +819,7 @@ def encode( @apply_forward_hook def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None + self, z: torch.FloatTensor, return_dict: bool = True, generator=None ) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. @@ -902,11 +843,11 @@ def decode( return dec def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample posterior = self.encode(x).latent_dist diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3ead6fd99d10..4d5ec21fd196 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKL3D(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] From bd6efd5fe44231a1239494925c7798267aec01de Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 09:02:06 +0200 Subject: [PATCH 03/94] initial draft of cogvideo transformer --- scripts/convert_cogvideox_to_diffusers.py | 119 +++++ src/diffusers/models/embeddings.py | 85 ++++ .../transformers/cogvideox_transformer_3d.py | 468 ++++++++++++++++++ 3 files changed, 672 insertions(+) create mode 100644 scripts/convert_cogvideox_to_diffusers.py create mode 100644 src/diffusers/models/transformers/cogvideox_transformer_3d.py diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py new file mode 100644 index 000000000000..f334924a92a3 --- /dev/null +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -0,0 +1,119 @@ +import argparse +from typing import Any, Dict, List, Tuple + +import torch +import torch.nn as nn + +from diffusers import CogVideoXTransformer3D + + +def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tuple[str, nn.Module]]: + to_q_key = key.replace("query_key_value", "to_q") + to_k_key = key.replace("query_key_value", "to_k") + to_v_key = key.replace("query_key_value", "to_v") + to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) + state_dict[to_q_key] = to_q + state_dict[to_k_key] = to_k + state_dict[to_v_key] = to_v + state_dict.pop(key) + + +def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tuple[str, nn.Module]]: + layer_id, weight_or_bias = key.split(".")[-2:] + + if "query" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" + elif "key" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" + + state_dict[new_key] = state_dict.pop(key) + + +def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tuple[str, nn.Module]]: + layer_id, _, weight_or_bias = key.split(".")[-3:] + new_key = f"transformer_blocks.{layer_id}.norm0.linear.{weight_or_bias}" + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "transformer.final_layernorm": "norm_final", + "transformer": "transformer_blocks", + "attention": "attn1", + "mlp": "ff.net", + "dense_h_to_4h": "0.proj", + "dense_4h_to_h": "2", + ".layers": "", + "dense": "to_out.0", + "input_layernorm": "norm1", + "post_attn1_layernorm": "norm2", + "time_embed.0": "time_embedding.linear_1", + "time_embed.2": "time_embedding.linear_2", + "mixins.patch_embed": "patch_embed", + "mixins.final_layer.norm_final": "norm_out", + "mixins.final_layer.linear": "proj_out", + "mixins.final_layer.adaLN_modulation.1": "adaln_out.linear", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "query_key_value": reassign_query_key_value_inplace, + "query_layernorm_list": reassign_query_key_layernorm_inplace, + "key_layernorm_list": reassign_query_key_layernorm_inplace, + "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, +} + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def convert_transformer(ckpt_path: str, output_path: str, fp16: bool = False, push_to_hub: bool = False) -> None: + PREFIX_KEY = "model.diffusion_model." + + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + transformer = CogVideoXTransformer3D() + + for key in list(original_state_dict.keys()): + new_key = key[len(PREFIX_KEY) :] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True) + transformer.save_pretrained(output_path) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformercheckpoint" + ) + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + if args.transformer_ckpt_path is not None: + convert_transformer(args.transformer_ckpt_path, args.output_path, args.fp16, args.push_to_hub) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7684fdf9cd6c..cad9e18b3a50 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -78,6 +78,53 @@ def get_timestep_embedding( return emb +def get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, +) -> np.ndarray: + r""" + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / temporal_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + return pos_embed + + def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): @@ -287,6 +334,44 @@ def forward(self, x, freqs_cis): ) +class CogVideoXPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, + ) -> None: + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + r""" + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + + B, F, C, H, W = image_embeds.shape + image_embeds = image_embeds.view(-1, C, H, W) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(B, F, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [B, F, H x W, C] + image_embeds = image_embeds.flatten(1, 2) # [B, F x H x W, C] + + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + return embeds + + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py new file mode 100644 index 000000000000..31cdfb2881cd --- /dev/null +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -0,0 +1,468 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, FeedForward +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, output_dim: int): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + + def forward(self, emb: torch.Tensor) -> torch.Tensor: + x = self.silu(emb.to(torch.float32)).to(emb.dtype) + x = self.linear(x) + return x + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in CogVideoX model. TODO: add link to CogVideoX upon release + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + self.norm0 = AdaLayerNorm(time_embed_dim, 12 * dim) + + # 1. Self Attention + # TODO: verify if this should actually be FP32LayerNorm or if nn.LayerNorm is okay + self.norm1 = FP32LayerNorm(dim, norm_eps) + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm=norm_type if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + # 2. Feed Forward + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def _modulate(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ( + shift_msa, + scale_msa, + gate_msa, + shift_ff, + scale_ff, + gate_mlp, + enc_shift_msa, + enc_scale_msa, + enc_gate_msa, + enc_shift_ff, + enc_scale_ff, + enc_gate_mlp, + ) = self.norm0(temb).chunk(12, dim=1) + gate_msa, gate_mlp, enc_gate_msa, enc_gate_mlp = ( + gate_msa.unsqueeze(1), + gate_mlp.unsqueeze(1), + enc_gate_msa.unsqueeze(1), + enc_gate_mlp.unsqueeze(1), + ) + + # norm & modulate + norm_hidden_states = self.norm1(hidden_states) + norm_encoder_hidden_states = self.norm1(encoder_hidden_states) + + norm_hidden_states = self._modulate(norm_hidden_states, shift_msa, scale_msa) + norm_encoder_hidden_states = self._modulate(norm_encoder_hidden_states, enc_shift_msa, enc_scale_msa) + + # attention + text_length = norm_encoder_hidden_states.size(1) + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + attn_output = self.attn1(norm_hidden_states, attention_mask=attention_mask) + + hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] + + # norm & modulate + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2(encoder_hidden_states) + + norm_hidden_states = self._modulate(norm_hidden_states, shift_ff, scale_ff) + norm_encoder_hidden_states = self._modulate(norm_encoder_hidden_states, enc_shift_ff, enc_scale_ff) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_mlp * ff_output[:, text_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_mlp * ff_output[:, :text_length] + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3D(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input. + out_channels (`int`, *optional*): + The number of channels in the output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + patch_size (`int`, *optional*): + The size of the patches to use in the patch embedding layer. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. During inference, you can denoise for up to but not more steps than + `num_embeds_ada_norm`. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. + caption_channels (`int`, *optional*): + The number of channels in the caption embeddings. + video_length (`int`, *optional*): + The number of frames in the video-like data. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: Optional[int] = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + time_compression: int = 4, + max_text_seq_length: int = 225, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.height = sample_height + self.width = sample_width + self.frames = sample_frames + + post_patch_height = sample_height // patch_size + post_patch_width = sample_width // patch_size + post_time_compression_frames = (sample_frames - 1) // time_compression + 1 + self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. 3D positional embeddings + spatial_pos_embedding = get_3d_sincos_pos_embed( + inner_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + spatial_interpolation_scale, + temporal_interpolation_scale, + ) + spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) + pos_embedding = nn.Parameter( + torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim), requires_grad=False + ) + pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) + self.register_buffer("pos_embedding", pos_embedding, persistent=False) + + # 3. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 4. Define spatial transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = FP32LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 5. Output blocks + self.adaln_out = AdaLayerNorm(time_embed_dim, 2 * inner_dim) + self.norm_out = FP32LayerNorm(inner_dim, 1e-06, norm_elementwise_affine) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def _modulate(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + def forward( + self, + sample: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + attention_mask: Optional[Union[int, torch.Tensor]] = None, + timestep_cond: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size, channels, num_frames, height, width = sample.shape + + # 1. Time embedding + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 3. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, sample) + + # 4. Position embedding + seq_length = height * width * num_frames // (self.config.patch_size**2) + text_seq_length = encoder_hidden_states.size(1) + + pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] + hidden_states = hidden_states + pos_embeds + hidden_states = self.embedding_dropout(hidden_states) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # 2. Prepare attention mask + if attention_mask is None: + attention_mask = torch.ones( + batch_size, + self.num_patches + self.config.max_text_seq_length, + self.num_patches + self.config.max_text_seq_length, + ) + attention_mask = attention_mask.to(device=sample.device, dtype=sample.dtype) + + # 5. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + attention_mask, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + attention_mask=attention_mask, + ) + + hidden_states = self.norm_final(hidden_states) + + # 6. Final block + shift, scale = self.adaln_out(emb).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + hidden_states = self._modulate(hidden_states, shift, scale) + hidden_states = self.norm_out(hidden_states) + hidden_states = self.proj_out(hidden_states) + + # 7. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape(-1, height // p, width // p, p, p, self.config.out_channels) + output = output.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) From bb917755ee304cfedec712d5840032e03ad8a5a9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 09:04:45 +0200 Subject: [PATCH 04/94] add imports --- src/diffusers/__init__.py | 2 ++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/models/transformers/__init__.py | 1 + src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 4 files changed, 20 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f42ccc064624..3b81f86a4f1f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -80,6 +80,7 @@ "AutoencoderKL", "AutoencoderKLTemporalDecoder", "AutoencoderTiny", + "CogVideoXTransformer3D", "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAdapter", @@ -516,6 +517,7 @@ AutoencoderKL, AutoencoderKLTemporalDecoder, AutoencoderTiny, + CogVideoXTransformer3D, ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d35786ee7642..226261213abd 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -40,6 +40,7 @@ _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] + _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3D"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] @@ -88,6 +89,7 @@ from .modeling_utils import ModelMixin from .transformers import ( AuraFlowTransformer2DModel, + CogVideoXTransformer3D, DiTTransformer2DModel, DualTransformer2DModel, HunyuanDiT2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ae5103160790..f68a35044b11 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -3,6 +3,7 @@ if is_torch_available(): from .auraflow_transformer_2d import AuraFlowTransformer2DModel + from .cogvideox_transformer_3d import CogVideoXTransformer3D from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3ead6fd99d10..78d5720ba87f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -77,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CogVideoXTransformer3D(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ConsistencyDecoderVAE(metaclass=DummyObject): _backends = ["torch"] From 59e6669f6d247461524e92a5470ac44e4a05e066 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 11:44:42 +0200 Subject: [PATCH 05/94] fix attention mask --- src/diffusers/models/embeddings.py | 2 +- .../models/transformers/cogvideox_transformer_3d.py | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index cad9e18b3a50..76e007b09af3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -368,7 +368,7 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): image_embeds = image_embeds.flatten(3).transpose(2, 3) # [B, F, H x W, C] image_embeds = image_embeds.flatten(1, 2) # [B, F x H x W, C] - embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() # [B, S + F x H x W, C] return embeds diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 31cdfb2881cd..f1b1c8a431c7 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -372,7 +372,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, return_dict: bool = True, ): - batch_size, channels, num_frames, height, width = sample.shape + batch_size, num_frames, channels, height, width = sample.shape # 1. Time embedding timesteps = timestep @@ -415,11 +415,7 @@ def forward( # 2. Prepare attention mask if attention_mask is None: - attention_mask = torch.ones( - batch_size, - self.num_patches + self.config.max_text_seq_length, - self.num_patches + self.config.max_text_seq_length, - ) + attention_mask = torch.ones(batch_size, self.num_patches + self.config.max_text_seq_length) attention_mask = attention_mask.to(device=sample.device, dtype=sample.dtype) # 5. Transformer blocks @@ -455,13 +451,12 @@ def custom_forward(*inputs): shift, scale = self.adaln_out(emb).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) hidden_states = self._modulate(hidden_states, shift, scale) - hidden_states = self.norm_out(hidden_states) hidden_states = self.proj_out(hidden_states) # 7. Unpatchify p = self.config.patch_size - output = hidden_states.reshape(-1, height // p, width // p, p, p, self.config.out_channels) - output = output.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, p, p, self.config.out_channels) + output = output.permute(0, 1, 6, 2, 4, 3, 5).flatten(5, 6).flatten(3, 4) if not return_dict: return output From 45cb1f92d354efcb49bb18622f2560fdccee528a Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 14:57:57 +0200 Subject: [PATCH 06/94] fix layernorms --- .../models/transformers/cogvideox_transformer_3d.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index f1b1c8a431c7..303d25043ea5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -24,7 +24,6 @@ from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -125,8 +124,7 @@ def __init__( self.norm0 = AdaLayerNorm(time_embed_dim, 12 * dim) # 1. Self Attention - # TODO: verify if this should actually be FP32LayerNorm or if nn.LayerNorm is okay - self.norm1 = FP32LayerNorm(dim, norm_eps) + self.norm1 = nn.LayerNorm(dim, norm_eps) self.attn1 = Attention( query_dim=dim, @@ -140,7 +138,7 @@ def __init__( ) # 2. Feed Forward - self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.ff = FeedForward( dim, @@ -348,11 +346,11 @@ def __init__( for _ in range(num_layers) ] ) - self.norm_final = FP32LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) # 5. Output blocks self.adaln_out = AdaLayerNorm(time_embed_dim, 2 * inner_dim) - self.norm_out = FP32LayerNorm(inner_dim, 1e-06, norm_elementwise_affine) + self.norm_out = nn.LayerNorm(inner_dim, 1e-6, norm_elementwise_affine) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -390,7 +388,6 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors From 84ff56eb90c461ad0b98a5213b34bb0bb41df7ca Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 30 Jul 2024 21:05:04 +0800 Subject: [PATCH 07/94] fix with some review guide --- src/diffusers/models/attention_processor.py | 485 ++++++----- .../models/autoencoders/autoencoder_kl3d.py | 762 +++++++++++------- src/diffusers/models/resnet.py | 295 +++++-- src/diffusers/utils/dummy_pt_objects.py | 15 - 4 files changed, 919 insertions(+), 638 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6669222c695d..9636f34e087a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -90,33 +90,33 @@ class Attention(nn.Module): """ def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - qk_norm: Optional[str] = None, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - out_dim: int = None, - context_pre_only=None, + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, ): super().__init__() @@ -143,7 +143,7 @@ def __init__( self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 self.heads = out_dim // dim_head if out_dim is not None else heads # for slice_size > 0 the attention score computation @@ -267,7 +267,7 @@ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: self.set_processor(processor) def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ) -> None: r""" Set whether to use memory efficient attention from `xformers` or not. @@ -412,9 +412,9 @@ def set_processor(self, processor: "AttnProcessor") -> None: # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") @@ -436,11 +436,11 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return self.processor def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **cross_attention_kwargs, + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, ) -> torch.Tensor: r""" The forward method of the `Attention` class. @@ -526,7 +526,7 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten return tensor def get_attention_scores( - self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None ) -> torch.Tensor: r""" Compute the attention scores. @@ -573,7 +573,7 @@ def get_attention_scores( return attention_probs def prepare_attention_mask( - self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: r""" Prepare the attention mask for the attention computation. @@ -701,14 +701,14 @@ class AttnProcessor: """ def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -787,13 +787,13 @@ class CustomDiffusionAttnProcessor(nn.Module): """ def __init__( - self, - train_kv: bool = True, - train_q_out: bool = True, - hidden_size: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - dropout: float = 0.0, + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv @@ -813,11 +813,11 @@ def __init__( self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -878,13 +878,13 @@ class AttnAddedKVProcessor: """ def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -951,13 +951,13 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1022,13 +1022,13 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, ) -> torch.FloatTensor: residual = hidden_states @@ -1071,7 +1071,7 @@ def __call__( # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], + hidden_states[:, residual.shape[1]:], ) # linear proj @@ -1097,13 +1097,13 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, ) -> torch.FloatTensor: residual = hidden_states @@ -1150,7 +1150,7 @@ def __call__( # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], + hidden_states[:, residual.shape[1]:], ) # linear proj @@ -1178,12 +1178,12 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, ) -> torch.FloatTensor: batch_size = hidden_states.shape[0] @@ -1244,7 +1244,7 @@ def __call__( # Split the attention outputs. if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( - hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, encoder_hidden_states.shape[1]:], hidden_states[:, : encoder_hidden_states.shape[1]], ) @@ -1277,11 +1277,11 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) @@ -1348,14 +1348,14 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1439,14 +1439,14 @@ def __init__(self): raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1545,14 +1545,14 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1635,13 +1635,13 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb @@ -1736,13 +1736,13 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb @@ -1836,14 +1836,14 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[torch.Tensor] = None, - key_rotary_emb: Optional[torch.Tensor] = None, - base_sequence_length: Optional[int] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb @@ -1941,14 +1941,14 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -2046,14 +2046,14 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): """ def __init__( - self, - train_kv: bool = True, - train_q_out: bool = False, - hidden_size: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - dropout: float = 0.0, - attention_op: Optional[Callable] = None, + self, + train_kv: bool = True, + train_q_out: bool = False, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + attention_op: Optional[Callable] = None, ): super().__init__() self.train_kv = train_kv @@ -2074,11 +2074,11 @@ def __init__( self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2159,13 +2159,13 @@ class CustomDiffusionAttnProcessor2_0(nn.Module): """ def __init__( - self, - train_kv: bool = True, - train_q_out: bool = True, - hidden_size: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - dropout: float = 0.0, + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv @@ -2185,11 +2185,11 @@ def __init__( self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -2266,11 +2266,11 @@ def __init__(self, slice_size: int): self.slice_size = slice_size def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states @@ -2353,12 +2353,12 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__( - self, - attn: "Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states @@ -2443,9 +2443,9 @@ class SpatialNorm(nn.Module): """ def __init__( - self, - f_channels: int, - zq_channels: int, + self, + f_channels: int, + zq_channels: int, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) @@ -2459,6 +2459,43 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f +class SpatialNorm3D(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) + self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + if zq.shape[2] > 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) + z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + class IPAdapterAttnProcessor(nn.Module): r""" @@ -2499,14 +2536,14 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - scale: float = 1.0, - ip_adapter_masks: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, ): residual = hidden_states @@ -2594,7 +2631,7 @@ def __call__( # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): @@ -2702,14 +2739,14 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - scale: float = 1.0, - ip_adapter_masks: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, ): residual = hidden_states @@ -2811,7 +2848,7 @@ def __call__( # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): @@ -2901,12 +2938,12 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: @@ -3000,12 +3037,12 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index de40596b643f..c13d562f7397 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -1,59 +1,63 @@ +from typing import Optional, Tuple, Union + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from beartype import beartype -from beartype.typing import Callable, Optional, Tuple, Union -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation +from ..attention_processor import SpatialNorm3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin +from ..resnet import ResnetBlock3D from .vae import DecoderOutput, DiagonalGaussianDistribution -def normalize3d(in_channels, z_ch, add_conv): - return SpatialNorm3D( - in_channels, - z_ch, - norm_layer=nn.GroupNorm, - freeze_norm_layer=False, - add_conv=add_conv, - num_groups=32, - eps=1e-6, - affine=True, - ) +class SaveConv3d(torch.nn.Conv3d): + """ + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 -class SafeConv3d(torch.nn.Conv3d): - def forward(self, input): - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 - if memory_count > 2: # Set to 2GB + # Set to 2GB, Suit for CuDNN + if memory_count > 2: kernel_size = self.kernel_size[0] part_num = int(memory_count / 2) + 1 - input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + input_chunks = torch.chunk(input, part_num, dim=2) + if kernel_size > 1: input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) for i in range(1, len(input_chunks)) ] output_chunks = [] for input_chunk in input_chunks: - output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output_chunks.append(super(SaveConv3d, self).forward(input_chunk)) output = torch.cat(output_chunks, dim=2) return output else: - return super(SafeConv3d, self).forward(input) + return super(SaveConv3d, self).forward(input) -class OriginCausalConv3d(nn.Module): - @beartype +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXCausalConv3d(nn.Module): + """ + A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + """ + def __init__( - self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode: str = "constant", + **kwargs ): super().__init__() @@ -79,7 +83,16 @@ def cast_tuple(t, length=1): stride = (stride, 1, 1) dilation = (dilation, 1, 1) - self.conv = SafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.conv = SaveConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + **kwargs + ) + + self.conv_cache = None def forward(self, x): if self.pad_mode == "constant": @@ -91,8 +104,7 @@ def forward(self, x): causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) x = F.pad(x, causal_padding_2d, mode="constant", value=0) elif self.pad_mode == "reflect": - # reflect padding - reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) + reflect_x = x[:, :, 1: self.time_pad + 1, :, :].flip(dims=[2]) if reflect_x.shape[2] < self.time_pad: reflect_x = torch.cat( [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 @@ -102,75 +114,126 @@ def forward(self, x): x = F.pad(x, causal_padding_2d, mode="constant", value=0) else: raise ValueError("Invalid pad mode") - return self.conv(x) - - -class CausalConv3d(OriginCausalConv3d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.conv_cache = None - - def forward(self, x): - if self.time_pad == 0: - return super().forward(x) - if self.conv_cache is None: - self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() - return super().forward(x) - else: - # print(self.conv_cache.shape, x.shape) + if self.time_pad != 0 and self.conv_cache is None: + self.conv_cache = x[:, :, -self.time_pad:].detach().clone().cpu() + return self.conv(x) + elif self.time_pad != 0 and self.conv_cache is not None: x = torch.cat([self.conv_cache.to(x.device), x], dim=2) causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) x = F.pad(x, causal_padding_2d, mode="constant", value=0) self.conv_cache = None return self.conv(x) + return self.conv(x) + + +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXSpatialNorm3D(SpatialNorm3D): + """ + Use SaveConv3d instead of nn.Conv3d to avoid OOM in CogVideoX Model + """ -class SpatialNorm3D(nn.Module): def __init__( - self, - f_channels, - z_channels, - norm_layer=nn.GroupNorm, - freeze_norm_layer=False, - add_conv=False, - pad_mode="constant", - **norm_layer_params, + self, + f_channels: int, + zq_channels: int, ): - super().__init__() - self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) - if freeze_norm_layer: - for p in self.norm_layer.parameters: - p.requires_grad = False - self.add_conv = add_conv - if self.add_conv: - self.conv = CausalConv3d(z_channels, z_channels, kernel_size=3, pad_mode=pad_mode) - - self.conv_y = CausalConv3d(z_channels, f_channels, kernel_size=1, pad_mode=pad_mode) - self.conv_b = CausalConv3d(z_channels, f_channels, kernel_size=1, pad_mode=pad_mode) - - def forward(self, f, z): - if z.shape[2] > 1: + super().__init__(f_channels, zq_channels) + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv = SaveConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) + self.conv_y = SaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = SaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + if zq.shape[2] > 1: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] - z_first, z_rest = z[:, :, :1], z[:, :, 1:] - z_first = torch.nn.functional.interpolate(z_first, size=f_first_size, mode="nearest") - z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size, mode="nearest") - z = torch.cat([z_first, z_rest], dim=2) + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) + z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) else: - z = torch.nn.functional.interpolate(z, size=f.shape[-3:], mode="nearest") - if self.add_conv: - z = self.conv(z) + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) + zq = self.conv(zq) norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(z) + self.conv_b(z) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f +# Todo: Create vae_3d.py such as vae.py file? class UpSample3D(nn.Module): - def __init__(self, in_channels: int, with_conv: bool, compress_time: bool = False): - super(UpSample3D, self).__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `UpSample` class.""" + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + + return x + + +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXUpzSample3D(UpSample3D): + r""" + Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + compress_time: bool = False + ): + super().__init__(in_channels, out_channels) + + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1 + ) self.compress_time = compress_time def forward(self, x): @@ -179,124 +242,168 @@ def forward(self, x): # split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:] - x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") - x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") - x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0) + x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + x = torch.cat([x_first, x_rest], dim=2) elif x.shape[2] > 1: - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = torch.nn.functional.interpolate(x, scale_factor=2.0) else: x = x.squeeze(2) - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = torch.nn.functional.interpolate(x, scale_factor=2.0) x = x[:, :, None, :, :] else: # only interpolate 2D - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - - if self.with_conv: - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.conv(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x +# Todo: Create vae_3d.py such as vae.py file? class DownSample3D(nn.Module): + r""" + Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in CogVideoX Model. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ + def __init__( - self, - in_channels: int, - with_conv: bool = False, - compress_time: bool = False, - out_channels: Optional[int] = None, + self, + in_channels: int, + out_channels: int, + compress_time: bool = False, + ): super(DownSample3D, self).__init__() - self.with_conv = with_conv - if out_channels is None: - out_channels = in_channels - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.compress_time: - h, w = x.shape[-2:] - x = rearrange(x, "b c t h w -> (b h w) c t") + + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) if x.shape[-1] % 2 == 1: + # split first frame x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) - x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) - - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.conv(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - else: - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x -class ResnetBlock3D(nn.Module): +class CogVideoXResnetBlock3D(ResnetBlock3D): def __init__( - self, - *, - in_channels: int, - out_channels: int, - conv_shortcut: bool = False, - dropout: float, - act_fn: str = "silu", - temb_channels: int = 512, - z_ch: Optional[int] = None, - add_conv: bool = False, - pad_mode: str = "constant", - norm_num_groups: int = 32, - normalization: Callable = None, + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + latent_channels: Optional[int] = None, + pad_mode: str = "first", ): - super(ResnetBlock3D, self).__init__() - self.in_channels = in_channels - self.act_fn = get_activation(act_fn) + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + eps=eps, + non_linearity=non_linearity, + conv_shortcut=conv_shortcut, + latent_channels=latent_channels + ) + out_channels = in_channels if out_channels is None else out_channels + + self.in_channels = in_channels self.out_channels = out_channels + self.act_fn = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut - if normalization is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=norm_num_groups, eps=1e-6) + if latent_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) else: - self.norm1 = normalization( - in_channels, - z_ch=z_ch, - add_conv=add_conv, + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=latent_channels, ) - self.norm2 = normalization(out_channels, z_ch=z_ch, add_conv=add_conv) - - self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=latent_channels, + ) + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + pad_mode=pad_mode + ) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.temb_proj = torch.nn.Linear( + in_features=temb_channels, + out_features=out_channels + ) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + pad_mode=pad_mode + ) + if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + self.conv_shortcut = CogVideoXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + pad_mode=pad_mode + ) else: - self.nin_shortcut = SafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = SaveConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0) def forward(self, x, temb, z=None): h = x @@ -326,11 +433,10 @@ def forward(self, x, temb, z=None): return x + h - +#Todo: Need refactor?@a-r-r-o-w class AttnBlock2D(nn.Module): def __init__(self, in_channels, norm_num_groups): super().__init__() - self.in_channels = in_channels self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) @@ -342,140 +448,147 @@ def forward(self, x): h_ = x h_ = self.norm(h_) - t = h_.shape[2] - h_ = rearrange(h_, "b c t h w -> (b t) c h w") + b, c, t, h, w = h_.shape + h_ = h_.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - - # # original version, nan in fp16 - # w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - # w_ = w_ * (int(c)**(-0.5)) - # # implement c**-0.5 on q + b_t, c, h, w = q.shape + q = q.reshape(b_t, c, h * w) + q = q.permute(0, 2, 1) # b_t, hw, c + k = k.reshape(b_t, c, h * w) # b_t, c, hw + # implement c**-0.5 on q q = q * (int(c) ** (-0.5)) - w_ = torch.bmm(q, k) - # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = torch.bmm(q, k) # b_t, hw, hw w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) + v = v.reshape(b_t, c, h * w) + w_ = w_.permute(0, 2, 1) # b_t, hw, hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b_t, c, hw (hw of q) + h_ = h_.reshape(b_t, c, h, w) h_ = self.proj_out(h_) - h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t) - - return x + h_ + h_ = h_.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) + return h_ +# Todo: Create vae_3d.py such as vae.py file? class Encoder3D(nn.Module): def __init__( - self, - *, - ch: int, - in_channels: int = 3, - out_channels: int = 16, - ch_mult: Tuple[int, ...] = (1, 2, 4, 8), - num_res_blocks: int, - act_fn: str = "silu", - norm_num_groups: int = 32, - attn_resolutions=None, - dropout: float = 0.0, - resamp_with_conv: bool = True, - resolution: int, - z_channels: int, - double_z: bool = True, - pad_mode: str = "first", - temporal_compress_times: int = 4, + self, + *, + in_channels: int = 3, + out_channels: int = 16, + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + num_res_blocks: int, + act_fn: str = "silu", + norm_num_groups: int = 32, + attn_resolutions=None, + dropout: float = 0.0, + resolution: int, + latent_channels: int, + double_z: bool = True, + pad_mode: str = "first", + temporal_compress_times: int = 4, ): super(Encoder3D, self).__init__() if attn_resolutions is None: attn_resolutions = [] self.act_fn = get_activation(act_fn) - self.ch = ch - self.num_resolutions = len(ch_mult) + self.num_resolutions = len(block_out_channels) self.num_res_blocks = num_res_blocks self.resolution = resolution - self.in_channels = in_channels self.attn_resolutions = attn_resolutions # log2 of temporal_compress_times self.temporal_compress_level = int(np.log2(temporal_compress_times)) - self.conv_in = CausalConv3d(in_channels, self.ch, kernel_size=3, pad_mode=pad_mode) + self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) + in_ch_mult = (block_out_channels[0],) + tuple(block_out_channels) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] + + block_in = in_ch_mult[i_level] + block_out = block_out_channels[i_level] + for i_block in range(self.num_res_blocks): block.append( - ResnetBlock3D( + CogVideoXResnetBlock3D( in_channels=block_in, out_channels=block_out, temb_channels=0, - act_fn=act_fn, + non_linearity=act_fn, dropout=dropout, - norm_num_groups=norm_num_groups, + groups=norm_num_groups, pad_mode=pad_mode, ) ) block_in = block_out if curr_res in attn_resolutions: - attn.append(AttnBlock2D(block_in)) + attn.append( + AttnBlock2D( + in_channels=block_in, + norm_num_groups=norm_num_groups + ) + ) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: if i_level < self.temporal_compress_level: - down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) + down.downsample = DownSample3D( + in_channels=block_in, + out_channels=block_in, + compress_time=True + ) else: - down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) + down.downsample = DownSample3D( + in_channels=block_in, + out_channels=block_in, + compress_time=False + ) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock3D( + block_in = in_ch_mult[-1] + self.mid.block_1 = CogVideoXResnetBlock3D( in_channels=block_in, out_channels=block_in, - act_fn=act_fn, + non_linearity=act_fn, temb_channels=0, - norm_num_groups=norm_num_groups, + groups=norm_num_groups, dropout=dropout, pad_mode=pad_mode, ) if len(attn_resolutions) > 0: self.mid.attn_1 = AttnBlock2D(block_in) - self.mid.block_2 = ResnetBlock3D( + self.mid.block_2 = CogVideoXResnetBlock3D( in_channels=block_in, out_channels=block_in, - act_fn=act_fn, + non_linearity=act_fn, temb_channels=0, - norm_num_groups=norm_num_groups, + groups=norm_num_groups, dropout=dropout, pad_mode=pad_mode, ) self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=norm_num_groups, eps=1e-6) conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = CausalConv3d( - block_in, conv_out_channels if double_z else z_channels, kernel_size=3, pad_mode=pad_mode + self.conv_out = CogVideoXCausalConv3d( + block_in, conv_out_channels if double_z else latent_channels, kernel_size=3, pad_mode=pad_mode ) def forward(self, x): @@ -505,38 +618,54 @@ def forward(self, x): h = self.conv_out(h) return h - +# Todo: Create vae_3d.py such as vae.py file? class Decoder3D(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + def __init__( - self, - *, - ch: int, - in_channels: int = 16, - out_channels: int = 3, - ch_mult: Tuple[int, ...] = (1, 2, 4, 8), - num_res_blocks: int, - attn_resolutions=None, - act_fn: str = "silu", - dropout: float = 0.0, - resamp_with_conv: bool = True, - resolution: int, - z_channels: int, - give_pre_end: bool = False, - z_ch: Optional[int] = None, - add_conv: bool = False, - pad_mode: str = "first", - temporal_compress_times: int = 4, - norm_num_groups=32, + self, + *, + in_channels: int = 16, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + num_res_blocks: int, + attn_resolutions=None, + act_fn: str = "silu", + dropout: float = 0.0, + resolution: int, + latent_channels: int, + give_pre_end: bool = False, + pad_mode: str = "first", + temporal_compress_times: int = 4, + norm_num_groups=32, ): super(Decoder3D, self).__init__() if attn_resolutions is None: attn_resolutions = [] - self.ch = ch self.act_fn = get_activation(act_fn) - self.num_resolutions = len(ch_mult) + self.num_resolutions = len(block_out_channels) self.num_res_blocks = num_res_blocks self.resolution = resolution - self.in_channels = in_channels self.give_pre_end = give_pre_end self.attn_resolutions = attn_resolutions self.norm_num_groups = norm_num_groups @@ -545,90 +674,102 @@ def __init__( self.temporal_compress_level = int(np.log2(temporal_compress_times)) - if z_ch is None: - z_ch = z_channels - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = ch * ch_mult[self.num_resolutions - 1] + block_in = block_out_channels[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) + self.z_shape = (1, latent_channels, curr_res, curr_res) print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode) + self.conv_in = CogVideoXCausalConv3d(latent_channels, block_in, kernel_size=3, pad_mode=pad_mode) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock3D( + self.mid.block_1 = CogVideoXResnetBlock3D( in_channels=block_in, out_channels=block_in, temb_channels=0, dropout=dropout, - act_fn=act_fn, - z_ch=z_ch, - add_conv=add_conv, - normalization=normalize3d, - norm_num_groups=norm_num_groups, + non_linearity=act_fn, + latent_channels=latent_channels, + groups=norm_num_groups, pad_mode=pad_mode, ) if len(attn_resolutions) > 0: - self.mid.attn_1 = AttnBlock2D(block_in) - self.mid.block_2 = ResnetBlock3D( + self.mid.attn_1 = AttnBlock2D( + in_channels=block_in, + norm_num_groups=norm_num_groups + ) + + self.mid.block_2 = CogVideoXResnetBlock3D( in_channels=block_in, out_channels=block_in, temb_channels=0, dropout=dropout, - act_fn=act_fn, - z_ch=z_ch, - add_conv=add_conv, - normalization=normalize3d, - norm_num_groups=norm_num_groups, + non_linearity=act_fn, + latent_channels=latent_channels, + groups=norm_num_groups, pad_mode=pad_mode, ) - # upsampling + # UpSampling + self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] + block_out = block_out_channels[i_level] for i_block in range(self.num_res_blocks + 1): block.append( - ResnetBlock3D( + CogVideoXResnetBlock3D( in_channels=block_in, out_channels=block_out, temb_channels=0, - act_fn=act_fn, + non_linearity=act_fn, dropout=dropout, - z_ch=z_ch, - add_conv=add_conv, - normalization=normalize3d, - norm_num_groups=norm_num_groups, + latent_channels=latent_channels, + groups=norm_num_groups, pad_mode=pad_mode, ) ) block_in = block_out if curr_res in attn_resolutions: - attn.append(AttnBlock2D(block_in=block_in, norm_num_groups=norm_num_groups)) + attn.append( + AttnBlock2D( + in_channels=block_in, + norm_num_groups=norm_num_groups + ) + ) up = nn.Module() up.block = block up.attn = attn if i_level != 0: if i_level < self.num_resolutions - self.temporal_compress_level: - up.upsample = UpSample3D(block_in, resamp_with_conv, compress_time=False) + up.upsample = CogVideoXUpzSample3D( + in_channels=block_in, + out_channels=block_in, + compress_time=False + ) else: - up.upsample = UpSample3D(block_in, resamp_with_conv, compress_time=True) + up.upsample = CogVideoXUpzSample3D( + in_channels=block_in, + out_channels=block_in, + compress_time=True + ) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - self.norm_out = normalize3d(block_in, z_ch, add_conv=add_conv) + self.up.insert(0, up) + + self.norm_out = CogVideoXSpatialNorm3D( + f_channels=block_in, + zq_channels=latent_channels + ) - self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) + self.conv_out = CogVideoXCausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) def forward(self, z): # timestep embedding - temb = None - # z to block_in + temb = None h = self.conv_in(z) @@ -638,7 +779,8 @@ def forward(self, z): h = self.mid.attn_1(h, z) h = self.mid.block_2(h, temb, z) - # upsampling + # UpSampling + for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, z) @@ -694,59 +836,55 @@ class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["ResnetBlock3D"] + _no_split_modules = ["CogVideoXResnetBlock3D"] @register_to_config def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), - ch: int = 128, - block_out_channels: Tuple[int] = (1, 2, 2, 4), - layers_per_block: int = 3, - act_fn: str = "silu", - latent_channels: int = 16, - norm_num_groups: int = 32, - sample_size: int = 256, - # Do Not Know how to use - scaling_factor: float = 0.13025, - shift_factor: Optional[float] = None, - latents_mean: Optional[Tuple[float]] = None, - latents_std: Optional[Tuple[float]] = None, - force_upcast: float = True, - use_quant_conv: bool = False, - use_post_quant_conv: bool = False, - mid_block_add_attention: bool = True, + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + latent_channels: int = 16, + norm_num_groups: int = 32, + sample_size: int = 256, + scaling_factor: float = 1.15258426, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + mid_block_add_attention: bool = True, ): super().__init__() self.encoder = Encoder3D( in_channels=in_channels, out_channels=latent_channels, - ch_mult=block_out_channels, - ch=ch, + block_out_channels=block_out_channels, num_res_blocks=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, resolution=sample_size, - z_channels=latent_channels, + latent_channels=latent_channels, ) self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, - ch=ch, - ch_mult=block_out_channels, + block_out_channels=block_out_channels, norm_num_groups=norm_num_groups, act_fn=act_fn, num_res_blocks=layers_per_block, resolution=sample_size, - z_channels=latent_channels, + latent_channels=latent_channels, ) - self.quant_conv = SafeConv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None - self.post_quant_conv = SafeConv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + self.quant_conv = SaveConv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = SaveConv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None self.use_slicing = False self.use_tiling = False @@ -757,7 +895,7 @@ def __init__( if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) - # self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): @@ -795,7 +933,7 @@ def disable_slicing(self): @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -819,7 +957,7 @@ def encode( @apply_forward_hook def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None + self, z: torch.FloatTensor, return_dict: bool = True, generator=None ) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. @@ -843,11 +981,11 @@ def decode( return dec def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample posterior = self.encode(x).latent_dist diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 00b55cd9c9d6..ec8c003e5a3d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -22,7 +22,7 @@ from ..utils import deprecate from .activations import get_activation -from .attention_processor import SpatialNorm +from .attention_processor import SpatialNorm, SpatialNorm3D from .downsampling import ( # noqa Downsample1D, Downsample2D, @@ -72,24 +72,24 @@ class ResnetBlockCondNorm2D(nn.Module): """ def __init__( - self, - *, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - eps: float = 1e-6, - non_linearity: str = "swish", - time_embedding_norm: str = "ada_group", # ada_group, spatial - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "ada_group", # ada_group, spatial + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, ): super().__init__() self.in_channels = in_channels @@ -218,27 +218,27 @@ class ResnetBlock2D(nn.Module): """ def __init__( - self, - *, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - pre_norm: bool = True, - eps: float = 1e-6, - non_linearity: str = "swish", - skip_time_act: bool = False, - time_embedding_norm: str = "default", # default, scale_shift, - kernel: Optional[torch.Tensor] = None, - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, + kernel: Optional[torch.Tensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, ): super().__init__() if time_embedding_norm == "ada_group": @@ -373,6 +373,127 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg return output_tensor +class ResnetBlock3D(nn.Module): + r""" + A Resnet3D block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + latent_channels: Optional[int] = None, + ): + super().__init__() + out_channels = in_channels if out_channels is None else out_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.non_linearity = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut + + if latent_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = SpatialNorm3D( + f_channels=in_channels, + zq_channels=latent_channels, + ) + self.norm2 = SpatialNorm3D( + f_channels=out_channels, + zq_channels=latent_channels, + ) + self.conv1 = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear( + in_features=temb_channels, + out_features=out_channels + ) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + ) + else: + self.nin_shortcut = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states,*args) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] + + hidden_states = self.norm2(hidden_states,*args) + input_tensor = self.non_linearity(input_tensor) + input_tensor = self.dropout(input_tensor) + input_tensor = self.conv2(input_tensor) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + input_tensor = self.conv_shortcut(input_tensor) + else: + input_tensor = self.nin_shortcut(input_tensor) + + return input_tensor + hidden_states + + # unet_rl.py def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: if len(tensor.shape) == 2: @@ -398,12 +519,12 @@ class Conv1dBlock(nn.Module): """ def __init__( - self, - inp_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - n_groups: int = 8, - activation: str = "mish", + self, + inp_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + n_groups: int = 8, + activation: str = "mish", ): super().__init__() @@ -434,12 +555,12 @@ class ResidualTemporalBlock1D(nn.Module): """ def __init__( - self, - inp_channels: int, - out_channels: int, - embed_dim: int, - kernel_size: Union[int, Tuple[int, int]] = 5, - activation: str = "mish", + self, + inp_channels: int, + out_channels: int, + embed_dim: int, + kernel_size: Union[int, Tuple[int, int]] = 5, + activation: str = "mish", ): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) @@ -480,11 +601,11 @@ class TemporalConvLayer(nn.Module): """ def __init__( - self, - in_dim: int, - out_dim: Optional[int] = None, - dropout: float = 0.0, - norm_num_groups: int = 32, + self, + in_dim: int, + out_dim: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, ): super().__init__() out_dim = out_dim or in_dim @@ -493,24 +614,24 @@ def __init__( # conv layers self.conv1 = nn.Sequential( - nn.GroupNorm(norm_num_groups, in_dim), + nn.GroupNorm(groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv2 = nn.Sequential( - nn.GroupNorm(norm_num_groups, out_dim), + nn.GroupNorm(groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv3 = nn.Sequential( - nn.GroupNorm(norm_num_groups, out_dim), + nn.GroupNorm(groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv4 = nn.Sequential( - nn.GroupNorm(norm_num_groups, out_dim), + nn.GroupNorm(groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), @@ -552,11 +673,11 @@ class TemporalResnetBlock(nn.Module): """ def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - temb_channels: int = 512, - eps: float = 1e-6, + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, ): super().__init__() self.in_channels = in_channels @@ -651,15 +772,15 @@ class SpatioTemporalResBlock(nn.Module): """ def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - temb_channels: int = 512, - eps: float = 1e-6, - temporal_eps: Optional[float] = None, - merge_factor: float = 0.5, - merge_strategy="learned_with_images", - switch_spatial_to_temporal_mix: bool = False, + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, + temporal_eps: Optional[float] = None, + merge_factor: float = 0.5, + merge_strategy="learned_with_images", + switch_spatial_to_temporal_mix: bool = False, ): super().__init__() @@ -684,10 +805,10 @@ def __init__( ) def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - image_only_indicator: Optional[torch.Tensor] = None, + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, ): num_frames = image_only_indicator.shape[-1] hidden_states = self.spatial_res_block(hidden_states, temb) @@ -731,10 +852,10 @@ class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] def __init__( - self, - alpha: float, - merge_strategy: str = "learned_with_images", - switch_spatial_to_temporal_mix: bool = False, + self, + alpha: float, + merge_strategy: str = "learned_with_images", + switch_spatial_to_temporal_mix: bool = False, ): super().__init__() self.merge_strategy = merge_strategy @@ -782,10 +903,10 @@ def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Ten return alpha def forward( - self, - x_spatial: torch.Tensor, - x_temporal: torch.Tensor, - image_only_indicator: Optional[torch.Tensor] = None, + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator, x_spatial.ndim) alpha = alpha.to(x_spatial.dtype) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 4d5ec21fd196..3ead6fd99d10 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -47,21 +47,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AutoencoderKL3D(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] From a3d827fb8de713e054f865d3753c18344c3a64b9 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 30 Jul 2024 22:00:53 +0800 Subject: [PATCH 08/94] rename --- src/diffusers/models/attention_processor.py | 469 ++++++++--------- .../models/autoencoders/autoencoder_kl3d.py | 475 ++++++++---------- src/diffusers/models/resnet.py | 266 +++++----- 3 files changed, 579 insertions(+), 631 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9636f34e087a..9915ad745be6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -90,33 +90,33 @@ class Attention(nn.Module): """ def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - qk_norm: Optional[str] = None, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - out_dim: int = None, - context_pre_only=None, + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, ): super().__init__() @@ -143,7 +143,7 @@ def __init__( self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk - self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = out_dim // dim_head if out_dim is not None else heads # for slice_size > 0 the attention score computation @@ -267,7 +267,7 @@ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: self.set_processor(processor) def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ) -> None: r""" Set whether to use memory efficient attention from `xformers` or not. @@ -412,9 +412,9 @@ def set_processor(self, processor: "AttnProcessor") -> None: # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") @@ -436,11 +436,11 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return self.processor def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **cross_attention_kwargs, + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, ) -> torch.Tensor: r""" The forward method of the `Attention` class. @@ -526,7 +526,7 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten return tensor def get_attention_scores( - self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None ) -> torch.Tensor: r""" Compute the attention scores. @@ -573,7 +573,7 @@ def get_attention_scores( return attention_probs def prepare_attention_mask( - self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: r""" Prepare the attention mask for the attention computation. @@ -701,14 +701,14 @@ class AttnProcessor: """ def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -787,13 +787,13 @@ class CustomDiffusionAttnProcessor(nn.Module): """ def __init__( - self, - train_kv: bool = True, - train_q_out: bool = True, - hidden_size: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - dropout: float = 0.0, + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv @@ -813,11 +813,11 @@ def __init__( self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -878,13 +878,13 @@ class AttnAddedKVProcessor: """ def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -951,13 +951,13 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1022,13 +1022,13 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, ) -> torch.FloatTensor: residual = hidden_states @@ -1071,7 +1071,7 @@ def __call__( # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1]:], + hidden_states[:, residual.shape[1] :], ) # linear proj @@ -1097,13 +1097,13 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, ) -> torch.FloatTensor: residual = hidden_states @@ -1150,7 +1150,7 @@ def __call__( # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1]:], + hidden_states[:, residual.shape[1] :], ) # linear proj @@ -1178,12 +1178,12 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, ) -> torch.FloatTensor: batch_size = hidden_states.shape[0] @@ -1244,7 +1244,7 @@ def __call__( # Split the attention outputs. if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( - hidden_states[:, encoder_hidden_states.shape[1]:], + hidden_states[:, encoder_hidden_states.shape[1] :], hidden_states[:, : encoder_hidden_states.shape[1]], ) @@ -1277,11 +1277,11 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) @@ -1348,14 +1348,14 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1439,14 +1439,14 @@ def __init__(self): raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1545,14 +1545,14 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -1635,13 +1635,13 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb @@ -1736,13 +1736,13 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb @@ -1836,14 +1836,14 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[torch.Tensor] = None, - key_rotary_emb: Optional[torch.Tensor] = None, - base_sequence_length: Optional[int] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb @@ -1941,14 +1941,14 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -2046,14 +2046,14 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): """ def __init__( - self, - train_kv: bool = True, - train_q_out: bool = False, - hidden_size: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - dropout: float = 0.0, - attention_op: Optional[Callable] = None, + self, + train_kv: bool = True, + train_q_out: bool = False, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + attention_op: Optional[Callable] = None, ): super().__init__() self.train_kv = train_kv @@ -2074,11 +2074,11 @@ def __init__( self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2159,13 +2159,13 @@ class CustomDiffusionAttnProcessor2_0(nn.Module): """ def __init__( - self, - train_kv: bool = True, - train_q_out: bool = True, - hidden_size: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - dropout: float = 0.0, + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv @@ -2185,11 +2185,11 @@ def __init__( self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -2266,11 +2266,11 @@ def __init__(self, slice_size: int): self.slice_size = slice_size def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states @@ -2353,12 +2353,12 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__( - self, - attn: "Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states @@ -2443,9 +2443,9 @@ class SpatialNorm(nn.Module): """ def __init__( - self, - f_channels: int, - zq_channels: int, + self, + f_channels: int, + zq_channels: int, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) @@ -2459,24 +2459,25 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f + class SpatialNorm3D(nn.Module): """ - Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. - Args: - f_channels (`int`): - The number of channels for input to group normalization layer, and output of the spatial norm layer. - zq_channels (`int`): - The number of channels for the quantized vector as described in the paper. + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. """ def __init__( - self, - f_channels: int, - zq_channels: int, + self, + f_channels: int, + zq_channels: int, ): super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) @@ -2536,14 +2537,14 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - scale: float = 1.0, - ip_adapter_masks: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, ): residual = hidden_states @@ -2631,7 +2632,7 @@ def __call__( # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): @@ -2739,14 +2740,14 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale ) def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - scale: float = 1.0, - ip_adapter_masks: Optional[torch.Tensor] = None, + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, ): residual = hidden_states @@ -2848,7 +2849,7 @@ def __call__( # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): @@ -2938,12 +2939,12 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: @@ -3037,12 +3038,12 @@ def __init__(self): ) def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index c13d562f7397..3c2f11662321 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -22,7 +22,7 @@ class SaveConv3d(torch.nn.Conv3d): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 # Set to 2GB, Suit for CuDNN if memory_count > 2: @@ -32,7 +32,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if kernel_size > 1: input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) for i in range(1, len(input_chunks)) ] @@ -52,12 +52,12 @@ class CogVideoXCausalConv3d(nn.Module): """ def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - pad_mode: str = "constant", - **kwargs + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode: str = "constant", + **kwargs, ): super().__init__() @@ -89,7 +89,7 @@ def cast_tuple(t, length=1): kernel_size=kernel_size, stride=stride, dilation=dilation, - **kwargs + **kwargs, ) self.conv_cache = None @@ -104,7 +104,7 @@ def forward(self, x): causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) x = F.pad(x, causal_padding_2d, mode="constant", value=0) elif self.pad_mode == "reflect": - reflect_x = x[:, :, 1: self.time_pad + 1, :, :].flip(dims=[2]) + reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) if reflect_x.shape[2] < self.time_pad: reflect_x = torch.cat( [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 @@ -115,7 +115,7 @@ def forward(self, x): else: raise ValueError("Invalid pad mode") if self.time_pad != 0 and self.conv_cache is None: - self.conv_cache = x[:, :, -self.time_pad:].detach().clone().cpu() + self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() return self.conv(x) elif self.time_pad != 0 and self.conv_cache is not None: x = torch.cat([self.conv_cache.to(x.device), x], dim=2) @@ -130,13 +130,13 @@ def forward(self, x): # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXSpatialNorm3D(SpatialNorm3D): """ - Use SaveConv3d instead of nn.Conv3d to avoid OOM in CogVideoX Model + Use SaveConv3d instead of nn.Conv3d to avoid OOM in CogVideoX Model """ def __init__( - self, - f_channels: int, - zq_channels: int, + self, + f_channels: int, + zq_channels: int, ): super().__init__(f_channels, zq_channels) self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) @@ -163,31 +163,25 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: # Todo: Create vae_3d.py such as vae.py file? class UpSample3D(nn.Module): r""" - The `UpSample` layer of a variational autoencoder that upsamples its input. + The `UpSample` layer of a variational autoencoder that upsamples its input. - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - """ + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ def __init__( - self, - in_channels: int, - out_channels: int, + self, + in_channels: int, + out_channels: int, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels - self.conv = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1 - ) + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `UpSample` class.""" @@ -208,32 +202,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXUpzSample3D(UpSample3D): r""" - Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. - """ + Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ - def __init__( - self, - in_channels: int, - out_channels: int, - compress_time: bool = False - ): + def __init__(self, in_channels: int, out_channels: int, compress_time: bool = False): super().__init__(in_channels, out_channels) - self.conv = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1 - ) + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time def forward(self, x): @@ -270,23 +253,22 @@ def forward(self, x): # Todo: Create vae_3d.py such as vae.py file? class DownSample3D(nn.Module): r""" - Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in CogVideoX Model. - - Args: - in_channels (`int`, *optional*): - The number of input channels. - out_channels (`int`, *optional*): - The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. - """ + Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in CogVideoX Model. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ def __init__( - self, - in_channels: int, - out_channels: int, - compress_time: bool = False, - + self, + in_channels: int, + out_channels: int, + compress_time: bool = False, ): super(DownSample3D, self).__init__() @@ -295,12 +277,10 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if self.compress_time: - b, c, t, h, w = x.shape x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) if x.shape[-1] % 2 == 1: - # split first frame x_first, x_rest = x[..., 0], x[..., 1:] @@ -325,17 +305,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class CogVideoXResnetBlock3D(ResnetBlock3D): def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - eps: float = 1e-6, - non_linearity: str = "swish", - conv_shortcut: bool = False, - latent_channels: Optional[int] = None, - pad_mode: str = "first", + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + latent_channels: Optional[int] = None, + pad_mode: str = "first", ): super().__init__( in_channels=in_channels, @@ -346,7 +326,7 @@ def __init__( eps=eps, non_linearity=non_linearity, conv_shortcut=conv_shortcut, - latent_channels=latent_channels + latent_channels=latent_channels, ) out_channels = in_channels if out_channels is None else out_channels @@ -369,73 +349,63 @@ def __init__( zq_channels=latent_channels, ) self.conv1 = CogVideoXCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - pad_mode=pad_mode + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) if temb_channels > 0: - self.temb_proj = torch.nn.Linear( - in_features=temb_channels, - out_features=out_channels - ) + self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = CogVideoXCausalConv3d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - pad_mode=pad_mode + in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = CogVideoXCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - pad_mode=pad_mode + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) else: self.nin_shortcut = SaveConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x, temb, z=None): - h = x - if z is not None: - h = self.norm1(h, z) + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, **kwargs + ) -> torch.Tensor: + hidden_states = input_tensor + if zq is not None: + hidden_states = self.norm1(hidden_states, zq) else: - h = self.norm1(h) - h = self.act_fn(h) - h = self.conv1(h) + hidden_states = self.norm1(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv1(hidden_states) if temb is not None: - h = h + self.temb_proj(self.act_fn(temb))[:, :, None, None, None] + hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] - if z is not None: - h = self.norm2(h, z) + if zq is not None: + hidden_states = self.norm2(hidden_states, zq) else: - h = self.norm2(h) - h = self.act_fn(h) - h = self.dropout(h) - h = self.conv2(h) + hidden_states = self.norm2(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + input_tensor = self.conv_shortcut(input_tensor) else: - x = self.nin_shortcut(x) + input_tensor = self.nin_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor - return x + h -#Todo: Need refactor?@a-r-r-o-w +# Todo: Need refactor? @a-r-r-o-w class AttnBlock2D(nn.Module): - def __init__(self, in_channels, norm_num_groups): + def __init__(self, in_channels: int, norm_num_groups: int): super().__init__() self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) @@ -479,24 +449,25 @@ def forward(self, x): return h_ + # Todo: Create vae_3d.py such as vae.py file? class Encoder3D(nn.Module): def __init__( - self, - *, - in_channels: int = 3, - out_channels: int = 16, - block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - num_res_blocks: int, - act_fn: str = "silu", - norm_num_groups: int = 32, - attn_resolutions=None, - dropout: float = 0.0, - resolution: int, - latent_channels: int, - double_z: bool = True, - pad_mode: str = "first", - temporal_compress_times: int = 4, + self, + *, + in_channels: int = 3, + out_channels: int = 16, + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + num_res_blocks: int, + act_fn: str = "silu", + norm_num_groups: int = 32, + attn_resolutions=None, + dropout: float = 0.0, + resolution: int, + latent_channels: int, + double_z: bool = True, + pad_mode: str = "first", + temporal_compress_times: int = 4, ): super(Encoder3D, self).__init__() if attn_resolutions is None: @@ -536,28 +507,15 @@ def __init__( ) block_in = block_out if curr_res in attn_resolutions: - attn.append( - AttnBlock2D( - in_channels=block_in, - norm_num_groups=norm_num_groups - ) - ) + attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: if i_level < self.temporal_compress_level: - down.downsample = DownSample3D( - in_channels=block_in, - out_channels=block_in, - compress_time=True - ) + down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) else: - down.downsample = DownSample3D( - in_channels=block_in, - out_channels=block_in, - compress_time=False - ) + down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=False) curr_res = curr_res // 2 self.down.append(down) @@ -591,73 +549,76 @@ def __init__( block_in, conv_out_channels if double_z else latent_channels, kernel_size=3, pad_mode=pad_mode ) - def forward(self, x): + def forward(self, sample: torch.Tensor) -> torch.Tensor: # timestep embedding temb = None # downsampling - h = self.conv_in(x) + sample = self.conv_in(sample) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) + sample = self.down[i_level].block[i_block](sample, temb) if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) + sample = self.down[i_level].attn[i_block](sample) if i_level != self.num_resolutions - 1: - h = self.down[i_level].downsample(h) + sample = self.down[i_level].downsample(sample) # middle - h = self.mid.block_1(h, temb) + sample = self.mid.block_1(sample, temb) if len(self.attn_resolutions): - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + sample = self.mid.attn_1(sample) + + sample = self.mid.block_2(sample, temb) + + # post-process + sample = self.norm_out(sample) + sample = self.act_fn(sample) + sample = self.conv_out(sample) + + return sample - # end - h = self.norm_out(h) - h = self.act_fn(h) - h = self.conv_out(h) - return h # Todo: Create vae_3d.py such as vae.py file? class Decoder3D(nn.Module): r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - norm_type (`str`, *optional*, defaults to `"group"`): - The normalization type to use. Can be either `"group"` or `"spatial"`. - """ + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ def __init__( - self, - *, - in_channels: int = 16, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - num_res_blocks: int, - attn_resolutions=None, - act_fn: str = "silu", - dropout: float = 0.0, - resolution: int, - latent_channels: int, - give_pre_end: bool = False, - pad_mode: str = "first", - temporal_compress_times: int = 4, - norm_num_groups=32, + self, + *, + in_channels: int = 16, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + num_res_blocks: int, + attn_resolutions=None, + act_fn: str = "silu", + dropout: float = 0.0, + resolution: int, + latent_channels: int, + give_pre_end: bool = False, + pad_mode: str = "first", + temporal_compress_times: int = 4, + norm_num_groups=32, ): super(Decoder3D, self).__init__() if attn_resolutions is None: @@ -695,10 +656,7 @@ def __init__( pad_mode=pad_mode, ) if len(attn_resolutions) > 0: - self.mid.attn_1 = AttnBlock2D( - in_channels=block_in, - norm_num_groups=norm_num_groups - ) + self.mid.attn_1 = AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups) self.mid.block_2 = CogVideoXResnetBlock3D( in_channels=block_in, @@ -733,71 +691,58 @@ def __init__( ) block_in = block_out if curr_res in attn_resolutions: - attn.append( - AttnBlock2D( - in_channels=block_in, - norm_num_groups=norm_num_groups - ) - ) + attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: if i_level < self.num_resolutions - self.temporal_compress_level: up.upsample = CogVideoXUpzSample3D( - in_channels=block_in, - out_channels=block_in, - compress_time=False + in_channels=block_in, out_channels=block_in, compress_time=False ) else: - up.upsample = CogVideoXUpzSample3D( - in_channels=block_in, - out_channels=block_in, - compress_time=True - ) + up.upsample = CogVideoXUpzSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) curr_res = curr_res * 2 self.up.insert(0, up) - self.norm_out = CogVideoXSpatialNorm3D( - f_channels=block_in, - zq_channels=latent_channels - ) + self.norm_out = CogVideoXSpatialNorm3D(f_channels=block_in, zq_channels=latent_channels) self.conv_out = CogVideoXCausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) - def forward(self, z): + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" # timestep embedding temb = None - h = self.conv_in(z) + hidden_states = self.conv_in(sample) # middle - h = self.mid.block_1(h, temb, z) + hidden_states = self.mid.block_1(hidden_states, temb, sample) if len(self.attn_resolutions) > 0: - h = self.mid.attn_1(h, z) - h = self.mid.block_2(h, temb, z) + hidden_states = self.mid.attn_1(hidden_states, sample) + hidden_states = self.mid.block_2(hidden_states, temb, sample) # UpSampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb, z) + hidden_states = self.up[i_level].block[i_block](hidden_states, temb, sample) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, z) + hidden_states = self.up[i_level].attn[i_block](hidden_states, sample) if i_level != 0: - h = self.up[i_level].upsample(h) + hidden_states = self.up[i_level].upsample(hidden_states) # end if self.give_pre_end: - return h + return hidden_states - h = self.norm_out(h, z) - h = self.act_fn(h) - h = self.conv_out(h) + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.conv_out(hidden_states) - return h + return hidden_states class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -840,25 +785,25 @@ class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): @register_to_config def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int] = (128, 256, 256, 512), - layers_per_block: int = 3, - act_fn: str = "silu", - latent_channels: int = 16, - norm_num_groups: int = 32, - sample_size: int = 256, - scaling_factor: float = 1.15258426, - shift_factor: Optional[float] = None, - latents_mean: Optional[Tuple[float]] = None, - latents_std: Optional[Tuple[float]] = None, - force_upcast: float = True, - use_quant_conv: bool = False, - use_post_quant_conv: bool = False, - mid_block_add_attention: bool = True, + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + latent_channels: int = 16, + norm_num_groups: int = 32, + sample_size: int = 256, + scaling_factor: float = 1.15258426, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + mid_block_add_attention: bool = True, ): super().__init__() @@ -933,7 +878,7 @@ def disable_slicing(self): @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -957,7 +902,7 @@ def encode( @apply_forward_hook def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None + self, z: torch.FloatTensor, return_dict: bool = True, generator=None ) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. @@ -981,11 +926,11 @@ def decode( return dec def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample posterior = self.encode(x).latent_dist diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ec8c003e5a3d..6acc59a5f3b3 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -72,24 +72,24 @@ class ResnetBlockCondNorm2D(nn.Module): """ def __init__( - self, - *, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - eps: float = 1e-6, - non_linearity: str = "swish", - time_embedding_norm: str = "ada_group", # ada_group, spatial - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "ada_group", # ada_group, spatial + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, ): super().__init__() self.in_channels = in_channels @@ -218,27 +218,27 @@ class ResnetBlock2D(nn.Module): """ def __init__( - self, - *, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - pre_norm: bool = True, - eps: float = 1e-6, - non_linearity: str = "swish", - skip_time_act: bool = False, - time_embedding_norm: str = "default", # default, scale_shift, - kernel: Optional[torch.Tensor] = None, - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, + kernel: Optional[torch.Tensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, ): super().__init__() if time_embedding_norm == "ada_group": @@ -375,41 +375,41 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg class ResnetBlock3D(nn.Module): r""" - A Resnet3D block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv2d layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - groups_out (`int`, *optional*, default to None): - The number of groups to use for the second normalization layer. if set to None, same as `groups`. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. - use_in_shortcut (`bool`, *optional*, default to `True`): - If `True`, add a 1x1 nn.conv2d layer for skip-connection. - up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. - down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. - conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the - `conv_shortcut` output. - conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. - If None, same as `out_channels`. - """ + A Resnet3D block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ def __init__( - self, - *, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - eps: float = 1e-6, - non_linearity: str = "swish", - conv_shortcut: bool = False, - latent_channels: Optional[int] = None, + self, + *, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + latent_channels: Optional[int] = None, ): super().__init__() out_channels = in_channels if out_channels is None else out_channels @@ -437,10 +437,7 @@ def __init__( kernel_size=3, ) if temb_channels > 0: - self.temb_proj = torch.nn.Linear( - in_features=temb_channels, - out_features=out_channels - ) + self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) self.dropout = torch.nn.Dropout(dropout) @@ -459,31 +456,34 @@ def __init__( ) else: self.nin_shortcut = nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - padding=0 + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) - def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def forward( + self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, **kwargs + ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states,*args) + if zq is not None: + hidden_states = self.norm1(hidden_states, zq) + else: + hidden_states = self.norm1(hidden_states) hidden_states = self.non_linearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] - hidden_states = self.norm2(hidden_states,*args) - input_tensor = self.non_linearity(input_tensor) - input_tensor = self.dropout(input_tensor) - input_tensor = self.conv2(input_tensor) + if zq is not None: + hidden_states = self.norm2(hidden_states, zq) + else: + hidden_states = self.norm2(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -491,7 +491,9 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg else: input_tensor = self.nin_shortcut(input_tensor) - return input_tensor + hidden_states + output_tensor = input_tensor + hidden_states + + return output_tensor # unet_rl.py @@ -519,12 +521,12 @@ class Conv1dBlock(nn.Module): """ def __init__( - self, - inp_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - n_groups: int = 8, - activation: str = "mish", + self, + inp_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + n_groups: int = 8, + activation: str = "mish", ): super().__init__() @@ -555,12 +557,12 @@ class ResidualTemporalBlock1D(nn.Module): """ def __init__( - self, - inp_channels: int, - out_channels: int, - embed_dim: int, - kernel_size: Union[int, Tuple[int, int]] = 5, - activation: str = "mish", + self, + inp_channels: int, + out_channels: int, + embed_dim: int, + kernel_size: Union[int, Tuple[int, int]] = 5, + activation: str = "mish", ): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) @@ -601,11 +603,11 @@ class TemporalConvLayer(nn.Module): """ def __init__( - self, - in_dim: int, - out_dim: Optional[int] = None, - dropout: float = 0.0, - groups: int = 32, + self, + in_dim: int, + out_dim: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, ): super().__init__() out_dim = out_dim or in_dim @@ -673,11 +675,11 @@ class TemporalResnetBlock(nn.Module): """ def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - temb_channels: int = 512, - eps: float = 1e-6, + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, ): super().__init__() self.in_channels = in_channels @@ -772,15 +774,15 @@ class SpatioTemporalResBlock(nn.Module): """ def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - temb_channels: int = 512, - eps: float = 1e-6, - temporal_eps: Optional[float] = None, - merge_factor: float = 0.5, - merge_strategy="learned_with_images", - switch_spatial_to_temporal_mix: bool = False, + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, + temporal_eps: Optional[float] = None, + merge_factor: float = 0.5, + merge_strategy="learned_with_images", + switch_spatial_to_temporal_mix: bool = False, ): super().__init__() @@ -805,10 +807,10 @@ def __init__( ) def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - image_only_indicator: Optional[torch.Tensor] = None, + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, ): num_frames = image_only_indicator.shape[-1] hidden_states = self.spatial_res_block(hidden_states, temb) @@ -852,10 +854,10 @@ class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] def __init__( - self, - alpha: float, - merge_strategy: str = "learned_with_images", - switch_spatial_to_temporal_mix: bool = False, + self, + alpha: float, + merge_strategy: str = "learned_with_images", + switch_spatial_to_temporal_mix: bool = False, ): super().__init__() self.merge_strategy = merge_strategy @@ -903,10 +905,10 @@ def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Ten return alpha def forward( - self, - x_spatial: torch.Tensor, - x_temporal: torch.Tensor, - image_only_indicator: Optional[torch.Tensor] = None, + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator, x_spatial.ndim) alpha = alpha.to(x_spatial.dtype) From dc7e6e814fcdeaac28cd3058ddb423295ac2c35d Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 30 Jul 2024 22:33:34 +0800 Subject: [PATCH 09/94] fix error --- .../models/autoencoders/autoencoder_kl3d.py | 1192 +++++++++-------- src/diffusers/models/resnet.py | 2 +- 2 files changed, 610 insertions(+), 584 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index 3c2f11662321..8f4244581e98 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -16,285 +16,352 @@ from .vae import DecoderOutput, DiagonalGaussianDistribution -class SaveConv3d(torch.nn.Conv3d): - """ - A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 - - # Set to 2GB, Suit for CuDNN - if memory_count > 2: - kernel_size = self.kernel_size[0] - part_num = int(memory_count / 2) + 1 - input_chunks = torch.chunk(input, part_num, dim=2) - - if kernel_size > 1: - input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) - for i in range(1, len(input_chunks)) - ] +## == Basic Block of 3D VAE Model design in CogVideoX === ### - output_chunks = [] - for input_chunk in input_chunks: - output_chunks.append(super(SaveConv3d, self).forward(input_chunk)) - output = torch.cat(output_chunks, dim=2) - return output - else: - return super(SaveConv3d, self).forward(input) +class Encoder3D(nn.Module): + r""" + The `Encoder3D` layer of a variational autoencoder that encodes its input into a latent representation. -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXCausalConv3d(nn.Module): - """ - A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. """ def __init__( self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - pad_mode: str = "constant", - **kwargs, + *, + in_channels: int = 3, + out_channels: int = 16, + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + num_res_blocks: int, + act_fn: str = "silu", + norm_num_groups: int = 32, + attn_resolutions=None, + dropout: float = 0.0, + resolution: int, + latent_channels: int, + double_z: bool = True, + pad_mode: str = "first", + temporal_compress_times: int = 4, ): - super().__init__() + super(Encoder3D, self).__init__() + if attn_resolutions is None: + attn_resolutions = [] + self.act_fn = get_activation(act_fn) + self.num_resolutions = len(block_out_channels) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions - def cast_tuple(t, length=1): - return t if isinstance(t, tuple) else ((t,) * length) + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) - kernel_size = cast_tuple(kernel_size, 3) + self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + curr_res = resolution + in_ch_mult = (block_out_channels[0],) + tuple(block_out_channels) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() - dilation = kwargs.pop("dilation", 1) - stride = kwargs.pop("stride", 1) + block_in = in_ch_mult[i_level] + block_out = block_out_channels[i_level] - self.pad_mode = pad_mode - time_pad = dilation * (time_kernel_size - 1) + (1 - stride) - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 + for i_block in range(self.num_res_blocks): + block.append( + CogVideoXResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=0, + non_linearity=act_fn, + dropout=dropout, + groups=norm_num_groups, + pad_mode=pad_mode, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level < self.temporal_compress_level: + down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) + else: + down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=False) + curr_res = curr_res // 2 + self.down.append(down) - self.height_pad = height_pad - self.width_pad = width_pad - self.time_pad = time_pad - self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + # middle + self.mid = nn.Module() + block_in = in_ch_mult[-1] + self.mid.block_1 = CogVideoXResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + non_linearity=act_fn, + temb_channels=0, + groups=norm_num_groups, + dropout=dropout, + pad_mode=pad_mode, + ) + if len(attn_resolutions) > 0: + self.mid.attn_1 = AttnBlock2D(block_in) + self.mid.block_2 = CogVideoXResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + non_linearity=act_fn, + temb_channels=0, + groups=norm_num_groups, + dropout=dropout, + pad_mode=pad_mode, + ) - stride = (stride, 1, 1) - dilation = (dilation, 1, 1) - self.conv = SaveConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - **kwargs, + self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=norm_num_groups, eps=1e-6) + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = CogVideoXCausalConv3d( + block_in, conv_out_channels if double_z else latent_channels, kernel_size=3, pad_mode=pad_mode ) - self.conv_cache = None + def forward(self, sample: torch.Tensor) -> torch.Tensor: + # timestep embedding - def forward(self, x): - if self.pad_mode == "constant": - causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_3d, mode="constant", value=0) - elif self.pad_mode == "first": - pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) - x = torch.cat([pad_x, x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - elif self.pad_mode == "reflect": - reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) - if reflect_x.shape[2] < self.time_pad: - reflect_x = torch.cat( - [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 - ) - x = torch.cat([reflect_x, x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - else: - raise ValueError("Invalid pad mode") - if self.time_pad != 0 and self.conv_cache is None: - self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() - return self.conv(x) - elif self.time_pad != 0 and self.conv_cache is not None: - x = torch.cat([self.conv_cache.to(x.device), x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - self.conv_cache = None - return self.conv(x) + temb = None - return self.conv(x) + # DownSampling + sample = self.conv_in(sample) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + sample = self.down[i_level].block[i_block](sample, temb) + if len(self.down[i_level].attn) > 0: + sample = self.down[i_level].attn[i_block](sample) + if i_level != self.num_resolutions - 1: + sample = self.down[i_level].downsample(sample) + # middle + sample = self.mid.block_1(sample, temb) -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXSpatialNorm3D(SpatialNorm3D): - """ - Use SaveConv3d instead of nn.Conv3d to avoid OOM in CogVideoX Model - """ + if len(self.attn_resolutions): + sample = self.mid.attn_1(sample) - def __init__( - self, - f_channels: int, - zq_channels: int, - ): - super().__init__(f_channels, zq_channels) - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv = SaveConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) - self.conv_y = SaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = SaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + sample = self.mid.block_2(sample, temb) - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - if zq.shape[2] > 1: - f_first, f_rest = f[:, :, :1], f[:, :, 1:] - f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] - z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] - z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) - z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) - zq = torch.cat([z_first, z_rest], dim=2) - else: - zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) - zq = self.conv(zq) - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f + # post-process + sample = self.norm_out(sample) + sample = self.act_fn(sample) + sample = self.conv_out(sample) + + return sample -# Todo: Create vae_3d.py such as vae.py file? -class UpSample3D(nn.Module): +class Decoder3D(nn.Module): r""" - The `UpSample` layer of a variational autoencoder that upsamples its input. + The `Decoder3D` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. """ def __init__( self, - in_channels: int, - out_channels: int, - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + *, + in_channels: int = 16, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + num_res_blocks: int, + attn_resolutions=None, + act_fn: str = "silu", + dropout: float = 0.0, + resolution: int, + latent_channels: int, + give_pre_end: bool = False, + pad_mode: str = "first", + temporal_compress_times: int = 4, + norm_num_groups=32, + ): + super(Decoder3D, self).__init__() + if attn_resolutions is None: + attn_resolutions = [] + self.act_fn = get_activation(act_fn) + self.num_resolutions = len(block_out_channels) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.give_pre_end = give_pre_end + self.attn_resolutions = attn_resolutions + self.norm_num_groups = norm_num_groups + self.temporal_compress_level = int(np.log2(temporal_compress_times)) - def forward(self, x: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `UpSample` class.""" + block_in = block_out_channels[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, latent_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = torch.nn.functional.interpolate(x, scale_factor=2.0) - x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + self.conv_in = CogVideoXCausalConv3d(latent_channels, block_in, kernel_size=3, pad_mode=pad_mode) - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + # middle + self.mid = nn.Module() + self.mid.block_1 = CogVideoXResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=0, + dropout=dropout, + non_linearity=act_fn, + latent_channels=latent_channels, + groups=norm_num_groups, + pad_mode=pad_mode, + ) + if len(attn_resolutions) > 0: + self.mid.attn_1 = AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups) - return x + self.mid.block_2 = CogVideoXResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=0, + dropout=dropout, + non_linearity=act_fn, + latent_channels=latent_channels, + groups=norm_num_groups, + pad_mode=pad_mode, + ) + # UpSampling -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXUpzSample3D(UpSample3D): - r""" - Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = block_out_channels[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + CogVideoXResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=0, + non_linearity=act_fn, + dropout=dropout, + latent_channels=latent_channels, + groups=norm_num_groups, + pad_mode=pad_mode, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = CogVideoXUpzSample3D( + in_channels=block_in, out_channels=block_in, compress_time=False + ) + else: + up.upsample = CogVideoXUpzSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) + curr_res = curr_res * 2 - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. - """ + self.up.insert(0, up) - def __init__(self, in_channels: int, out_channels: int, compress_time: bool = False): - super().__init__(in_channels, out_channels) + self.norm_out = CogVideoXSpatialNorm3D(f_channels=block_in, zq_channels=latent_channels) - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.compress_time = compress_time + self.conv_out = CogVideoXCausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) - def forward(self, x): - if self.compress_time: - if x.shape[2] > 1 and x.shape[2] % 2 == 1: - # split first frame - x_first, x_rest = x[:, :, 0], x[:, :, 1:] + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + # timestep embedding - x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0) - x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0) - x_first = x_first[:, :, None, :, :] - x = torch.cat([x_first, x_rest], dim=2) - elif x.shape[2] > 1: - x = torch.nn.functional.interpolate(x, scale_factor=2.0) - else: - x = x.squeeze(2) - x = torch.nn.functional.interpolate(x, scale_factor=2.0) - x = x[:, :, None, :, :] - else: - # only interpolate 2D - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = torch.nn.functional.interpolate(x, scale_factor=2.0) - x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + temb = None - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + hidden_states = self.conv_in(sample) - return x + # middle + hidden_states = self.mid.block_1(hidden_states, temb, sample) + if len(self.attn_resolutions) > 0: + hidden_states = self.mid.attn_1(hidden_states, sample) + hidden_states = self.mid.block_2(hidden_states, temb, sample) + # UpSampling -# Todo: Create vae_3d.py such as vae.py file? -class DownSample3D(nn.Module): + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden_states = self.up[i_level].block[i_block](hidden_states, temb, sample) + if len(self.up[i_level].attn) > 0: + hidden_states = self.up[i_level].attn[i_block](hidden_states, sample) + if i_level != 0: + hidden_states = self.up[i_level].upsample(hidden_states) + + # end + if self.give_pre_end: + return hidden_states + + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class UpSample3D(nn.Module): r""" - Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in CogVideoX Model. + The `UpSample` layer of a variational autoencoder that upsamples its input. Args: - in_channels (`int`, *optional*): + in_channels (`int`, *optional*, defaults to 3): The number of input channels. - out_channels (`int`, *optional*): + out_channels (`int`, *optional*, defaults to 3): The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. """ def __init__( self, in_channels: int, out_channels: int, - compress_time: bool = False, - ): - super(DownSample3D, self).__init__() + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) - self.compress_time = compress_time + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.compress_time: - b, c, t, h, w = x.shape - x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) - - if x.shape[-1] % 2 == 1: - # split first frame - x_first, x_rest = x[..., 0], x[..., 1:] - - if x_rest.shape[-1] > 0: - x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) - x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + r"""The forward method of the `UpSample` class.""" - else: - x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) b, c, t, h, w = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) x = self.conv(x) @@ -303,446 +370,403 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class CogVideoXResnetBlock3D(ResnetBlock3D): +## ==== After this is the special code of CogVideoX ==== ## + + +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXSaveConv3d(torch.nn.Conv3d): + """ + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + + # Set to 2GB, Suit for CuDNN + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) + + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(CogVideoXSaveConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(CogVideoXSaveConv3d, self).forward(input) + + +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXCausalConv3d(nn.Module): + """ + A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + """ + def __init__( self, in_channels: int, out_channels: int, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - eps: float = 1e-6, - non_linearity: str = "swish", - conv_shortcut: bool = False, - latent_channels: Optional[int] = None, - pad_mode: str = "first", + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode: str = "constant", + **kwargs, ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - eps=eps, - non_linearity=non_linearity, - conv_shortcut=conv_shortcut, - latent_channels=latent_channels, - ) + super().__init__() - out_channels = in_channels if out_channels is None else out_channels + def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) - self.in_channels = in_channels - self.out_channels = out_channels - self.act_fn = get_activation(non_linearity) - self.use_conv_shortcut = conv_shortcut + kernel_size = cast_tuple(kernel_size, 3) - if latent_channels is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) - else: - self.norm1 = CogVideoXSpatialNorm3D( - f_channels=in_channels, - zq_channels=latent_channels, - ) - self.norm2 = CogVideoXSpatialNorm3D( - f_channels=out_channels, - zq_channels=latent_channels, - ) - self.conv1 = CogVideoXCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - self.dropout = torch.nn.Dropout(dropout) + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) - self.conv2 = CogVideoXCausalConv3d( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = CogVideoXCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - else: - self.nin_shortcut = SaveConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 - ) + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - def forward( - self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, **kwargs - ) -> torch.Tensor: - hidden_states = input_tensor - if zq is not None: - hidden_states = self.norm1(hidden_states, zq) - else: - hidden_states = self.norm1(hidden_states) - hidden_states = self.non_linearity(hidden_states) - hidden_states = self.conv1(hidden_states) + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = CogVideoXSaveConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + **kwargs, + ) - if temb is not None: - hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] + self.conv_cache = None - if zq is not None: - hidden_states = self.norm2(hidden_states, zq) + def forward(self, x): + if self.pad_mode == "constant": + causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_3d, mode="constant", value=0) + elif self.pad_mode == "first": + pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) + x = torch.cat([pad_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + elif self.pad_mode == "reflect": + reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) + if reflect_x.shape[2] < self.time_pad: + reflect_x = torch.cat( + [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 + ) + x = torch.cat([reflect_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) else: - hidden_states = self.norm2(hidden_states) - hidden_states = self.non_linearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + raise ValueError("Invalid pad mode") + if self.time_pad != 0 and self.conv_cache is None: + self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() + return self.conv(x) + elif self.time_pad != 0 and self.conv_cache is not None: + x = torch.cat([self.conv_cache.to(x.device), x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + self.conv_cache = None + return self.conv(x) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - input_tensor = self.conv_shortcut(input_tensor) - else: - input_tensor = self.nin_shortcut(input_tensor) + return self.conv(x) - output_tensor = input_tensor + hidden_states - return output_tensor +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXSpatialNorm3D(SpatialNorm3D): + """ + Use CogVideoXSaveConv3d instead of nn.Conv3d to avoid OOM in CogVideoX Model + """ + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__(f_channels, zq_channels) + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv = CogVideoXSaveConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) + self.conv_y = CogVideoXSaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = CogVideoXSaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) -# Todo: Need refactor? @a-r-r-o-w -class AttnBlock2D(nn.Module): - def __init__(self, in_channels: int, norm_num_groups: int): - super().__init__() + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + if zq.shape[2] > 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) + z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f - self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x): - h_ = x - h_ = self.norm(h_) +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXUpzSample3D(UpSample3D): + r""" + Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. - b, c, t, h, w = h_.shape - h_ = h_.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) + def __init__(self, in_channels: int, out_channels: int, compress_time: bool = False): + super().__init__(in_channels, out_channels) - # compute attention - b_t, c, h, w = q.shape - q = q.reshape(b_t, c, h * w) - q = q.permute(0, 2, 1) # b_t, hw, c - k = k.reshape(b_t, c, h * w) # b_t, c, hw + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.compress_time = compress_time - # implement c**-0.5 on q - q = q * (int(c) ** (-0.5)) - w_ = torch.bmm(q, k) # b_t, hw, hw + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = x[:, :, 0], x[:, :, 1:] - w_ = torch.nn.functional.softmax(w_, dim=2) + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0) + x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + x = torch.cat([x_first, x_rest], dim=2) + elif x.shape[2] > 1: + x = torch.nn.functional.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + # only interpolate 2D + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - # attend to values - v = v.reshape(b_t, c, h * w) - w_ = w_.permute(0, 2, 1) # b_t, hw, hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b_t, c, hw (hw of q) - h_ = h_.reshape(b_t, c, h, w) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - h_ = self.proj_out(h_) + return x - h_ = h_.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) - return h_ +# Todo: Create vae_3d.py such as vae.py file? +class DownSample3D(nn.Module): + r""" + Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in CogVideoX Model. + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ -# Todo: Create vae_3d.py such as vae.py file? -class Encoder3D(nn.Module): def __init__( self, - *, - in_channels: int = 3, - out_channels: int = 16, - block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - num_res_blocks: int, - act_fn: str = "silu", - norm_num_groups: int = 32, - attn_resolutions=None, - dropout: float = 0.0, - resolution: int, - latent_channels: int, - double_z: bool = True, - pad_mode: str = "first", - temporal_compress_times: int = 4, + in_channels: int, + out_channels: int, + compress_time: bool = False, ): - super(Encoder3D, self).__init__() - if attn_resolutions is None: - attn_resolutions = [] - self.act_fn = get_activation(act_fn) - self.num_resolutions = len(block_out_channels) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.attn_resolutions = attn_resolutions - - # log2 of temporal_compress_times - self.temporal_compress_level = int(np.log2(temporal_compress_times)) - - self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) - - curr_res = resolution - in_ch_mult = (block_out_channels[0],) + tuple(block_out_channels) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - - block_in = in_ch_mult[i_level] - block_out = block_out_channels[i_level] - - for i_block in range(self.num_res_blocks): - block.append( - CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_out, - temb_channels=0, - non_linearity=act_fn, - dropout=dropout, - groups=norm_num_groups, - pad_mode=pad_mode, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - if i_level < self.temporal_compress_level: - down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) - else: - down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=False) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - block_in = in_ch_mult[-1] - self.mid.block_1 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - non_linearity=act_fn, - temb_channels=0, - groups=norm_num_groups, - dropout=dropout, - pad_mode=pad_mode, - ) - if len(attn_resolutions) > 0: - self.mid.attn_1 = AttnBlock2D(block_in) - self.mid.block_2 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - non_linearity=act_fn, - temb_channels=0, - groups=norm_num_groups, - dropout=dropout, - pad_mode=pad_mode, - ) - - self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=norm_num_groups, eps=1e-6) - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = CogVideoXCausalConv3d( - block_in, conv_out_channels if double_z else latent_channels, kernel_size=3, pad_mode=pad_mode - ) - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - # timestep embedding - temb = None - - # downsampling - sample = self.conv_in(sample) - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - sample = self.down[i_level].block[i_block](sample, temb) - if len(self.down[i_level].attn) > 0: - sample = self.down[i_level].attn[i_block](sample) - if i_level != self.num_resolutions - 1: - sample = self.down[i_level].downsample(sample) + super(DownSample3D, self).__init__() - # middle - sample = self.mid.block_1(sample, temb) + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.compress_time = compress_time - if len(self.attn_resolutions): - sample = self.mid.attn_1(sample) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) - sample = self.mid.block_2(sample, temb) + if x.shape[-1] % 2 == 1: + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] - # post-process - sample = self.norm_out(sample) - sample = self.act_fn(sample) - sample = self.conv_out(sample) + if x_rest.shape[-1] > 0: + x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - return sample + else: + x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) -# Todo: Create vae_3d.py such as vae.py file? -class Decoder3D(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + return x - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - norm_type (`str`, *optional*, defaults to `"group"`): - The normalization type to use. Can be either `"group"` or `"spatial"`. - """ +class CogVideoXResnetBlock3D(ResnetBlock3D): def __init__( self, - *, - in_channels: int = 16, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - num_res_blocks: int, - attn_resolutions=None, - act_fn: str = "silu", + in_channels: int, + out_channels: int, dropout: float = 0.0, - resolution: int, - latent_channels: int, - give_pre_end: bool = False, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + latent_channels: Optional[int] = None, pad_mode: str = "first", - temporal_compress_times: int = 4, - norm_num_groups=32, ): - super(Decoder3D, self).__init__() - if attn_resolutions is None: - attn_resolutions = [] - self.act_fn = get_activation(act_fn) - self.num_resolutions = len(block_out_channels) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.give_pre_end = give_pre_end - self.attn_resolutions = attn_resolutions - self.norm_num_groups = norm_num_groups + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + eps=eps, + non_linearity=non_linearity, + conv_shortcut=conv_shortcut, + latent_channels=latent_channels, + ) - # log2 of temporal_compress_times + out_channels = in_channels if out_channels is None else out_channels - self.temporal_compress_level = int(np.log2(temporal_compress_times)) + self.in_channels = in_channels + self.out_channels = out_channels + self.act_fn = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = block_out_channels[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, latent_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + if latent_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=latent_channels, + ) + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=latent_channels, + ) + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) - self.conv_in = CogVideoXCausalConv3d(latent_channels, block_in, kernel_size=3, pad_mode=pad_mode) + self.dropout = torch.nn.Dropout(dropout) - # middle - self.mid = nn.Module() - self.mid.block_1 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - temb_channels=0, - dropout=dropout, - non_linearity=act_fn, - latent_channels=latent_channels, - groups=norm_num_groups, - pad_mode=pad_mode, + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) - if len(attn_resolutions) > 0: - self.mid.attn_1 = AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups) - self.mid.block_2 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - temb_channels=0, - dropout=dropout, - non_linearity=act_fn, - latent_channels=latent_channels, - groups=norm_num_groups, - pad_mode=pad_mode, - ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + else: + self.nin_shortcut = CogVideoXSaveConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) - # UpSampling + def forward( + self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, *args, **kwargs + ) -> torch.Tensor: + hidden_states = input_tensor + if zq is not None: + hidden_states = self.norm1(hidden_states, zq) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv1(hidden_states) - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = block_out_channels[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_out, - temb_channels=0, - non_linearity=act_fn, - dropout=dropout, - latent_channels=latent_channels, - groups=norm_num_groups, - pad_mode=pad_mode, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - if i_level < self.num_resolutions - self.temporal_compress_level: - up.upsample = CogVideoXUpzSample3D( - in_channels=block_in, out_channels=block_in, compress_time=False - ) - else: - up.upsample = CogVideoXUpzSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) - curr_res = curr_res * 2 + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] - self.up.insert(0, up) + if zq is not None: + hidden_states = self.norm2(hidden_states, zq) + else: + hidden_states = self.norm2(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) - self.norm_out = CogVideoXSpatialNorm3D(f_channels=block_in, zq_channels=latent_channels) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + input_tensor = self.conv_shortcut(input_tensor) + else: + input_tensor = self.nin_shortcut(input_tensor) - self.conv_out = CogVideoXCausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) + output_tensor = input_tensor + hidden_states - def forward(self, sample: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `Decoder` class.""" - # timestep embedding + return output_tensor - temb = None - hidden_states = self.conv_in(sample) +# Todo: Need refactor? @a-r-r-o-w +class AttnBlock2D(nn.Module): + def __init__(self, in_channels: int, norm_num_groups: int): + super().__init__() - # middle - hidden_states = self.mid.block_1(hidden_states, temb, sample) - if len(self.attn_resolutions) > 0: - hidden_states = self.mid.attn_1(hidden_states, sample) - hidden_states = self.mid.block_2(hidden_states, temb, sample) + self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - # UpSampling + def forward(self, x): + h_ = x + h_ = self.norm(h_) - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - hidden_states = self.up[i_level].block[i_block](hidden_states, temb, sample) - if len(self.up[i_level].attn) > 0: - hidden_states = self.up[i_level].attn[i_block](hidden_states, sample) - if i_level != 0: - hidden_states = self.up[i_level].upsample(hidden_states) + b, c, t, h, w = h_.shape + h_ = h_.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - # end - if self.give_pre_end: - return hidden_states + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) - hidden_states = self.norm_out(hidden_states, sample) - hidden_states = self.act_fn(hidden_states) - hidden_states = self.conv_out(hidden_states) + # compute attention + b_t, c, h, w = q.shape + q = q.reshape(b_t, c, h * w) + q = q.permute(0, 2, 1) # b_t, hw, c + k = k.reshape(b_t, c, h * w) # b_t, c, hw - return hidden_states + # implement c**-0.5 on q + q = q * (int(c) ** (-0.5)) + w_ = torch.bmm(q, k) # b_t, hw, hw + + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b_t, c, h * w) + w_ = w_.permute(0, 2, 1) # b_t, hw, hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b_t, c, hw (hw of q) + h_ = h_.reshape(b_t, c, h, w) + + h_ = self.proj_out(h_) + + h_ = h_.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) + + return h_ class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -828,8 +852,10 @@ def __init__( resolution=sample_size, latent_channels=latent_channels, ) - self.quant_conv = SaveConv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None - self.post_quant_conv = SaveConv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + self.quant_conv = CogVideoXSaveConv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = ( + CogVideoXSaveConv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + ) self.use_slicing = False self.use_tiling = False diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 6acc59a5f3b3..58c99ac48564 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -460,7 +460,7 @@ def __init__( ) def forward( - self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, **kwargs + self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, *args, **kwargs ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." From aff72ec5dc5e459ca108cb23af2f68fe719f9296 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 30 Jul 2024 22:43:09 +0800 Subject: [PATCH 10/94] Update autoencoder_kl3d.py --- src/diffusers/models/autoencoders/autoencoder_kl3d.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index 8f4244581e98..e250fcca80a1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -521,7 +521,8 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXUpzSample3D(UpSample3D): r""" - Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. + Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX + Model. Args: in_channels (`int`, *optional*, defaults to 3): @@ -572,7 +573,8 @@ def forward(self, x): # Todo: Create vae_3d.py such as vae.py file? class DownSample3D(nn.Module): r""" - Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in CogVideoX Model. + Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in + CogVideoX Model. Args: in_channels (`int`, *optional*): From cb5348a0c2162232a832619c338a53a2d42c5f5b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 01:26:32 +0200 Subject: [PATCH 11/94] fix nasty bug in 3d sincos pos embeds --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 76e007b09af3..ba25c104fdfe 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -103,7 +103,7 @@ def get_3d_sincos_pos_embed( # 1. Spatial grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale - grid_w = np.arange(spatial_size[0], dtype=np.float32) / temporal_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) From e9828817161637634826e873ab83dcc881f8cfb0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 02:52:53 +0200 Subject: [PATCH 12/94] refactor --- src/diffusers/models/normalization.py | 24 ++++ .../transformers/cogvideox_transformer_3d.py | 117 ++++++------------ 2 files changed, 65 insertions(+), 76 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 4e532f3fc990..2bb84ad8a2e6 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -289,6 +289,30 @@ def forward( return x +class CogVideoXLayerNormZero(nn.Module): + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shift_msa, scale_msa, gate_msa, enc_shift_msa, enc_scale_msa, enc_gate_msa = self.linear(self.silu(temb)).chunk(6, dim=1) + print("adaln's:", shift_msa.float().sum(), scale_msa.float().sum(), gate_msa.float().sum()) + print("adaln's:", enc_shift_msa.float().sum(), enc_scale_msa.float().sum(), enc_gate_msa.float().sum()) + hidden_states = self.norm(hidden_states) * (1 + scale_msa)[:, None, :] + shift_msa[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale_msa)[:, None, :] + enc_shift_msa[:, None, :] + return hidden_states, encoder_hidden_states, gate_msa[:, None, :], enc_gate_msa[:, None, :] + + if is_torch_version(">=", "2.1.0"): LayerNorm = nn.LayerNorm else: diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 303d25043ea5..c3c2e6fb6e19 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -24,6 +24,7 @@ from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from ..normalization import CogVideoXLayerNormZero logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,26 +45,6 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output -class AdaLayerNorm(nn.Module): - r""" - Norm layer modified to incorporate timestep embeddings. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the embeddings dictionary. - """ - - def __init__(self, embedding_dim: int, output_dim: int): - super().__init__() - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, output_dim) - - def forward(self, emb: torch.Tensor) -> torch.Tensor: - x = self.silu(emb.to(torch.float32)).to(emb.dtype) - x = self.linear(x) - return x - - @maybe_allow_in_graph class CogVideoXBlock(nn.Module): r""" @@ -121,10 +102,8 @@ def __init__( ): super().__init__() - self.norm0 = AdaLayerNorm(time_embed_dim, 12 * dim) - # 1. Self Attention - self.norm1 = nn.LayerNorm(dim, norm_eps) + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, @@ -138,7 +117,7 @@ def __init__( ) # 2. Feed Forward - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, @@ -158,9 +137,6 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_size = chunk_size self._chunk_dim = dim - def _modulate(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - def forward( self, hidden_states: torch.Tensor, @@ -168,59 +144,35 @@ def forward( temb: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - ( - shift_msa, - scale_msa, - gate_msa, - shift_ff, - scale_ff, - gate_mlp, - enc_shift_msa, - enc_scale_msa, - enc_gate_msa, - enc_shift_ff, - enc_scale_ff, - enc_gate_mlp, - ) = self.norm0(temb).chunk(12, dim=1) - gate_msa, gate_mlp, enc_gate_msa, enc_gate_mlp = ( - gate_msa.unsqueeze(1), - gate_mlp.unsqueeze(1), - enc_gate_msa.unsqueeze(1), - enc_gate_mlp.unsqueeze(1), - ) - - # norm & modulate - norm_hidden_states = self.norm1(hidden_states) - norm_encoder_hidden_states = self.norm1(encoder_hidden_states) - - norm_hidden_states = self._modulate(norm_hidden_states, shift_msa, scale_msa) - norm_encoder_hidden_states = self._modulate(norm_encoder_hidden_states, enc_shift_msa, enc_scale_msa) + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) + print("norm and modulate 1:", norm_hidden_states.float().sum(), norm_encoder_hidden_states.float().sum()) # attention text_length = norm_encoder_hidden_states.size(1) norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + print("attention_input:", norm_hidden_states.float().sum()) attn_output = self.attn1(norm_hidden_states, attention_mask=attention_mask) + print("attention_output:", attn_output.float().sum()) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] # norm & modulate - norm_hidden_states = self.norm2(hidden_states) - norm_encoder_hidden_states = self.norm2(encoder_hidden_states) - - norm_hidden_states = self._modulate(norm_hidden_states, shift_ff, scale_ff) - norm_encoder_hidden_states = self._modulate(norm_encoder_hidden_states, enc_shift_ff, enc_scale_ff) + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) + print("norm and modulate 2:", norm_hidden_states.float().sum(), norm_encoder_hidden_states.float().sum()) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + print("ff_input:", norm_hidden_states.float().sum()) if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) + print("ff_output:", ff_output.float().sum()) - hidden_states = hidden_states + gate_mlp * ff_output[:, text_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_mlp * ff_output[:, :text_length] + hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] return hidden_states, encoder_hidden_states @@ -317,9 +269,7 @@ def __init__( temporal_interpolation_scale, ) spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = nn.Parameter( - torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim), requires_grad=False - ) + pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) self.register_buffer("pos_embedding", pos_embedding, persistent=False) @@ -349,8 +299,11 @@ def __init__( self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) # 5. Output blocks - self.adaln_out = AdaLayerNorm(time_embed_dim, 2 * inner_dim) - self.norm_out = nn.LayerNorm(inner_dim, 1e-6, norm_elementwise_affine) + self.adaln_out = nn.Sequential( + nn.SiLU(), + nn.Linear(time_embed_dim, 2 * inner_dim) + ) + self.norm_out = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -358,9 +311,6 @@ def __init__( def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value - def _modulate(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - def forward( self, sample: torch.Tensor, @@ -396,21 +346,30 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) - # 3. Patch embedding + print("temb:", emb.float().sum()) + + # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, sample) - # 4. Position embedding + print("hidden_states patch_embeds:", hidden_states.float().sum()) + + # 3. Position embedding seq_length = height * width * num_frames // (self.config.patch_size**2) text_seq_length = encoder_hidden_states.size(1) - pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] + pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) + print("hidden_states pos_embeds", hidden_states.float().sum(), pos_embeds.float().sum()) + encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - # 2. Prepare attention mask + print("encoder_hidden_states:", encoder_hidden_states.float().sum()) + print("hidden_states:", hidden_states.float().sum()) + + # 4. Prepare attention mask if attention_mask is None: attention_mask = torch.ones(batch_size, self.num_patches + self.config.max_text_seq_length) attention_mask = attention_mask.to(device=sample.device, dtype=sample.dtype) @@ -442,19 +401,25 @@ def custom_forward(*inputs): attention_mask=attention_mask, ) + print("loop i:", i, torch.cat([encoder_hidden_states, hidden_states], dim=1).float().sum()) + hidden_states = self.norm_final(hidden_states) + print("norm_final:", hidden_states.float().sum()) # 6. Final block shift, scale = self.adaln_out(emb).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - hidden_states = self._modulate(hidden_states, shift, scale) + print("adaln_out:", shift.float().sum(), scale.float().sum()) + hidden_states = self.norm_out(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + print("modulate out:", hidden_states.float().sum()) hidden_states = self.proj_out(hidden_states) + print("proj_out:", hidden_states.float().sum()) # 7. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, p, p, self.config.out_channels) output = output.permute(0, 1, 6, 2, 4, 3, 5).flatten(5, 6).flatten(3, 4) + print("output:", output.float().sum()) if not return_dict: - return output + return (output,) return Transformer2DModelOutput(sample=output) From d963b1aaa4a082cc54bdc5fbc41a1350309e807e Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 02:54:32 +0200 Subject: [PATCH 13/94] update conversion script for latest modeling changes --- scripts/convert_cogvideox_to_diffusers.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index f334924a92a3..084a8cce48cb 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -31,8 +31,18 @@ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]) - def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tuple[str, nn.Module]]: layer_id, _, weight_or_bias = key.split(".")[-3:] - new_key = f"transformer_blocks.{layer_id}.norm0.linear.{weight_or_bias}" - state_dict[new_key] = state_dict.pop(key) + + weights_or_biases = state_dict[key].chunk(12, dim=0) + norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) + norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) + + norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" + state_dict[norm1_key] = norm1_weights_or_biases + + norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" + state_dict[norm2_key] = norm2_weights_or_biases + + state_dict.pop(key) TRANSFORMER_KEYS_RENAME_DICT = { @@ -44,14 +54,14 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tu "dense_4h_to_h": "2", ".layers": "", "dense": "to_out.0", - "input_layernorm": "norm1", - "post_attn1_layernorm": "norm2", + "input_layernorm": "norm1.norm", + "post_attn1_layernorm": "norm2.norm", "time_embed.0": "time_embedding.linear_1", "time_embed.2": "time_embedding.linear_2", "mixins.patch_embed": "patch_embed", "mixins.final_layer.norm_final": "norm_out", "mixins.final_layer.linear": "proj_out", - "mixins.final_layer.adaLN_modulation.1": "adaln_out.linear", + "mixins.final_layer.adaLN_modulation.1": "adaln_out.1", } TRANSFORMER_SPECIAL_KEYS_REMAP = { From 16967589d83bc8dde4132098fb2a45de993fa0de Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 02:54:41 +0200 Subject: [PATCH 14/94] remove debug prints --- src/diffusers/models/normalization.py | 2 -- .../transformers/cogvideox_transformer_3d.py | 22 ------------------- 2 files changed, 24 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 2bb84ad8a2e6..4045701aed53 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -306,8 +306,6 @@ def __init__( def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: shift_msa, scale_msa, gate_msa, enc_shift_msa, enc_scale_msa, enc_gate_msa = self.linear(self.silu(temb)).chunk(6, dim=1) - print("adaln's:", shift_msa.float().sum(), scale_msa.float().sum(), gate_msa.float().sum()) - print("adaln's:", enc_shift_msa.float().sum(), enc_scale_msa.float().sum(), enc_gate_msa.float().sum()) hidden_states = self.norm(hidden_states) * (1 + scale_msa)[:, None, :] + shift_msa[:, None, :] encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale_msa)[:, None, :] + enc_shift_msa[:, None, :] return hidden_states, encoder_hidden_states, gate_msa[:, None, :], enc_gate_msa[:, None, :] diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c3c2e6fb6e19..ad29445c9b6a 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -145,31 +145,25 @@ def forward( attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) - print("norm and modulate 1:", norm_hidden_states.float().sum(), norm_encoder_hidden_states.float().sum()) # attention text_length = norm_encoder_hidden_states.size(1) norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - print("attention_input:", norm_hidden_states.float().sum()) attn_output = self.attn1(norm_hidden_states, attention_mask=attention_mask) - print("attention_output:", attn_output.float().sum()) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) - print("norm and modulate 2:", norm_hidden_states.float().sum(), norm_encoder_hidden_states.float().sum()) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - print("ff_input:", norm_hidden_states.float().sum()) if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) - print("ff_output:", ff_output.float().sum()) hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] @@ -346,13 +340,9 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) - print("temb:", emb.float().sum()) - # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, sample) - print("hidden_states patch_embeds:", hidden_states.float().sum()) - # 3. Position embedding seq_length = height * width * num_frames // (self.config.patch_size**2) text_seq_length = encoder_hidden_states.size(1) @@ -361,14 +351,9 @@ def forward( hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) - print("hidden_states pos_embeds", hidden_states.float().sum(), pos_embeds.float().sum()) - encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - print("encoder_hidden_states:", encoder_hidden_states.float().sum()) - print("hidden_states:", hidden_states.float().sum()) - # 4. Prepare attention mask if attention_mask is None: attention_mask = torch.ones(batch_size, self.num_patches + self.config.max_text_seq_length) @@ -401,24 +386,17 @@ def custom_forward(*inputs): attention_mask=attention_mask, ) - print("loop i:", i, torch.cat([encoder_hidden_states, hidden_states], dim=1).float().sum()) - hidden_states = self.norm_final(hidden_states) - print("norm_final:", hidden_states.float().sum()) # 6. Final block shift, scale = self.adaln_out(emb).chunk(2, dim=1) - print("adaln_out:", shift.float().sum(), scale.float().sum()) hidden_states = self.norm_out(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] - print("modulate out:", hidden_states.float().sum()) hidden_states = self.proj_out(hidden_states) - print("proj_out:", hidden_states.float().sum()) # 7. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, p, p, self.config.out_channels) output = output.permute(0, 1, 6, 2, 4, 3, 5).flatten(5, 6).flatten(3, 4) - print("output:", output.float().sum()) if not return_dict: return (output,) From 21a0fc1b0d494ac6306050b3c78cafc260418432 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 02:55:08 +0200 Subject: [PATCH 15/94] make style --- scripts/convert_cogvideox_to_diffusers.py | 4 ++-- src/diffusers/models/normalization.py | 14 ++++++++++---- .../transformers/cogvideox_transformer_3d.py | 13 +++++++------ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 084a8cce48cb..a80e4ecdf155 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -38,10 +38,10 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tu norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" state_dict[norm1_key] = norm1_weights_or_biases - + norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" state_dict[norm2_key] = norm2_weights_or_biases - + state_dict.pop(key) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 4045701aed53..ecb831a3390d 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -303,11 +303,17 @@ def __init__( self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) - - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - shift_msa, scale_msa, gate_msa, enc_shift_msa, enc_scale_msa, enc_gate_msa = self.linear(self.silu(temb)).chunk(6, dim=1) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift_msa, scale_msa, gate_msa, enc_shift_msa, enc_scale_msa, enc_gate_msa = self.linear( + self.silu(temb) + ).chunk(6, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale_msa)[:, None, :] + shift_msa[:, None, :] - encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale_msa)[:, None, :] + enc_shift_msa[:, None, :] + encoder_hidden_states = ( + self.norm(encoder_hidden_states) * (1 + enc_scale_msa)[:, None, :] + enc_shift_msa[:, None, :] + ) return hidden_states, encoder_hidden_states, gate_msa[:, None, :], enc_gate_msa[:, None, :] diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index ad29445c9b6a..8263afa1ed0d 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -144,7 +144,9 @@ def forward( temb: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) # attention text_length = norm_encoder_hidden_states.size(1) @@ -155,7 +157,9 @@ def forward( encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) @@ -293,10 +297,7 @@ def __init__( self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) # 5. Output blocks - self.adaln_out = nn.Sequential( - nn.SiLU(), - nn.Linear(time_embed_dim, 2 * inner_dim) - ) + self.adaln_out = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * inner_dim)) self.norm_out = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) From d83c1f844782ecaae5c0b9a6576c71deae79780c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 29 Jul 2024 12:37:47 +0530 Subject: [PATCH 16/94] add workflow to rebase with upstream main nightly. --- .github/workflows/upstream.yml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/upstream.yml diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml new file mode 100644 index 000000000000..331e81e5742f --- /dev/null +++ b/.github/workflows/upstream.yml @@ -0,0 +1,32 @@ +name: Rebase Upstream + +on: + schedule: + - cron: '0 0 * * *' # This runs the job nightly at midnight UTC + workflow_dispatch: + pull_request: + +permissions: + contents: write + +jobs: + rebase: + runs-on: ubuntu-latest + + steps: + - name: Checkout private repository + uses: actions/checkout@v2 + with: + ref: main + + - name: Fetch upstream changes + run: git fetch upstream + + - name: Rebase onto upstream main + run: git rebase upstream/main + + - name: Push changes to private main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + git push origin main --force From dfeb32975d0fda8933a355f02a24e9ceeaaf49e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 29 Jul 2024 12:39:27 +0530 Subject: [PATCH 17/94] add upstream --- .github/workflows/upstream.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml index 331e81e5742f..1965b9e9492b 100644 --- a/.github/workflows/upstream.yml +++ b/.github/workflows/upstream.yml @@ -19,6 +19,9 @@ jobs: with: ref: main + - name: Add upstream repository + run: git remote add upstream https://github.com/huggingface/diffusers.git + - name: Fetch upstream changes run: git fetch upstream From 71bcb1e1c54becb958cee98499d48e13d28fbd1b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 29 Jul 2024 12:46:46 +0530 Subject: [PATCH 18/94] Revert "add workflow to rebase with upstream main nightly." --- .github/workflows/upstream.yml | 35 ---------------------------------- 1 file changed, 35 deletions(-) delete mode 100644 .github/workflows/upstream.yml diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml deleted file mode 100644 index 1965b9e9492b..000000000000 --- a/.github/workflows/upstream.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Rebase Upstream - -on: - schedule: - - cron: '0 0 * * *' # This runs the job nightly at midnight UTC - workflow_dispatch: - pull_request: - -permissions: - contents: write - -jobs: - rebase: - runs-on: ubuntu-latest - - steps: - - name: Checkout private repository - uses: actions/checkout@v2 - with: - ref: main - - - name: Add upstream repository - run: git remote add upstream https://github.com/huggingface/diffusers.git - - - name: Fetch upstream changes - run: git fetch upstream - - - name: Rebase onto upstream main - run: git rebase upstream/main - - - name: Push changes to private main - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - git push origin main --force From 0980f4dcd2627bcaa20c4d5ce6503864f6f6c433 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 29 Jul 2024 12:51:46 +0530 Subject: [PATCH 19/94] add workflow for rebasing with upstream automatically. --- .github/workflows/upstream.yaml | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 .github/workflows/upstream.yaml diff --git a/.github/workflows/upstream.yaml b/.github/workflows/upstream.yaml new file mode 100644 index 000000000000..124cd50361d5 --- /dev/null +++ b/.github/workflows/upstream.yaml @@ -0,0 +1,36 @@ + +name: Rebase Upstream + +on: + schedule: + - cron: '0 0 * * *' # This runs the job nightly at midnight UTC + workflow_dispatch: + +permissions: + contents: write + +jobs: + rebase: + runs-on: ubuntu-latest + + steps: + - name: Checkout private repository + uses: actions/checkout@v2 + with: + ref: main + + - name: Add upstream repository + run: git remote add upstream https://github.com/huggingface/diffusers.git + + - name: Fetch upstream changes + run: git fetch upstream + + - name: Rebase onto upstream main + run: git rebase upstream/main + + - name: Push changes to private main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + git push origin main --force + From ee40f0e1ca6011fc976de37f137c3e6a7df788d7 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Wed, 31 Jul 2024 14:31:21 +0800 Subject: [PATCH 20/94] follow review guide --- .../models/autoencoders/autoencoder_kl3d.py | 125 ++++++++---------- 1 file changed, 57 insertions(+), 68 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index e250fcca80a1..9f469293ec4a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -45,17 +45,16 @@ class Encoder3D(nn.Module): def __init__( self, - *, in_channels: int = 3, out_channels: int = 16, block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - num_res_blocks: int, + layers_per_block: int = 3, act_fn: str = "silu", norm_num_groups: int = 32, attn_resolutions=None, dropout: float = 0.0, - resolution: int, - latent_channels: int, + resolution: int = 256, + latent_channels: int = 16, double_z: bool = True, pad_mode: str = "first", temporal_compress_times: int = 4, @@ -65,7 +64,7 @@ def __init__( attn_resolutions = [] self.act_fn = get_activation(act_fn) self.num_resolutions = len(block_out_channels) - self.num_res_blocks = num_res_blocks + self.layers_per_block = layers_per_block self.resolution = resolution self.attn_resolutions = attn_resolutions @@ -84,7 +83,7 @@ def __init__( block_in = in_ch_mult[i_level] block_out = block_out_channels[i_level] - for i_block in range(self.num_res_blocks): + for i_block in range(self.layers_per_block): block.append( CogVideoXResnetBlock3D( in_channels=block_in, @@ -148,7 +147,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: # DownSampling sample = self.conv_in(sample) for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): + for i_block in range(self.layers_per_block): sample = self.down[i_level].block[i_block](sample, temb) if len(self.down[i_level].attn) > 0: sample = self.down[i_level].attn[i_block](sample) @@ -196,16 +195,15 @@ class Decoder3D(nn.Module): def __init__( self, - *, in_channels: int = 16, out_channels: int = 3, block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - num_res_blocks: int, + layers_per_block: int = 3, attn_resolutions=None, act_fn: str = "silu", dropout: float = 0.0, - resolution: int, - latent_channels: int, + resolution: int = 256, + latent_channels: int = 16, give_pre_end: bool = False, pad_mode: str = "first", temporal_compress_times: int = 4, @@ -216,7 +214,7 @@ def __init__( attn_resolutions = [] self.act_fn = get_activation(act_fn) self.num_resolutions = len(block_out_channels) - self.num_res_blocks = num_res_blocks + self.layers_per_block = layers_per_block self.resolution = resolution self.give_pre_end = give_pre_end self.attn_resolutions = attn_resolutions @@ -263,7 +261,7 @@ def __init__( block = nn.ModuleList() attn = nn.ModuleList() block_out = block_out_channels[i_level] - for i_block in range(self.num_res_blocks + 1): + for i_block in range(self.layers_per_block + 1): block.append( CogVideoXResnetBlock3D( in_channels=block_in, @@ -314,7 +312,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: # UpSampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): + for i_block in range(self.layers_per_block + 1): hidden_states = self.up[i_level].block[i_block](hidden_states, temb, sample) if len(self.up[i_level].attn) > 0: hidden_states = self.up[i_level].attn[i_block](hidden_states, sample) @@ -359,7 +357,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, t, h, w = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = F.interpolate(x, scale_factor=2.0) x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) b, c, t, h, w = x.shape @@ -507,11 +505,11 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] - z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) - z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) + z_first = F.interpolate(z_first, size=f_first_size) + z_rest = F.interpolate(z_rest, size=f_rest_size) zq = torch.cat([z_first, z_rest], dim=2) else: - zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) + zq = F.interpolate(zq, size=f.shape[-3:]) zq = self.conv(zq) norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) @@ -545,21 +543,21 @@ def forward(self, x): # split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:] - x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0) - x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0) + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) x_first = x_first[:, :, None, :, :] x = torch.cat([x_first, x_rest], dim=2) elif x.shape[2] > 1: - x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = F.interpolate(x, scale_factor=2.0) else: x = x.squeeze(2) - x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = F.interpolate(x, scale_factor=2.0) x = x[:, :, None, :, :] else: # only interpolate 2D b, c, t, h, w = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = torch.nn.functional.interpolate(x, scale_factor=2.0) + x = F.interpolate(x, scale_factor=2.0) x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) b, c, t, h, w = x.shape @@ -606,16 +604,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: - x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) else: - x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + x = F.avg_pool1d(x, kernel_size=2, stride=2) x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = F.pad(x, pad, mode="constant", value=0) b, c, t, h, w = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) x = self.conv(x) @@ -730,50 +728,41 @@ def __init__(self, in_channels: int, norm_num_groups: int): super().__init__() self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - - b, c, t, h, w = h_.shape - h_ = h_.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b_t, c, h, w = q.shape - q = q.reshape(b_t, c, h * w) - q = q.permute(0, 2, 1) # b_t, hw, c - k = k.reshape(b_t, c, h * w) # b_t, c, hw - - # implement c**-0.5 on q - q = q * (int(c) ** (-0.5)) - w_ = torch.bmm(q, k) # b_t, hw, hw - - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b_t, c, h * w) - w_ = w_.permute(0, 2, 1) # b_t, hw, hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b_t, c, hw (hw of q) - h_ = h_.reshape(b_t, c, h, w) - - h_ = self.proj_out(h_) - - h_ = h_.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) - - return h_ + self.to_q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.to_k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.to_v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + +def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + # compute attention + batch_frames, num_channels, height, width = query.shape + query = query.reshape(batch_frames, num_channels, height * width) + query = query.permute(0, 2, 1) # b_t, hw, c + key = key.reshape(batch_frames, num_channels, height * width) # b_t, c, hw + # implement c**-0.5 on q + query = query * (int(num_channels) ** (-0.5)) + context = torch.bmm(query, key) # b_t, hw, hw + context = F.softmax(context, dim=2) + # attend to values + value = value.reshape(batch_frames, num_channels, height * width) + context = context.permute(0, 2, 1) # b_t, hw, hw (first hw of k, second of q) + hidden_states = torch.bmm(value, context) # b_t, c, hw (hw of q) + hidden_states = hidden_states.reshape(batch_frames, num_channels, height, width) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_size, num_frames, num_channels, height, width).permute(0, 2, 1, 3, 4) + return hidden_states class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" - A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + A VAE model with KL loss for encodfing images into latents and decoding latent representations into images. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -837,7 +826,7 @@ def __init__( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, - num_res_blocks=layers_per_block, + layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, @@ -850,7 +839,7 @@ def __init__( block_out_channels=block_out_channels, norm_num_groups=norm_num_groups, act_fn=act_fn, - num_res_blocks=layers_per_block, + layers_per_block=layers_per_block, resolution=sample_size, latent_channels=latent_channels, ) From 8fe54bcd266b11fc158e64f0a962b77806fa4da1 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Wed, 31 Jul 2024 16:42:47 +0800 Subject: [PATCH 21/94] add --- .../models/autoencoders/autoencoder_kl3d.py | 269 ++++++++++++------ src/diffusers/models/resnet.py | 7 +- 2 files changed, 179 insertions(+), 97 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index 9f469293ec4a..5a434ac16e71 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -19,6 +19,155 @@ ## == Basic Block of 3D VAE Model design in CogVideoX === ### +## Draft of block +# class DownEncoderBlock3D(nn.Module): +# def __init__( +# self, +# in_channels: int, +# out_channels: int, +# dropout: float = 0.0, +# num_layers: int = 1, +# resnet_eps: float = 1e-6, +# resnet_act_fn: str = "swish", +# resnet_groups: int = 32, +# resnet_pre_norm: bool = True, +# pad_mode: str = "first", +# ): +# super().__init__() +# resnets = [] +# +# for i in range(num_layers): +# resnets.append( +# CogVideoXResnetBlock3D( +# in_channels=in_channels if i == 0 else out_channels, +# out_channels=out_channels, +# temb_channels=0, +# eps=resnet_eps, +# groups=resnet_groups, +# dropout=dropout, +# non_linearity=resnet_act_fn, +# conv_shortcut=resnet_pre_norm, +# pad_mode=pad_mode, +# ) +# ) +# in_channels = out_channels +# +# self.resnets = nn.ModuleList(resnets) +# self.downsampler = DownSample3D(in_channels=out_channels, out_channels=out_channels) if num_layers > 0 else None +# +# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: +# for resnet in self.resnets: +# hidden_states = resnet(hidden_states, temb=None) +# +# if self.downsampler is not None: +# hidden_states = self.downsampler(hidden_states) +# +# return hidden_states +# +# +# class Encoder3D(nn.Module): +# def __init__( +# self, +# in_channels: int = 3, +# out_channels: int = 16, +# down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), +# block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), +# layers_per_block: int = 3, +# act_fn: str = "silu", +# norm_num_groups: int = 32, +# dropout: float = 0.0, +# resolution: int = 256, +# double_z: bool = True, +# pad_mode: str = "first", +# temporal_compress_times: int = 4, +# ): +# super().__init__() +# self.act_fn = get_activation(act_fn) +# self.num_resolutions = len(block_out_channels) +# self.layers_per_block = layers_per_block +# self.resolution = resolution +# +# # log2 of temporal_compress_times +# self.temporal_compress_level = int(np.log2(temporal_compress_times)) +# +# self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) +# +# self.down_blocks = nn.ModuleList() +# self.downsamples = nn.ModuleList() +# +# for i_level in range(self.num_resolutions): +# block_in = block_out_channels[i_level - 1] if i_level > 0 else block_out_channels[0] +# block_out = block_out_channels[i_level] +# is_final_block = i_level == self.num_resolutions - 1 +# +# down_block = DownEncoderBlock3D( +# in_channels=block_in, +# out_channels=block_out, +# num_layers=self.layers_per_block, +# dropout=dropout, +# resnet_eps=1e-6, +# resnet_act_fn=act_fn, +# resnet_groups=norm_num_groups, +# resnet_pre_norm=True, +# pad_mode=pad_mode, +# ) +# self.down_blocks.append(down_block) +# +# if not is_final_block: +# compress_time = i_level < self.temporal_compress_level +# self.downsamples.append( +# DownSample3D(in_channels=block_out, out_channels=block_out, compress_time=compress_time) +# ) +# +# # middle +# block_in = block_out_channels[-1] +# self.mid_block_1 = CogVideoXResnetBlock3D( +# in_channels=block_in, +# out_channels=block_in, +# non_linearity=act_fn, +# temb_channels=0, +# groups=norm_num_groups, +# dropout=dropout, +# pad_mode=pad_mode, +# ) +# self.mid_block_2 = CogVideoXResnetBlock3D( +# in_channels=block_in, +# out_channels=block_in, +# non_linearity=act_fn, +# temb_channels=0, +# groups=norm_num_groups, +# dropout=dropout, +# pad_mode=pad_mode, +# ) +# +# # out +# self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) +# self.conv_act = get_activation(act_fn) +# +# conv_out_channels = 2 * out_channels if double_z else out_channels +# self.conv_out = CogVideoXCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, pad_mode=pad_mode) +# +# def forward(self, sample: torch.Tensor) -> torch.Tensor: +# temb = None +# +# # DownSampling +# sample = self.conv_in(sample) +# for i_level in range(self.num_resolutions): +# sample = self.down_blocks[i_level](sample) +# if i_level < len(self.downsamples): +# sample = self.downsamples[i_level](sample) +# +# sample = self.mid_block_1(sample, temb) +# sample = self.mid_block_2(sample, temb) +# +# # post-process +# sample = self.norm_out(sample) +# sample = self.conv_act(sample) +# sample = self.conv_out(sample) +# +# return sample + + class Encoder3D(nn.Module): r""" The `Encoder3D` layer of a variational autoencoder that encodes its input into a latent representation. @@ -47,26 +196,22 @@ def __init__( self, in_channels: int = 3, out_channels: int = 16, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", norm_num_groups: int = 32, - attn_resolutions=None, dropout: float = 0.0, resolution: int = 256, - latent_channels: int = 16, double_z: bool = True, pad_mode: str = "first", temporal_compress_times: int = 4, ): super(Encoder3D, self).__init__() - if attn_resolutions is None: - attn_resolutions = [] self.act_fn = get_activation(act_fn) self.num_resolutions = len(block_out_channels) self.layers_per_block = layers_per_block self.resolution = resolution - self.attn_resolutions = attn_resolutions # log2 of temporal_compress_times self.temporal_compress_level = int(np.log2(temporal_compress_times)) @@ -78,7 +223,6 @@ def __init__( self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() - attn = nn.ModuleList() block_in = in_ch_mult[i_level] block_out = block_out_channels[i_level] @@ -96,11 +240,9 @@ def __init__( ) ) block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) down = nn.Module() down.block = block - down.attn = attn + if i_level != self.num_resolutions - 1: if i_level < self.temporal_compress_level: down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) @@ -121,8 +263,6 @@ def __init__( dropout=dropout, pad_mode=pad_mode, ) - if len(attn_resolutions) > 0: - self.mid.attn_1 = AttnBlock2D(block_in) self.mid.block_2 = CogVideoXResnetBlock3D( in_channels=block_in, out_channels=block_in, @@ -136,7 +276,7 @@ def __init__( self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=norm_num_groups, eps=1e-6) conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = CogVideoXCausalConv3d( - block_in, conv_out_channels if double_z else latent_channels, kernel_size=3, pad_mode=pad_mode + block_in, conv_out_channels if double_z else out_channels, kernel_size=3, pad_mode=pad_mode ) def forward(self, sample: torch.Tensor) -> torch.Tensor: @@ -149,17 +289,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: for i_level in range(self.num_resolutions): for i_block in range(self.layers_per_block): sample = self.down[i_level].block[i_block](sample, temb) - if len(self.down[i_level].attn) > 0: - sample = self.down[i_level].attn[i_block](sample) + if i_level != self.num_resolutions - 1: sample = self.down[i_level].downsample(sample) # middle sample = self.mid.block_1(sample, temb) - - if len(self.attn_resolutions): - sample = self.mid.attn_1(sample) - sample = self.mid.block_2(sample, temb) # post-process @@ -199,34 +334,30 @@ def __init__( out_channels: int = 3, block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, - attn_resolutions=None, act_fn: str = "silu", dropout: float = 0.0, resolution: int = 256, - latent_channels: int = 16, give_pre_end: bool = False, pad_mode: str = "first", temporal_compress_times: int = 4, norm_num_groups=32, ): super(Decoder3D, self).__init__() - if attn_resolutions is None: - attn_resolutions = [] + self.act_fn = get_activation(act_fn) self.num_resolutions = len(block_out_channels) self.layers_per_block = layers_per_block self.resolution = resolution self.give_pre_end = give_pre_end - self.attn_resolutions = attn_resolutions self.norm_num_groups = norm_num_groups self.temporal_compress_level = int(np.log2(temporal_compress_times)) block_in = block_out_channels[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, latent_channels, curr_res, curr_res) + self.z_shape = (1, in_channels, curr_res, curr_res) print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - self.conv_in = CogVideoXCausalConv3d(latent_channels, block_in, kernel_size=3, pad_mode=pad_mode) + self.conv_in = CogVideoXCausalConv3d(in_channels, block_in, kernel_size=3, pad_mode=pad_mode) # middle self.mid = nn.Module() @@ -236,12 +367,10 @@ def __init__( temb_channels=0, dropout=dropout, non_linearity=act_fn, - latent_channels=latent_channels, + latent_channels=in_channels, groups=norm_num_groups, pad_mode=pad_mode, ) - if len(attn_resolutions) > 0: - self.mid.attn_1 = AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups) self.mid.block_2 = CogVideoXResnetBlock3D( in_channels=block_in, @@ -249,7 +378,7 @@ def __init__( temb_channels=0, dropout=dropout, non_linearity=act_fn, - latent_channels=latent_channels, + latent_channels=in_channels, groups=norm_num_groups, pad_mode=pad_mode, ) @@ -259,7 +388,7 @@ def __init__( self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() - attn = nn.ModuleList() + block_out = block_out_channels[i_level] for i_block in range(self.layers_per_block + 1): block.append( @@ -269,17 +398,16 @@ def __init__( temb_channels=0, non_linearity=act_fn, dropout=dropout, - latent_channels=latent_channels, + latent_channels=in_channels, groups=norm_num_groups, pad_mode=pad_mode, ) ) block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock2D(in_channels=block_in, norm_num_groups=norm_num_groups)) + up = nn.Module() up.block = block - up.attn = attn + if i_level != 0: if i_level < self.num_resolutions - self.temporal_compress_level: up.upsample = CogVideoXUpzSample3D( @@ -291,7 +419,7 @@ def __init__( self.up.insert(0, up) - self.norm_out = CogVideoXSpatialNorm3D(f_channels=block_in, zq_channels=latent_channels) + self.norm_out = CogVideoXSpatialNorm3D(f_channels=block_in, zq_channels=in_channels) self.conv_out = CogVideoXCausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) @@ -305,8 +433,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: # middle hidden_states = self.mid.block_1(hidden_states, temb, sample) - if len(self.attn_resolutions) > 0: - hidden_states = self.mid.attn_1(hidden_states, sample) + hidden_states = self.mid.block_2(hidden_states, temb, sample) # UpSampling @@ -314,8 +441,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.layers_per_block + 1): hidden_states = self.up[i_level].block[i_block](hidden_states, temb, sample) - if len(self.up[i_level].attn) > 0: - hidden_states = self.up[i_level].attn[i_block](hidden_states, sample) + if i_level != 0: hidden_states = self.up[i_level].upsample(hidden_states) @@ -496,9 +622,9 @@ def __init__( ): super().__init__(f_channels, zq_channels) self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv = CogVideoXSaveConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) - self.conv_y = CogVideoXSaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = CogVideoXSaveConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv = CogVideoXCausalConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: if zq.shape[2] > 1: @@ -645,7 +771,6 @@ def __init__( eps=eps, non_linearity=non_linearity, conv_shortcut=conv_shortcut, - latent_channels=latent_channels, ) out_channels = in_channels if out_channels is None else out_channels @@ -722,44 +847,6 @@ def forward( return output_tensor -# Todo: Need refactor? @a-r-r-o-w -class AttnBlock2D(nn.Module): - def __init__(self, in_channels: int, norm_num_groups: int): - super().__init__() - - self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6) - self.to_q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.to_k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.to_v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - -def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.norm(hidden_states) - batch_size, num_channels, num_frames, height, width = hidden_states.shape - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - query = self.to_q(hidden_states) - key = self.to_k(hidden_states) - value = self.to_v(hidden_states) - # compute attention - batch_frames, num_channels, height, width = query.shape - query = query.reshape(batch_frames, num_channels, height * width) - query = query.permute(0, 2, 1) # b_t, hw, c - key = key.reshape(batch_frames, num_channels, height * width) # b_t, c, hw - # implement c**-0.5 on q - query = query * (int(num_channels) ** (-0.5)) - context = torch.bmm(query, key) # b_t, hw, hw - context = F.softmax(context, dim=2) - # attend to values - value = value.reshape(batch_frames, num_channels, height * width) - context = context.permute(0, 2, 1) # b_t, hw, hw (first hw of k, second of q) - hidden_states = torch.bmm(value, context) # b_t, c, hw (hw of q) - hidden_states = hidden_states.reshape(batch_frames, num_channels, height, width) - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, num_frames, num_channels, height, width).permute(0, 2, 1, 3, 4) - return hidden_states - - class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encodfing images into latents and decoding latent representations into images. @@ -777,7 +864,6 @@ class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the @@ -802,13 +888,13 @@ class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): def __init__( self, in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + out_channels: int = 16, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (128, 256, 256, 512), + latent_channels: int = 16, layers_per_block: int = 3, act_fn: str = "silu", - latent_channels: int = 16, norm_num_groups: int = 32, sample_size: int = 256, scaling_factor: float = 1.15258426, @@ -826,27 +912,24 @@ def __init__( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, + down_block_types=down_block_types, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, resolution=sample_size, - latent_channels=latent_channels, ) self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, block_out_channels=block_out_channels, + layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, - layers_per_block=layers_per_block, resolution=sample_size, - latent_channels=latent_channels, - ) - self.quant_conv = CogVideoXSaveConv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None - self.post_quant_conv = ( - CogVideoXSaveConv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None ) + self.quant_conv = CogVideoXSaveConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSaveConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None self.use_slicing = False self.use_tiling = False diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 58c99ac48564..4511817d0731 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -409,7 +409,6 @@ def __init__( eps: float = 1e-6, non_linearity: str = "swish", conv_shortcut: bool = False, - latent_channels: Optional[int] = None, ): super().__init__() out_channels = in_channels if out_channels is None else out_channels @@ -419,17 +418,17 @@ def __init__( self.non_linearity = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut - if latent_channels is None: + if out_channels is None: self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) else: self.norm1 = SpatialNorm3D( f_channels=in_channels, - zq_channels=latent_channels, + zq_channels=out_channels, ) self.norm2 = SpatialNorm3D( f_channels=out_channels, - zq_channels=latent_channels, + zq_channels=out_channels, ) self.conv1 = nn.Conv3d( in_channels=in_channels, From 1c661ce3d4bff37f3e9aa4d8e3e52344d38cf08b Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Wed, 31 Jul 2024 17:19:22 +0800 Subject: [PATCH 22/94] remove deriving and using nn.module --- .../models/autoencoders/autoencoder_kl3d.py | 65 +++---------------- 1 file changed, 8 insertions(+), 57 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index 5a434ac16e71..d1c69e8a7da1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -9,10 +9,8 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation -from ..attention_processor import SpatialNorm3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..resnet import ResnetBlock3D from .vae import DecoderOutput, DiagonalGaussianDistribution @@ -456,49 +454,11 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: return hidden_states -class UpSample3D(nn.Module): - r""" - The `UpSample` layer of a variational autoencoder that upsamples its input. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `UpSample` class.""" - - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = F.interpolate(x, scale_factor=2.0) - x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - - return x - - ## ==== After this is the special code of CogVideoX ==== ## # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXSaveConv3d(torch.nn.Conv3d): +class CogVideoXSaveConv3d(nn.Conv3d): """ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. """ @@ -610,7 +570,7 @@ def forward(self, x): # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXSpatialNorm3D(SpatialNorm3D): +class CogVideoXSpatialNorm3D(nn.Module): """ Use CogVideoXSaveConv3d instead of nn.Conv3d to avoid OOM in CogVideoX Model """ @@ -620,7 +580,7 @@ def __init__( f_channels: int, zq_channels: int, ): - super().__init__(f_channels, zq_channels) + super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv = CogVideoXCausalConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) @@ -643,7 +603,7 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXUpzSample3D(UpSample3D): +class CogVideoXUpzSample3D(nn.Module): r""" Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. @@ -658,7 +618,7 @@ class CogVideoXUpzSample3D(UpSample3D): """ def __init__(self, in_channels: int, out_channels: int, compress_time: bool = False): - super().__init__(in_channels, out_channels) + super().__init__() self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time @@ -748,7 +708,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class CogVideoXResnetBlock3D(ResnetBlock3D): +class CogVideoXResnetBlock3D(nn.Module): def __init__( self, in_channels: int, @@ -762,22 +722,13 @@ def __init__( latent_channels: Optional[int] = None, pad_mode: str = "first", ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - eps=eps, - non_linearity=non_linearity, - conv_shortcut=conv_shortcut, - ) + super().__init__() out_channels = in_channels if out_channels is None else out_channels self.in_channels = in_channels self.out_channels = out_channels - self.act_fn = get_activation(non_linearity) + self.non_linearity = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut if latent_channels is None: From b3052807e5f06ed3198599fd0a4d69370974c7d6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 12:02:38 +0200 Subject: [PATCH 23/94] add skeleton for pipeline --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/cogvideo/__init__.py | 48 ++++ .../pipelines/cogvideo/pipeline_cogvideox.py | 228 ++++++++++++++++++ 4 files changed, 280 insertions(+) create mode 100644 src/diffusers/pipelines/cogvideo/__init__.py create mode 100644 src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 97d86201ccff..58660f88e904 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -246,6 +246,7 @@ "ChatGLMModel", "ChatGLMTokenizer", "CLIPImageProjection", + "CogVideoXPipeline", "CycleDiffusionPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPipeline", @@ -661,6 +662,7 @@ ChatGLMModel, ChatGLMTokenizer, CLIPImageProjection, + CogVideoXPipeline, CycleDiffusionPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7bc50b297566..0a7e68de08eb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -129,6 +129,7 @@ "AudioLDM2UNet2DConditionModel", ] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] + _import_structure["cogvideo"] = ["CogVideoXPipeline"] _import_structure["controlnet"].extend( [ "BlipDiffusionControlNetPipeline", @@ -428,6 +429,7 @@ ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline + from .cogvideo import CogVideoXPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/cogvideo/__init__.py b/src/diffusers/pipelines/cogvideo/__init__.py new file mode 100644 index 000000000000..d155d3ef51b7 --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_cogvideox import CogVideoXPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py new file mode 100644 index 000000000000..ae3b42390e06 --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -0,0 +1,228 @@ +# TODO: Ask collaborators about license +# Copyright 2024 The CogVideoX Team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...models import AutoencoderKL3D, CogVideoXTransformer3D +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...video_processor import VideoProcessor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + # TODO: update example + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class CogVideoXPipelineOutput(BaseOutput): + frames: torch.Tensor + + +class CogVideoXPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3D`]): + A text conditioned `CogVideoXTransformer3D` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL3D, + transformer: CogVideoXTransformer3D, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + def encode_prompt( + self, + ): + r""" + TODO: implement encode prompt with T5 XXL + """ + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + ): + # TODO: implement check_inputs + pass + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # TODO: implement prepare_latents + pass + + def decode_latents(self, latents: torch.Tensor, video_length: int, vae_batch_size: int = 16): + # TODO: implement decode_latents + pass + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + TODO: implement forward pass + """ + pass From 6bcafcbaa62dcc701ce18e4999d591f5174322bd Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 12:02:58 +0200 Subject: [PATCH 24/94] make fix-copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8e40e5128854..8cd1d54fb237 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -257,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogVideoXPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 3ae94139669ce54961767212efac7d41b807a16a Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 12:42:03 +0200 Subject: [PATCH 25/94] undo unnecessary changes added on cogvideo-vae by mistake --- .../geodiff_molecule_conformation.ipynb | 7222 ++++++++--------- 1 file changed, 3611 insertions(+), 3611 deletions(-) diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb index 19b87bc18012..bde093802a5d 100644 --- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb +++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb @@ -1,3652 +1,3652 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "F88mignPnalS" - }, - "source": [ - "# Introduction\n", - "\n", - "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", - "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", - "\n", - "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", - "\n", - "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", - "\n", - "> Colab made by [natolambert](https://twitter.com/natolambert).\n", - "\n", - "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7cnwXMocnuzB" - }, - "source": [ - "## Installations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Install Conda" - ], - "metadata": { - "id": "ff9SxWnaNId9" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1g_6zOabItDk" - }, - "source": [ - "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K0ofXobG5Y-X", - "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2021 NVIDIA Corporation\n", - "Built on Sun_Feb_14_21:12:58_PST_2021\n", - "Cuda compilation tools, release 11.2, V11.2.152\n", - "Build cuda_11.2.r11.2/compiler.29618528_0\n" - ] - } - ], - "source": [ - "!nvcc --version" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfthW90vI0nw" - }, - "source": [ - "Install Conda for some more complex dependencies for geometric networks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2WNFzSnbiE0k", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", - "\u001B[0m" - ] - } - ], - "source": [ - "!pip install -q condacolab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NUsbWYCUI7Km" - }, - "source": [ - "Setup Conda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FZelreINdmd0", - "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✨🍰✨ Everything looks OK!\n" - ] - } - ], - "source": [ - "import condacolab\n", - "condacolab.install()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JzDHaPU7I9Sn" - }, - "source": [ - "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMxRjHhL7w8V", - "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", - "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - cudatoolkit=11.1\n", - " - pytorch\n", - " - torchaudio\n", - " - torchvision\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 960 KB\n", - "\n", - "The following packages will be UPDATED:\n", - "\n", - " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", - "Preparing transaction: / \b\bdone\n", - "Verifying transaction: \\ \b\bdone\n", - "Executing transaction: / \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", - "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" - ] - }, - { - "cell_type": "markdown", - "source": [ - "Need to remove a pathspec for colab that specifies the incorrect cuda version." - ], - "metadata": { - "id": "QDS6FPZ0Tu5b" - } - }, - { - "cell_type": "code", - "source": [ - "!rm /usr/local/conda-meta/pinned" - ], - "metadata": { - "id": "dq1lxR10TtrR", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z1L3DdZOJB30" - }, - "source": [ - "Install torch geometric (used in the model later)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D5ukfCOWfjzK", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - pytorch-geometric=1.7.2\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " decorator-4.4.2 | py_0 11 KB conda-forge\n", - " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", - " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", - " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", - " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", - " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", - " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", - " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", - " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", - " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", - " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", - " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", - " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", - " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", - " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", - " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", - " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", - " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", - " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", - " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 55.9 MB\n", - "\n", - "The following NEW packages will be INSTALLED:\n", - "\n", - " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", - " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", - " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", - " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", - " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", - " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", - " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", - " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", - " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", - " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", - " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", - " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", - " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", - " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", - " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", - " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", - " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", - " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", - " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", - "\n", - "The following packages will be DOWNGRADED:\n", - "\n", - " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", - "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", - "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", - "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", - "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", - "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", - "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", - "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", - "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", - "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", - "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", - "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", - "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", - "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", - "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", - "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", - "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", - "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", - "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", - "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", - "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", - "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install -c rusty1s pytorch-geometric=1.7.2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppxv6Mdkalbc" - }, - "source": [ - "### Install Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mgQA_XN-XGY2", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content\n", - "Cloning into 'diffusers'...\n", - "remote: Enumerating objects: 9298, done.\u001B[K\n", - "remote: Counting objects: 100% (40/40), done.\u001B[K\n", - "remote: Compressing objects: 100% (23/23), done.\u001B[K\n", - "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001B[K\n", - "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", - "Resolving deltas: 100% (6168/6168), done.\n", - " Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", - " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m757.0/757.0 kB\u001B[0m \u001B[31m52.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m163.5/163.5 kB\u001B[0m \u001B[31m21.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m40.8/40.8 kB\u001B[0m \u001B[31m5.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m596.3/596.3 kB\u001B[0m \u001B[31m51.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25h Building wheel for diffusers (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", - "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m432.7/432.7 kB\u001B[0m \u001B[31m36.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m5.3/5.3 MB\u001B[0m \u001B[31m90.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m35.3/35.3 MB\u001B[0m \u001B[31m39.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m115.1/115.1 kB\u001B[0m \u001B[31m16.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m948.0/948.0 kB\u001B[0m \u001B[31m63.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m212.2/212.2 kB\u001B[0m \u001B[31m21.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m95.8/95.8 kB\u001B[0m \u001B[31m12.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m140.8/140.8 kB\u001B[0m \u001B[31m18.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m7.6/7.6 MB\u001B[0m \u001B[31m104.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m148.0/148.0 kB\u001B[0m \u001B[31m20.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m231.3/231.3 kB\u001B[0m \u001B[31m30.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m94.8/94.8 kB\u001B[0m \u001B[31m14.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m58.8/58.8 kB\u001B[0m \u001B[31m8.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25h\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", - "\u001B[0m" - ] - } - ], - "source": [ - "%cd /content\n", - "\n", - "# install latest HF diffusers (will update to the release once added)\n", - "!git clone https://github.com/huggingface/diffusers.git\n", - "!pip install -q /content/diffusers\n", - "\n", - "# dependencies for diffusers\n", - "!pip install -q datasets transformers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LZO6AJKuJKO8" - }, - "source": [ - "Check that torch is installed correctly and utilizing the GPU in the colab" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gZt7BNi1e1PA", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] }, - "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "True\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" + }, + "source": [ + "## Installations\n", + "\n" + ] }, { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'1.8.2'" + "cell_type": "markdown", + "source": [ + "### Install Conda" ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 8 - } - ], - "source": [ - "import torch\n", - "print(torch.cuda.is_available())\n", - "torch.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLE7CqlfJNUO" - }, - "source": [ - "### Install Chemistry-specific Dependencies\n", - "\n", - "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0CPv_NvehRz3", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting rdkit\n", - " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m36.8/36.8 MB\u001B[0m \u001B[31m34.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", - "Installing collected packages: rdkit\n", - "Successfully installed rdkit-2022.3.5\n", - "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", - "\u001B[0m" - ] - } - ], - "source": [ - "!pip install rdkit" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88GaDbDPxJ5I" - }, - "source": [ - "### Get viewer from nglview\n", - "\n", - "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", - "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", - "The rdmol in this object is a source of ground truth for the generated molecules.\n", - "\n", - "You will use one rendering function from nglviewer later!\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jcl8GCS2mz6t", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting nglview\n", - " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m5.7/5.7 MB\u001B[0m \u001B[31m91.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25h Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", - " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", - "Collecting jupyterlab-widgets\n", - " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m384.1/384.1 kB\u001B[0m \u001B[31m40.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting ipywidgets>=7\n", - " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m134.4/134.4 kB\u001B[0m \u001B[31m21.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting widgetsnbextension~=4.0\n", - " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m2.0/2.0 MB\u001B[0m \u001B[31m84.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting ipython>=6.1.0\n", - " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m793.8/793.8 kB\u001B[0m \u001B[31m60.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting ipykernel>=4.5.1\n", - " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m138.4/138.4 kB\u001B[0m \u001B[31m20.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting traitlets>=4.3.1\n", - " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m107.1/107.1 kB\u001B[0m \u001B[31m17.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", - "Collecting pyzmq>=17\n", - " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.1/1.1 MB\u001B[0m \u001B[31m68.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting matplotlib-inline>=0.1\n", - " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", - "Collecting tornado>=6.1\n", - " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m424.0/424.0 kB\u001B[0m \u001B[31m41.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting nest-asyncio\n", - " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", - "Collecting debugpy>=1.0\n", - " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.8/1.8 MB\u001B[0m \u001B[31m83.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting psutil\n", - " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m281.3/281.3 kB\u001B[0m \u001B[31m33.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting jupyter-client>=6.1.12\n", - " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m132.2/132.2 kB\u001B[0m \u001B[31m19.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting pickleshare\n", - " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", - "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", - "Collecting backcall\n", - " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pexpect>4.3\n", - " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m59.0/59.0 kB\u001B[0m \u001B[31m7.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting pygments\n", - " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.1/1.1 MB\u001B[0m \u001B[31m70.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting jedi>=0.16\n", - " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.6/1.6 MB\u001B[0m \u001B[31m83.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", - " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m382.3/382.3 kB\u001B[0m \u001B[31m40.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", - "Collecting parso<0.9.0,>=0.8.0\n", - " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m100.8/100.8 kB\u001B[0m \u001B[31m14.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", - "Collecting entrypoints\n", - " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", - "Collecting jupyter-core>=4.9.2\n", - " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m88.4/88.4 kB\u001B[0m \u001B[31m14.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting ptyprocess>=0.5\n", - " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", - "Collecting wcwidth\n", - " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", - "Building wheels for collected packages: nglview\n", - " Building wheel for nglview (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", - " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", - " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", - "Successfully built nglview\n", - "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", - "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", - "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n", - "\u001B[0m" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "pexpect", - "pickleshare", - "wcwidth" - ] - } + "metadata": { + "id": "ff9SxWnaNId9" } - }, - "metadata": {} - } - ], - "source": [ - "!pip install nglview" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Create a diffusion model" - ], - "metadata": { - "id": "8t8_e_uVLdKB" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Model class(es)" - ], - "metadata": { - "id": "G0rMncVtNSqU" - } - }, - { - "cell_type": "markdown", - "source": [ - "Imports" - ], - "metadata": { - "id": "L5FEXz5oXkzt" - } - }, - { - "cell_type": "code", - "source": [ - "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", - "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", - "from dataclasses import dataclass\n", - "from typing import Callable, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch import Tensor, nn\n", - "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", - "\n", - "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", - "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", - "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", - "from torch_scatter import scatter_add\n", - "from torch_sparse import SparseTensor, coalesce\n", - "\n", - "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", - "from diffusers.modeling_utils import ModelMixin\n", - "from diffusers.utils import BaseOutput\n" - ], - "metadata": { - "id": "-3-P4w5sXkRU" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Helper classes" - ], - "metadata": { - "id": "EzJQXPN_XrMX" - } - }, - { - "cell_type": "code", - "source": [ - "@dataclass\n", - "class MoleculeGNNOutput(BaseOutput):\n", - " \"\"\"\n", - " Args:\n", - " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", - " Hidden states output. Output of last layer of model.\n", - " \"\"\"\n", - "\n", - " sample: torch.Tensor\n", - "\n", - "\n", - "class MultiLayerPerceptron(nn.Module):\n", - " \"\"\"\n", - " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", - " Args:\n", - " input_dim (int): input dimension\n", - " hidden_dim (list of int): hidden dimensions\n", - " activation (str or function, optional): activation function\n", - " dropout (float, optional): dropout rate\n", - " \"\"\"\n", - "\n", - " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", - " super(MultiLayerPerceptron, self).__init__()\n", - "\n", - " self.dims = [input_dim] + hidden_dims\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", - " self.activation = None\n", - " if dropout > 0:\n", - " self.dropout = nn.Dropout(dropout)\n", - " else:\n", - " self.dropout = None\n", - "\n", - " self.layers = nn.ModuleList()\n", - " for i in range(len(self.dims) - 1):\n", - " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\"\"\"\n", - " for i, layer in enumerate(self.layers):\n", - " x = layer(x)\n", - " if i < len(self.layers) - 1:\n", - " if self.activation:\n", - " x = self.activation(x)\n", - " if self.dropout:\n", - " x = self.dropout(x)\n", - " return x\n", - "\n", - "\n", - "class ShiftedSoftplus(torch.nn.Module):\n", - " def __init__(self):\n", - " super(ShiftedSoftplus, self).__init__()\n", - " self.shift = torch.log(torch.tensor(2.0)).item()\n", - "\n", - " def forward(self, x):\n", - " return F.softplus(x) - self.shift\n", - "\n", - "\n", - "class CFConv(MessagePassing):\n", - " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", - " super(CFConv, self).__init__(aggr=\"add\")\n", - " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", - " self.lin2 = Linear(num_filters, out_channels)\n", - " self.nn = mlp\n", - " self.cutoff = cutoff\n", - " self.smooth = smooth\n", - "\n", - " self.reset_parameters()\n", - "\n", - " def reset_parameters(self):\n", - " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", - " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", - " self.lin2.bias.data.fill_(0)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " if self.smooth:\n", - " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", - " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", - " else:\n", - " C = (edge_length <= self.cutoff).float()\n", - " W = self.nn(edge_attr) * C.view(-1, 1)\n", - "\n", - " x = self.lin1(x)\n", - " x = self.propagate(edge_index, x=x, W=W)\n", - " x = self.lin2(x)\n", - " return x\n", - "\n", - " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", - " return x_j * W\n", - "\n", - "\n", - "class InteractionBlock(torch.nn.Module):\n", - " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", - " super(InteractionBlock, self).__init__()\n", - " mlp = Sequential(\n", - " Linear(num_gaussians, num_filters),\n", - " ShiftedSoftplus(),\n", - " Linear(num_filters, num_filters),\n", - " )\n", - " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", - " self.act = ShiftedSoftplus()\n", - " self.lin = Linear(hidden_channels, hidden_channels)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " x = self.conv(x, edge_index, edge_length, edge_attr)\n", - " x = self.act(x)\n", - " x = self.lin(x)\n", - " return x\n", - "\n", - "\n", - "class SchNetEncoder(Module):\n", - " def __init__(\n", - " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.hidden_channels = hidden_channels\n", - " self.num_filters = num_filters\n", - " self.num_interactions = num_interactions\n", - " self.cutoff = cutoff\n", - "\n", - " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", - "\n", - " self.interactions = ModuleList()\n", - " for _ in range(num_interactions):\n", - " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", - " self.interactions.append(block)\n", - "\n", - " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", - " if embed_node:\n", - " assert z.dim() == 1 and z.dtype == torch.long\n", - " h = self.embedding(z)\n", - " else:\n", - " h = z\n", - " for interaction in self.interactions:\n", - " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", - "\n", - " return h\n", - "\n", - "\n", - "class GINEConv(MessagePassing):\n", - " \"\"\"\n", - " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", - " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", - " \"\"\"\n", - "\n", - " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", - " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", - " self.nn = mlp\n", - " self.initial_eps = eps\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " if train_eps:\n", - " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", - " else:\n", - " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", - "\n", - " def forward(\n", - " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", - " ) -> torch.Tensor:\n", - " \"\"\"\"\"\"\n", - " if isinstance(x, torch.Tensor):\n", - " x: OptPairTensor = (x, x)\n", - "\n", - " # Node and edge feature dimensionalites need to match.\n", - " if isinstance(edge_index, torch.Tensor):\n", - " assert edge_attr is not None\n", - " assert x[0].size(-1) == edge_attr.size(-1)\n", - " elif isinstance(edge_index, SparseTensor):\n", - " assert x[0].size(-1) == edge_index.size(-1)\n", - "\n", - " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", - " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", - "\n", - " x_r = x[1]\n", - " if x_r is not None:\n", - " out += (1 + self.eps) * x_r\n", - "\n", - " return self.nn(out)\n", - "\n", - " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", - " if self.activation:\n", - " return self.activation(x_j + edge_attr)\n", - " else:\n", - " return x_j + edge_attr\n", - "\n", - " def __repr__(self):\n", - " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", - "\n", - "\n", - "class GINEncoder(torch.nn.Module):\n", - " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", - " super().__init__()\n", - "\n", - " self.hidden_dim = hidden_dim\n", - " self.num_convs = num_convs\n", - " self.short_cut = short_cut\n", - " self.concat_hidden = concat_hidden\n", - " self.node_emb = nn.Embedding(100, hidden_dim)\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " self.convs = nn.ModuleList()\n", - " for i in range(self.num_convs):\n", - " self.convs.append(\n", - " GINEConv(\n", - " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", - " activation=activation,\n", - " )\n", - " )\n", - "\n", - " def forward(self, z, edge_index, edge_attr):\n", - " \"\"\"\n", - " Input:\n", - " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", - " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", - " Output:\n", - " node_feature: graph feature\n", - " \"\"\"\n", - "\n", - " node_attr = self.node_emb(z) # (num_node, hidden)\n", - "\n", - " hiddens = []\n", - " conv_input = node_attr # (num_node, hidden)\n", - "\n", - " for conv_idx, conv in enumerate(self.convs):\n", - " hidden = conv(conv_input, edge_index, edge_attr)\n", - " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", - " hidden = self.activation(hidden)\n", - " assert hidden.shape == conv_input.shape\n", - " if self.short_cut and hidden.shape == conv_input.shape:\n", - " hidden += conv_input\n", - "\n", - " hiddens.append(hidden)\n", - " conv_input = hidden\n", - "\n", - " if self.concat_hidden:\n", - " node_feature = torch.cat(hiddens, dim=-1)\n", - " else:\n", - " node_feature = hiddens[-1]\n", - "\n", - " return node_feature\n", - "\n", - "\n", - "class MLPEdgeEncoder(Module):\n", - " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", - " super().__init__()\n", - " self.hidden_dim = hidden_dim\n", - " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", - " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", - "\n", - " @property\n", - " def out_channels(self):\n", - " return self.hidden_dim\n", - "\n", - " def forward(self, edge_length, edge_type):\n", - " \"\"\"\n", - " Input:\n", - " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", - " Returns:\n", - " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", - " \"\"\"\n", - " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", - " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", - " return d_emb * edge_attr # (num_edge, hidden)\n", - "\n", - "\n", - "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", - " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", - " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", - " return h_pair\n", - "\n", - "\n", - "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", - " \"\"\"\n", - " Args:\n", - " num_nodes: Number of atoms.\n", - " edge_index: Bond indices of the original graph.\n", - " edge_type: Bond types of the original graph.\n", - " order: Extension order.\n", - " Returns:\n", - " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", - " \"\"\"\n", - "\n", - " def binarize(x):\n", - " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", - "\n", - " def get_higher_order_adj_matrix(adj, order):\n", - " \"\"\"\n", - " Args:\n", - " adj: (N, N)\n", - " type_mat: (N, N)\n", - " Returns:\n", - " Following attributes will be updated:\n", - " - edge_index\n", - " - edge_type\n", - " Following attributes will be added to the data object:\n", - " - bond_edge_index: Original edge_index.\n", - " \"\"\"\n", - " adj_mats = [\n", - " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", - " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", - " ]\n", - "\n", - " for i in range(2, order + 1):\n", - " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", - " order_mat = torch.zeros_like(adj)\n", - "\n", - " for i in range(1, order + 1):\n", - " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", - "\n", - " return order_mat\n", - "\n", - " num_types = 22\n", - " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", - " # from rdkit.Chem.rdchem import BondType as BT\n", - " N = num_nodes\n", - " adj = to_dense_adj(edge_index).squeeze(0)\n", - " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", - "\n", - " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", - " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", - " assert (type_mat * type_highorder == 0).all()\n", - " type_new = type_mat + type_highorder\n", - "\n", - " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", - " _, edge_order = dense_to_sparse(adj_order)\n", - "\n", - " # data.bond_edge_index = data.edge_index # Save original edges\n", - " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", - " assert edge_type.dim() == 1\n", - " N = pos.size(0)\n", - "\n", - " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", - "\n", - " if is_sidechain is None:\n", - " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", - " else:\n", - " # fetch sidechain and its batch index\n", - " is_sidechain = is_sidechain.bool()\n", - " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", - " sidechain_pos = pos[is_sidechain]\n", - " sidechain_index = dummy_index[is_sidechain]\n", - " sidechain_batch = batch[is_sidechain]\n", - "\n", - " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", - " r_edge_index_x = assign_index[1]\n", - " r_edge_index_y = assign_index[0]\n", - " r_edge_index_y = sidechain_index[r_edge_index_y]\n", - "\n", - " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", - " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", - " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", - " # delete self loop\n", - " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", - "\n", - " rgraph_adj = torch.sparse.LongTensor(\n", - " rgraph_edge_index,\n", - " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", - " torch.Size([N, N]),\n", - " )\n", - "\n", - " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", - "\n", - " new_edge_index = composed_adj.indices()\n", - " new_edge_type = composed_adj.values().long()\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def extend_graph_order_radius(\n", - " num_nodes,\n", - " pos,\n", - " edge_index,\n", - " edge_type,\n", - " batch,\n", - " order=3,\n", - " cutoff=10.0,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - "):\n", - " if extend_order:\n", - " edge_index, edge_type = _extend_graph_order(\n", - " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", - " )\n", - "\n", - " if extend_radius:\n", - " edge_index, edge_type = _extend_to_radius_graph(\n", - " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", - " )\n", - "\n", - " return edge_index, edge_type\n", - "\n", - "\n", - "def get_distance(pos, edge_index):\n", - " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", - "\n", - "\n", - "def graph_field_network(score_d, pos, edge_index, edge_length):\n", - " \"\"\"\n", - " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", - " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", - " \"\"\"\n", - " N = pos.size(0)\n", - " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", - " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", - " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", - " ) # (N, 3)\n", - " return score_pos\n", - "\n", - "\n", - "def clip_norm(vec, limit, p=2):\n", - " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", - " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", - " return vec * denom\n", - "\n", - "\n", - "def is_local_edge(edge_type):\n", - " return edge_type > 0\n" - ], - "metadata": { - "id": "oR1Y56QiLY90" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Main model class!" - ], - "metadata": { - "id": "QWrHJFcYXyUB" - } - }, - { - "cell_type": "code", - "source": [ - "class MoleculeGNN(ModelMixin, ConfigMixin):\n", - " @register_to_config\n", - " def __init__(\n", - " self,\n", - " hidden_dim=128,\n", - " num_convs=6,\n", - " num_convs_local=4,\n", - " cutoff=10.0,\n", - " mlp_act=\"relu\",\n", - " edge_order=3,\n", - " edge_encoder=\"mlp\",\n", - " smooth_conv=True,\n", - " ):\n", - " super().__init__()\n", - " self.cutoff = cutoff\n", - " self.edge_encoder = edge_encoder\n", - " self.edge_order = edge_order\n", - "\n", - " \"\"\"\n", - " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", - " in SchNetEncoder\n", - " \"\"\"\n", - " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - "\n", - " \"\"\"\n", - " The graph neural network that extracts node-wise features.\n", - " \"\"\"\n", - " self.encoder_global = SchNetEncoder(\n", - " hidden_channels=hidden_dim,\n", - " num_filters=hidden_dim,\n", - " num_interactions=num_convs,\n", - " edge_channels=self.edge_encoder_global.out_channels,\n", - " cutoff=cutoff,\n", - " smooth=smooth_conv,\n", - " )\n", - " self.encoder_local = GINEncoder(\n", - " hidden_dim=hidden_dim,\n", - " num_convs=num_convs_local,\n", - " )\n", - "\n", - " \"\"\"\n", - " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", - " gradients w.r.t. edge_length (out_dim = 1).\n", - " \"\"\"\n", - " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " \"\"\"\n", - " Incorporate parameters together\n", - " \"\"\"\n", - " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", - " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", - "\n", - " def _forward(\n", - " self,\n", - " atom_type,\n", - " pos,\n", - " bond_index,\n", - " bond_type,\n", - " batch,\n", - " time_step, # NOTE, model trained without timestep performed best\n", - " edge_index=None,\n", - " edge_type=None,\n", - " edge_length=None,\n", - " return_edges=False,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " atom_type: Types of atoms, (N, ).\n", - " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", - " bond_type: Bond types, (E, ).\n", - " batch: Node index to graph index, (N, ).\n", - " \"\"\"\n", - " N = atom_type.size(0)\n", - " if edge_index is None or edge_type is None or edge_length is None:\n", - " edge_index, edge_type = extend_graph_order_radius(\n", - " num_nodes=N,\n", - " pos=pos,\n", - " edge_index=bond_index,\n", - " edge_type=bond_type,\n", - " batch=batch,\n", - " order=self.edge_order,\n", - " cutoff=self.cutoff,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " is_sidechain=is_sidechain,\n", - " )\n", - " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", - " local_edge_mask = is_local_edge(edge_type) # (E, )\n", - "\n", - " # with the parameterization of NCSNv2\n", - " # DDPM loss implicit handle the noise variance scale conditioning\n", - " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", - "\n", - " # Encoding global\n", - " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - "\n", - " # Global\n", - " node_attr_global = self.encoder_global(\n", - " z=atom_type,\n", - " edge_index=edge_index,\n", - " edge_length=edge_length,\n", - " edge_attr=edge_attr_global,\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_global = assemble_atom_pair_feature(\n", - " node_attr=node_attr_global,\n", - " edge_index=edge_index,\n", - " edge_attr=edge_attr_global,\n", - " ) # (E_global, 2H)\n", - " # Invariant features of edges (radius graph, global)\n", - " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", - "\n", - " # Encoding local\n", - " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - " # edge_attr += temb_edge\n", - "\n", - " # Local\n", - " node_attr_local = self.encoder_local(\n", - " z=atom_type,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_local = assemble_atom_pair_feature(\n", - " node_attr=node_attr_local,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " ) # (E_local, 2H)\n", - "\n", - " # Invariant features of edges (bond graph, local)\n", - " if isinstance(sigma_edge, torch.Tensor):\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", - " 1.0 / sigma_edge[local_edge_mask]\n", - " ) # (E_local, 1)\n", - " else:\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", - "\n", - " if return_edges:\n", - " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", - " else:\n", - " return edge_inv_global, edge_inv_local\n", - "\n", - " def forward(\n", - " self,\n", - " sample,\n", - " timestep: Union[torch.Tensor, float, int],\n", - " return_dict: bool = True,\n", - " sigma=1.0,\n", - " global_start_sigma=0.5,\n", - " w_global=1.0,\n", - " extend_order=False,\n", - " extend_radius=True,\n", - " clip_local=None,\n", - " clip_global=1000.0,\n", - " ) -> Union[MoleculeGNNOutput, Tuple]:\n", - " r\"\"\"\n", - " Args:\n", - " sample: packed torch geometric object\n", - " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", - " return_dict (`bool`, *optional*, defaults to `True`):\n", - " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", - " Returns:\n", - " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", - " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", - " \"\"\"\n", - "\n", - " # unpack sample\n", - " atom_type = sample.atom_type\n", - " bond_index = sample.edge_index\n", - " bond_type = sample.edge_type\n", - " num_graphs = sample.num_graphs\n", - " pos = sample.pos\n", - "\n", - " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", - "\n", - " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", - " atom_type=atom_type,\n", - " pos=sample.pos,\n", - " bond_index=bond_index,\n", - " bond_type=bond_type,\n", - " batch=sample.batch,\n", - " time_step=timesteps,\n", - " return_edges=True,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " ) # (E_global, 1), (E_local, 1)\n", - "\n", - " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", - " node_eq_local = graph_field_network(\n", - " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", - " )\n", - " if clip_local is not None:\n", - " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", - "\n", - " # Global\n", - " if sigma < global_start_sigma:\n", - " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", - " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", - " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", - " else:\n", - " node_eq_global = 0\n", - "\n", - " # Sum\n", - " eps_pos = node_eq_local + node_eq_global * w_global\n", - "\n", - " if not return_dict:\n", - " return (-eps_pos,)\n", - "\n", - " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" - ], - "metadata": { - "id": "MCeZA1qQXzoK" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CCIrPYSJj9wd" - }, - "source": [ - "### Load pretrained model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YdrAr6Ch--Ab" - }, - "source": [ - "#### Load a model\n", - "The model used is a design an\n", - "equivariant convolutional layer, named graph field network (GFN).\n", - "\n", - "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DyCo0nsqjbml", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 172, - "referenced_widgets": [ - "d90f304e9560472eacfbdd11e46765eb", - "1c6246f15b654f4daa11c9bcf997b78c", - "c2321b3bff6f490ca12040a20308f555", - "b7feb522161f4cf4b7cc7c1a078ff12d", - "e2d368556e494ae7ae4e2e992af2cd4f", - "bbef741e76ec41b7ab7187b487a383df", - "561f742d418d4721b0670cc8dd62e22c", - "872915dd1bb84f538c44e26badabafdd", - "d022575f1fa2446d891650897f187b4d", - "fdc393f3468c432aa0ada05e238a5436", - "2c9362906e4b40189f16d14aa9a348da", - "6010fc8daa7a44d5aec4b830ec2ebaa1", - "7e0bb1b8d65249d3974200686b193be2", - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "6526646be5ed415c84d1245b040e629b", - "24d31fc3576e43dd9f8301d2ef3a37ab", - "2918bfaadc8d4b1a9832522c40dfefb8", - "a4bfdca35cc54dae8812720f1b276a08", - "e4901541199b45c6a18824627692fc39", - "f915cf874246446595206221e900b2fe", - "a9e388f22a9742aaaf538e22575c9433", - "42f6c3db29d7484ba6b4f73590abd2f4" - ] - }, - "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", - "\n", - "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", - "\n" - ] - } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "dataset = torch.load('/content/molecules.pkl')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QZcmy1EvKQRk" - }, - "source": [ - "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JVjz6iH_H6Eh", - "colab": { - "base_uri": "https://localhost:8080/" + "source": [ + "!pip install -q condacolab" + ] }, - "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" - }, - "outputs": [ { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" + }, + "source": [ + "Setup Conda" ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "dataset[0]" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Run the diffusion process" - ], - "metadata": { - "id": "vHNiZAUxNgoy" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZ1KZrxKqENg" - }, - "source": [ - "#### Helper Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s240tYueqKKf" - }, - "outputs": [], - "source": [ - "from torch_geometric.data import Data, Batch\n", - "from torch_scatter import scatter_add, scatter_mean\n", - "from tqdm import tqdm\n", - "import copy\n", - "import os\n", - "\n", - "def repeat_data(data: Data, num_repeat) -> Batch:\n", - " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", - " return Batch.from_data_list(datas)\n", - "\n", - "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", - " datas = batch.to_data_list()\n", - " new_data = []\n", - " for i in range(num_repeat):\n", - " new_data += copy.deepcopy(datas)\n", - " return Batch.from_data_list(new_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMnQTk0eqT7Z" - }, - "source": [ - "#### Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WYGkzqgzrHmF" - }, - "outputs": [], - "source": [ - "num_samples = 1 # solutions per molecule\n", - "num_molecules = 3\n", - "\n", - "DEVICE = 'cuda'\n", - "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", - "# constants for inference\n", - "w_global = 0.5 #0,.3 for qm9\n", - "global_start_sigma = 0.5\n", - "eta = 1.0\n", - "clip_local = None\n", - "clip_pos = None\n", - "\n", - "# constands for data handling\n", - "save_traj = False\n", - "save_data = False\n", - "output_dir = '/content/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-xD5bJ3SqM7t" - }, - "source": [ - "#### Generate samples!\n", - "Note that the 3d representation of a molecule is referred to as the **conformation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "x9xuLUNg26z1", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " after removing the cwd from sys.path.\n", - "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" - ] - } - ], - "source": [ - "results = []\n", - "\n", - "# define sigmas\n", - "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", - "sigmas = sigmas.to(DEVICE)\n", - "\n", - "for count, data in enumerate(tqdm(dataset)):\n", - " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", - "\n", - " data_input = data.clone()\n", - " data_input['pos_ref'] = None\n", - " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", - "\n", - " # initial configuration\n", - " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", - "\n", - " # for logging animation of denoising\n", - " pos_traj = []\n", - " with torch.no_grad():\n", - "\n", - " # scale initial sample\n", - " pos = pos_init * sigmas[-1]\n", - " for t in scheduler.timesteps:\n", - " batch.pos = pos\n", - "\n", - " # generate geometry with model, then filter it\n", - " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", - "\n", - " # Update\n", - " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", - "\n", - " pos = reconstructed_pos\n", - "\n", - " if torch.isnan(pos).any():\n", - " print(\"NaN detected. Please restart.\")\n", - " raise FloatingPointError()\n", - "\n", - " # recenter graph of positions for next iteration\n", - " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", - "\n", - " # optional clipping\n", - " if clip_pos is not None:\n", - " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", - " pos_traj.append(pos.clone().cpu())\n", - "\n", - " pos_gen = pos.cpu()\n", - " if save_traj:\n", - " pos_gen_traj = pos_traj.cpu()\n", - " data.pos_gen = torch.stack(pos_gen_traj)\n", - " else:\n", - " data.pos_gen = pos_gen\n", - " results.append(data)\n", - "\n", - "\n", - "if save_data:\n", - " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", - "\n", - " with open(save_path, 'wb') as f:\n", - " pickle.dump(results, f)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Render the results!" - ], - "metadata": { - "id": "fSApwSaZNndW" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d47Zxo2OKdgZ" - }, - "source": [ - "This function allows us to render 3d in colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e9Cd0kCAv9b8" - }, - "outputs": [], - "source": [ - "from google.colab import output\n", - "output.enable_custom_widget_manager()" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Helper functions" - ], - "metadata": { - "id": "RjaVuR15NqzF" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "28rBYa9NKhlz" - }, - "source": [ - "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LKdKdwxcyTQ6" - }, - "outputs": [], - "source": [ - "from copy import deepcopy\n", - "def set_rdmol_positions(rdkit_mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " mol = deepcopy(rdkit_mol)\n", - " set_rdmol_positions_(mol, pos)\n", - " return mol\n", - "\n", - "def set_rdmol_positions_(mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " for i in range(pos.shape[0]):\n", - " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", - " return mol\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NuE10hcpKmzK" - }, - "source": [ - "Process the generated data to make it easy to view." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KieVE1vc0_Vs", - "colab": { - "base_uri": "https://localhost:8080/" }, - "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "collect 5 generated molecules in `mols`\n" - ] - } - ], - "source": [ - "# the model can generate multiple conformations per 2d geometry\n", - "num_gen = results[0]['pos_gen'].shape[0]\n", - "\n", - "# init storage objects\n", - "mols_gen = []\n", - "mols_orig = []\n", - "for to_process in results:\n", - "\n", - " # store the reference 3d position\n", - " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # store the generated 3d position\n", - " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # copy data to new object\n", - " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", - "\n", - " # append results\n", - " mols_gen.append(new_mol)\n", - " mols_orig.append(to_process.rdmol)\n", - "\n", - "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tin89JwMKp4v" - }, - "source": [ - "Import tools to visualize the 2d chemical diagram of the molecule." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yqV6gllSZn38" - }, - "outputs": [], - "source": [ - "from rdkit.Chem import AllChem\n", - "from rdkit import Chem\n", - "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", - "from IPython.display import SVG, display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TFNKmGddVoOk" - }, - "source": [ - "Select molecule to visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KzuwLlrrVaGc" - }, - "outputs": [], - "source": [ - "idx = 0\n", - "assert idx < len(results), \"selected molecule that was not generated\"" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Viewing" - ], - "metadata": { - "id": "hkb8w0_SNtU8" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I3R4QBQeKttN" - }, - "source": [ - "This 2D rendering is the equivalent of the **input to the model**!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gkQRWjraaKex", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 321 + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] + } + ], + "source": [ + "import condacolab\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" + }, + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" + ] }, - "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" - }, - "outputs": [ { - "output_type": "display_data", - "data": { - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } ], - "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" - }, - "metadata": {} - } - ], - "source": [ - "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", - "molSize=(450,300)\n", - "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", - "drawer.DrawMolecule(mc)\n", - "drawer.FinishDrawing()\n", - "svg = drawer.GetDrawingText()\n", - "display(SVG(svg.replace('svg:','')))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z4FDMYMxKw2I" - }, - "source": [ - "Generate the 3d molecule!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aT1Bkb8YxJfV", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "695ab5bbf30a4ab19df1f9f33469f314", - "eac6a8dcdc9d4335a2e51031793ead29" - ] - }, - "outputId": "b98870ae-049d-4386-b676-166e9526bda2" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "695ab5bbf30a4ab19df1f9f33469f314" + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ], + "metadata": { + "id": "QDS6FPZ0Tu5b" } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + }, + { + "cell_type": "code", + "source": [ + "!rm /usr/local/conda-meta/pinned" + ], + "metadata": { + "id": "dq1lxR10TtrR", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] } - } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D5ukfCOWfjzK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mgQA_XN-XGY2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001b[K\n", + "remote: Counting objects: 100% (40/40), done.\u001b[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers\n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" + }, + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZt7BNi1e1PA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "True\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'1.8.2'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 8 + } + ], + "source": [ + "import torch\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0CPv_NvehRz3", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jcl8GCS2mz6t", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } + } + }, + "metadata": {} + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a diffusion model" + ], + "metadata": { + "id": "8t8_e_uVLdKB" } - } - } - ], - "source": [ - "from nglview import show_rdkit as show" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pxtq8I-I18C-", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "be446195da2b4ff2aec21ec5ff963a54", - "c6596896148b4a8a9c57963b67c7782f", - "2489b5e5648541fbbdceadb05632a050", - "01e0ba4e5da04914b4652b8d58565d7b", - "c30e6c2f3e2a44dbbb3d63bd519acaa4", - "f31c6e40e9b2466a9064a2669933ecd5", - "19308ccac642498ab8b58462e3f1b0bb", - "4a081cdc2ec3421ca79dd933b7e2b0c4", - "e5c0d75eb5e1447abd560c8f2c6017e1", - "5146907ef6764654ad7d598baebc8b58", - "144ec959b7604a2cabb5ca46ae5e5379", - "abce2a80e6304df3899109c6d6cac199", - "65195cb7a4134f4887e9dd19f3676462" - ] - }, - "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "NGLWidget()" + }, + { + "cell_type": "markdown", + "source": [ + "### Model class(es)" ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "be446195da2b4ff2aec21ec5ff963a54" + "metadata": { + "id": "G0rMncVtNSqU" } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } + }, + { + "cell_type": "markdown", + "source": [ + "Imports" + ], + "metadata": { + "id": "L5FEXz5oXkzt" } - } - } - ], - "source": [ - "# new molecule\n", - "show(mols_gen[idx])" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "KJr4h2mwXeTo" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "d90f304e9560472eacfbdd11e46765eb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", - "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", - "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + }, + { + "cell_type": "code", + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput\n" ], - "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" - } - }, - "1c6246f15b654f4daa11c9bcf997b78c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", - "placeholder": "​", - "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", - "value": "Downloading: 100%" - } - }, - "c2321b3bff6f490ca12040a20308f555": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", - "max": 3271865, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", - "value": 3271865 - } - }, - "b7feb522161f4cf4b7cc7c1a078ff12d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", - "placeholder": "​", - "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", - "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" - } - }, - "e2d368556e494ae7ae4e2e992af2cd4f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "bbef741e76ec41b7ab7187b487a383df": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "561f742d418d4721b0670cc8dd62e22c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "872915dd1bb84f538c44e26badabafdd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d022575f1fa2446d891650897f187b4d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fdc393f3468c432aa0ada05e238a5436": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c9362906e4b40189f16d14aa9a348da": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6010fc8daa7a44d5aec4b830ec2ebaa1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", - "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Helper classes" ], - "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" - } - }, - "7e0bb1b8d65249d3974200686b193be2": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", - "placeholder": "​", - "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", - "value": "Downloading: 100%" - } - }, - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", - "max": 401, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", - "value": 401 - } - }, - "6526646be5ed415c84d1245b040e629b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", - "placeholder": "​", - "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", - "value": " 401/401 [00:00<00:00, 13.5kB/s]" - } - }, - "24d31fc3576e43dd9f8301d2ef3a37ab": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2918bfaadc8d4b1a9832522c40dfefb8": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a4bfdca35cc54dae8812720f1b276a08": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e4901541199b45c6a18824627692fc39": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f915cf874246446595206221e900b2fe": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "a9e388f22a9742aaaf538e22575c9433": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "42f6c3db29d7484ba6b4f73590abd2f4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "695ab5bbf30a4ab19df1f9f33469f314": { - "model_module": "nglview-js-widgets", - "model_name": "ColormakerRegistryModel", - "model_module_version": "3.0.1", - "state": { - "_dom_classes": [], - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "ColormakerRegistryModel", - "_msg_ar": [], - "_msg_q": [], - "_ready": false, - "_view_count": null, - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "ColormakerRegistryView", - "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" - } - }, - "eac6a8dcdc9d4335a2e51031793ead29": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be446195da2b4ff2aec21ec5ff963a54": { - "model_module": "nglview-js-widgets", - "model_name": "NGLModel", - "model_module_version": "3.0.1", - "state": { - "_camera_orientation": [ - -15.519693580202304, - -14.065056548036177, - -23.53197484807691, - 0, - -23.357853515109753, - 20.94055073042662, - 2.888695042134944, - 0, - 14.352363398292777, - 18.870825741878015, - -20.744689572909344, - 0, - 0.2724999189376831, - 0.6940000057220459, - -0.3734999895095825, - 1 + "metadata": { + "id": "EzJQXPN_XrMX" + } + }, + { + "cell_type": "code", + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.Tensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0\n" ], - "_camera_str": "orthographic", - "_dom_classes": [], - "_gui_theme": null, - "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", - "_igui": null, - "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "NGLModel", - "_ngl_color_dict": {}, - "_ngl_coordinate_resource": {}, - "_ngl_full_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" + "metadata": { + "id": "oR1Y56QiLY90" }, - "_ngl_msg_archive": [ - { - "target": "Stage", - "type": "call_method", - "methodName": "loadFile", - "reconstruc_color_scheme": false, - "args": [ - { - "type": "blob", - "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", - "binary": false - } - ], - "kwargs": { - "defaultRepresentation": true, - "ext": "pdb" - } - } + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Main model class!" ], - "_ngl_original_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" + "metadata": { + "id": "QWrHJFcYXyUB" + } + }, + { + "cell_type": "code", + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" + ], + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" }, - "_ngl_repr_dict": { - "0": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DyCo0nsqjbml", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] } - }, - "1": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load('/content/molecules.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JVjz6iH_H6Eh", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + ] }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run the diffusion process" + ], + "metadata": { + "id": "vHNiZAUxNgoy" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "from torch_geometric.data import Data, Batch\n", + "from torch_scatter import scatter_add, scatter_mean\n", + "from tqdm import tqdm\n", + "import copy\n", + "import os\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = 'cuda'\n", + "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 #0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None\n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = '/content/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x9xuLUNg26z1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)):\n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input['pos_ref'] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + "\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", + "\n", + " with open(save_path, 'wb') as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Render the results!" + ], + "metadata": { + "id": "fSApwSaZNndW" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Helper functions" + ], + "metadata": { + "id": "RjaVuR15NqzF" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KieVE1vc0_Vs", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "collect 5 generated molecules in `mols`\n" + ] + } + ], + "source": [ + "# the model can generate multiple conformations per 2d geometry\n", + "num_gen = results[0]['pos_gen'].shape[0]\n", + "\n", + "# init storage objects\n", + "mols_gen = []\n", + "mols_orig = []\n", + "for to_process in results:\n", + "\n", + " # store the reference 3d position\n", + " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # store the generated 3d position\n", + " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # copy data to new object\n", + " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", + "\n", + " # append results\n", + " mols_gen.append(new_mol)\n", + " mols_orig.append(to_process.rdmol)\n", + "\n", + "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tin89JwMKp4v" + }, + "source": [ + "Import tools to visualize the 2d chemical diagram of the molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yqV6gllSZn38" + }, + "outputs": [], + "source": [ + "from rdkit.Chem import AllChem\n", + "from rdkit import Chem\n", + "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", + "from IPython.display import SVG, display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFNKmGddVoOk" + }, + "source": [ + "Select molecule to visualize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KzuwLlrrVaGc" + }, + "outputs": [], + "source": [ + "idx = 0\n", + "assert idx < len(results), \"selected molecule that was not generated\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Viewing" + ], + "metadata": { + "id": "hkb8w0_SNtU8" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3R4QBQeKttN" + }, + "source": [ + "This 2D rendering is the equivalent of the **input to the model**!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gkQRWjraaKex", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 321 + }, + "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" }, - "disablePicking": false, - "sele": "" - } + "metadata": {} } - } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", + "molSize=(450,300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace('svg:','')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" }, - "_ngl_serialize": false, - "_ngl_version": "", - "_ngl_view_id": [ - "FB989FD1-5B9C-446B-8914-6B58AF85446D" + "source": [ + "Generate the 3d molecule!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aT1Bkb8YxJfV", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "695ab5bbf30a4ab19df1f9f33469f314" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } ], - "_player_dict": {}, - "_scene_position": {}, - "_scene_rotation": {}, - "_synced_model_ids": [], - "_synced_repr_model_ids": [], - "_view_count": null, - "_view_height": "", - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "NGLView", - "_view_width": "", - "background": "white", - "frame": 0, - "gui_style": null, - "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", - "max_frame": 0, - "n_components": 2, - "picked": {} - } - }, - "c6596896148b4a8a9c57963b67c7782f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2489b5e5648541fbbdceadb05632a050": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ButtonView", - "button_style": "", - "description": "", - "disabled": false, - "icon": "compress", - "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", - "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", - "tooltip": "" - } - }, - "01e0ba4e5da04914b4652b8d58565d7b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", - "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxtq8I-I18C-", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "NGLWidget()" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "be446195da2b4ff2aec21ec5ff963a54" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } ], - "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" - } - }, - "c30e6c2f3e2a44dbbb3d63bd519acaa4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f31c6e40e9b2466a9064a2669933ecd5": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "19308ccac642498ab8b58462e3f1b0bb": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4a081cdc2ec3421ca79dd933b7e2b0c4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "SliderStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "", - "handle_color": null - } - }, - "e5c0d75eb5e1447abd560c8f2c6017e1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "PlayModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "PlayModel", - "_playing": false, - "_repeat": false, - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "PlayView", - "description": "", - "description_tooltip": null, - "disabled": false, - "interval": 100, - "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", - "max": 0, - "min": 0, - "show_repeat": true, - "step": 1, - "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", - "value": 0 - } - }, - "5146907ef6764654ad7d598baebc8b58": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntSliderModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "IntSliderModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "IntSliderView", - "continuous_update": true, - "description": "", - "description_tooltip": null, - "disabled": false, - "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", - "max": 0, - "min": 0, - "orientation": "horizontal", - "readout": true, - "readout_format": "d", - "step": 1, - "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", - "value": 0 - } - }, - "144ec959b7604a2cabb5ca46ae5e5379": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "abce2a80e6304df3899109c6d6cac199": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "34px" - } - }, - "65195cb7a4134f4887e9dd19f3676462": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "button_color": null, - "font_weight": "" - } + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + ], + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + ], + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_name": "ColormakerRegistryModel", + "model_module_version": "3.0.1", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_name": "NGLModel", + "model_module_version": "3.0.1", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292777, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 + ], + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_msg_archive": [ + { + "target": "Stage", + "type": "call_method", + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "args": [ + { + "type": "blob", + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "binary": false + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" + } + } + ], + "_ngl_original_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_repr_dict": { + "0": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + }, + "1": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + } + }, + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" + ], + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + ], + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "SliderStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "PlayModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntSliderModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + } + } } - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 2be74698213397d22256c4dea9f17ee127207209 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 12:44:18 +0200 Subject: [PATCH 26/94] groups->norm_num_groups --- src/diffusers/models/resnet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 4511817d0731..a5783b3a6150 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -606,7 +606,7 @@ def __init__( in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, - groups: int = 32, + norm_num_groups: int = 32, ): super().__init__() out_dim = out_dim or in_dim @@ -615,24 +615,24 @@ def __init__( # conv layers self.conv1 = nn.Sequential( - nn.GroupNorm(groups, in_dim), + nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv2 = nn.Sequential( - nn.GroupNorm(groups, out_dim), + nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv3 = nn.Sequential( - nn.GroupNorm(groups, out_dim), + nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv4 = nn.Sequential( - nn.GroupNorm(groups, out_dim), + nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), From 9f9d0cbb83bf15e47fc9093f223c442a09c85f33 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 13:14:39 +0200 Subject: [PATCH 27/94] verify CogVideoXSpatialNorm3D implementation --- src/diffusers/models/attention_processor.py | 38 ---- .../models/autoencoders/autoencoder_kl3d.py | 199 +++++++++--------- src/diffusers/models/resnet.py | 124 +---------- 3 files changed, 105 insertions(+), 256 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1c3417e1fe03..5c5464c37683 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2595,44 +2595,6 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: return new_f -class SpatialNorm3D(nn.Module): - """ - Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. - - Args: - f_channels (`int`): - The number of channels for input to group normalization layer, and output of the spatial norm layer. - zq_channels (`int`): - The number of channels for the quantized vector as described in the paper. - """ - - def __init__( - self, - f_channels: int, - zq_channels: int, - ): - super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) - self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - if zq.shape[2] > 1: - f_first, f_rest = f[:, :, :1], f[:, :, 1:] - f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] - z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] - z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) - z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) - zq = torch.cat([z_first, z_rest], dim=2) - else: - zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) - zq = self.conv(zq) - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f - - class IPAdapterAttnProcessor(nn.Module): r""" Attention processor for Multiple IP-Adapters. diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index d1c69e8a7da1..9f81d3145c43 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -166,6 +166,96 @@ # return sample +class CogVideoXResnetBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.non_linearity = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=spatial_norm_dim, + ) + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=spatial_norm_dim, + ) + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + else: + self.nin_shortcut = CogVideoXSaveConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, *args, **kwargs + ) -> torch.Tensor: + hidden_states = input_tensor + if zq is not None: + hidden_states = self.norm1(hidden_states, zq) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] + + if zq is not None: + hidden_states = self.norm2(hidden_states, zq) + else: + hidden_states = self.norm2(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + input_tensor = self.conv_shortcut(input_tensor) + else: + input_tensor = self.nin_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + class Encoder3D(nn.Module): r""" The `Encoder3D` layer of a variational autoencoder that encodes its input into a latent representation. @@ -365,7 +455,7 @@ def __init__( temb_channels=0, dropout=dropout, non_linearity=act_fn, - latent_channels=in_channels, + spatial_norm_dim=in_channels, groups=norm_num_groups, pad_mode=pad_mode, ) @@ -376,7 +466,7 @@ def __init__( temb_channels=0, dropout=dropout, non_linearity=act_fn, - latent_channels=in_channels, + spatial_norm_dim=in_channels, groups=norm_num_groups, pad_mode=pad_mode, ) @@ -396,7 +486,7 @@ def __init__( temb_channels=0, non_linearity=act_fn, dropout=dropout, - latent_channels=in_channels, + spatial_norm_dim=in_channels, groups=norm_num_groups, pad_mode=pad_mode, ) @@ -571,8 +661,17 @@ def forward(self, x): # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXSpatialNorm3D(nn.Module): - """ - Use CogVideoXSaveConv3d instead of nn.Conv3d to avoid OOM in CogVideoX Model + r""" + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific + to 3D-video like data. + + CogVideoXSaveConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. """ def __init__( @@ -708,96 +807,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class CogVideoXResnetBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - eps: float = 1e-6, - non_linearity: str = "swish", - conv_shortcut: bool = False, - latent_channels: Optional[int] = None, - pad_mode: str = "first", - ): - super().__init__() - - out_channels = in_channels if out_channels is None else out_channels - - self.in_channels = in_channels - self.out_channels = out_channels - self.non_linearity = get_activation(non_linearity) - self.use_conv_shortcut = conv_shortcut - - if latent_channels is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) - else: - self.norm1 = CogVideoXSpatialNorm3D( - f_channels=in_channels, - zq_channels=latent_channels, - ) - self.norm2 = CogVideoXSpatialNorm3D( - f_channels=out_channels, - zq_channels=latent_channels, - ) - self.conv1 = CogVideoXCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) - - self.dropout = torch.nn.Dropout(dropout) - - self.conv2 = CogVideoXCausalConv3d( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = CogVideoXCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - else: - self.nin_shortcut = CogVideoXSaveConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward( - self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, *args, **kwargs - ) -> torch.Tensor: - hidden_states = input_tensor - if zq is not None: - hidden_states = self.norm1(hidden_states, zq) - else: - hidden_states = self.norm1(hidden_states) - hidden_states = self.non_linearity(hidden_states) - hidden_states = self.conv1(hidden_states) - - if temb is not None: - hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] - - if zq is not None: - hidden_states = self.norm2(hidden_states, zq) - else: - hidden_states = self.norm2(hidden_states) - hidden_states = self.non_linearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - input_tensor = self.conv_shortcut(input_tensor) - else: - input_tensor = self.nin_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encodfing images into latents and decoding latent representations into images. diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index a5783b3a6150..00b55cd9c9d6 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -22,7 +22,7 @@ from ..utils import deprecate from .activations import get_activation -from .attention_processor import SpatialNorm, SpatialNorm3D +from .attention_processor import SpatialNorm from .downsampling import ( # noqa Downsample1D, Downsample2D, @@ -373,128 +373,6 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg return output_tensor -class ResnetBlock3D(nn.Module): - r""" - A Resnet3D block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv2d layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - groups_out (`int`, *optional*, default to None): - The number of groups to use for the second normalization layer. if set to None, same as `groups`. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. - use_in_shortcut (`bool`, *optional*, default to `True`): - If `True`, add a 1x1 nn.conv2d layer for skip-connection. - up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. - down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. - conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the - `conv_shortcut` output. - conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. - If None, same as `out_channels`. - """ - - def __init__( - self, - *, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - eps: float = 1e-6, - non_linearity: str = "swish", - conv_shortcut: bool = False, - ): - super().__init__() - out_channels = in_channels if out_channels is None else out_channels - - self.in_channels = in_channels - self.out_channels = out_channels - self.non_linearity = get_activation(non_linearity) - self.use_conv_shortcut = conv_shortcut - - if out_channels is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) - else: - self.norm1 = SpatialNorm3D( - f_channels=in_channels, - zq_channels=out_channels, - ) - self.norm2 = SpatialNorm3D( - f_channels=out_channels, - zq_channels=out_channels, - ) - self.conv1 = nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) - - self.dropout = torch.nn.Dropout(dropout) - - self.conv2 = nn.Conv3d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - ) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - ) - else: - self.nin_shortcut = nn.Conv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward( - self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, *args, **kwargs - ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - hidden_states = input_tensor - if zq is not None: - hidden_states = self.norm1(hidden_states, zq) - else: - hidden_states = self.norm1(hidden_states) - hidden_states = self.non_linearity(hidden_states) - hidden_states = self.conv1(hidden_states) - - if temb is not None: - hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] - - if zq is not None: - hidden_states = self.norm2(hidden_states, zq) - else: - hidden_states = self.norm2(hidden_states) - hidden_states = self.non_linearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - input_tensor = self.conv_shortcut(input_tensor) - else: - input_tensor = self.nin_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - # unet_rl.py def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: if len(tensor.shape) == 2: From c43a8f5b2bedcf96896db7c703b1f238c9959731 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 13:27:15 +0200 Subject: [PATCH 28/94] minor factor and repositioning of code in order of invocation --- .../models/autoencoders/autoencoder_kl3d.py | 311 +++++++++--------- 1 file changed, 152 insertions(+), 159 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index 9f81d3145c43..e3122bb3ac2d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -166,6 +166,155 @@ # return sample +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXSaveConv3d(nn.Conv3d): + """ + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + + # Set to 2GB, suitable for CuDNN + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) + + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super().forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super().forward(input) + + +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXCausalConv3d(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: int = 1, + dilation: int = 1, + pad_mode: str = "constant", + ): + super().__init__() + + def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = CogVideoXSaveConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + self.conv_cache = None + + def forward(self, x): + if self.pad_mode == "constant": + causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_3d, mode="constant", value=0) + elif self.pad_mode == "first": + pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) + x = torch.cat([pad_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + elif self.pad_mode == "reflect": + reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) + if reflect_x.shape[2] < self.time_pad: + reflect_x = torch.cat( + [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 + ) + x = torch.cat([reflect_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + else: + raise ValueError("Invalid pad mode") + if self.time_pad != 0 and self.conv_cache is None: + self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() + return self.conv(x) + elif self.time_pad != 0 and self.conv_cache is not None: + x = torch.cat([self.conv_cache.to(x.device), x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + self.conv_cache = None + return self.conv(x) + + return self.conv(x) + + +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXSpatialNorm3D(nn.Module): + r""" + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific + to 3D-video like data. + + CogVideoXSaveConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv = CogVideoXCausalConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + if zq.shape[2] > 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = F.interpolate(z_first, size=f_first_size) + z_rest = F.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = F.interpolate(zq, size=f.shape[-3:]) + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + class CogVideoXResnetBlock3D(nn.Module): def __init__( self, @@ -498,11 +647,9 @@ def __init__( if i_level != 0: if i_level < self.num_resolutions - self.temporal_compress_level: - up.upsample = CogVideoXUpzSample3D( - in_channels=block_in, out_channels=block_in, compress_time=False - ) + up.upsample = CogVideoXUpSample3D(in_channels=block_in, out_channels=block_in, compress_time=False) else: - up.upsample = CogVideoXUpzSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) + up.upsample = CogVideoXUpSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) curr_res = curr_res * 2 self.up.insert(0, up) @@ -548,161 +695,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXSaveConv3d(nn.Conv3d): - """ - A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 - - # Set to 2GB, Suit for CuDNN - if memory_count > 2: - kernel_size = self.kernel_size[0] - part_num = int(memory_count / 2) + 1 - input_chunks = torch.chunk(input, part_num, dim=2) - - if kernel_size > 1: - input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) - for i in range(1, len(input_chunks)) - ] - - output_chunks = [] - for input_chunk in input_chunks: - output_chunks.append(super(CogVideoXSaveConv3d, self).forward(input_chunk)) - output = torch.cat(output_chunks, dim=2) - return output - else: - return super(CogVideoXSaveConv3d, self).forward(input) - - -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXCausalConv3d(nn.Module): - """ - A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - pad_mode: str = "constant", - **kwargs, - ): - super().__init__() - - def cast_tuple(t, length=1): - return t if isinstance(t, tuple) else ((t,) * length) - - kernel_size = cast_tuple(kernel_size, 3) - - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - - dilation = kwargs.pop("dilation", 1) - stride = kwargs.pop("stride", 1) - - self.pad_mode = pad_mode - time_pad = dilation * (time_kernel_size - 1) + (1 - stride) - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 - - self.height_pad = height_pad - self.width_pad = width_pad - self.time_pad = time_pad - self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - - stride = (stride, 1, 1) - dilation = (dilation, 1, 1) - self.conv = CogVideoXSaveConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - **kwargs, - ) - - self.conv_cache = None - - def forward(self, x): - if self.pad_mode == "constant": - causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_3d, mode="constant", value=0) - elif self.pad_mode == "first": - pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) - x = torch.cat([pad_x, x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - elif self.pad_mode == "reflect": - reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) - if reflect_x.shape[2] < self.time_pad: - reflect_x = torch.cat( - [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 - ) - x = torch.cat([reflect_x, x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - else: - raise ValueError("Invalid pad mode") - if self.time_pad != 0 and self.conv_cache is None: - self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() - return self.conv(x) - elif self.time_pad != 0 and self.conv_cache is not None: - x = torch.cat([self.conv_cache.to(x.device), x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - self.conv_cache = None - return self.conv(x) - - return self.conv(x) - - -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXSpatialNorm3D(nn.Module): - r""" - Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific - to 3D-video like data. - - CogVideoXSaveConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. - - Args: - f_channels (`int`): - The number of channels for input to group normalization layer, and output of the spatial norm layer. - zq_channels (`int`): - The number of channels for the quantized vector as described in the paper. - """ - - def __init__( - self, - f_channels: int, - zq_channels: int, - ): - super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv = CogVideoXCausalConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) - self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - if zq.shape[2] > 1: - f_first, f_rest = f[:, :, :1], f[:, :, 1:] - f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] - z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] - z_first = F.interpolate(z_first, size=f_first_size) - z_rest = F.interpolate(z_rest, size=f_rest_size) - zq = torch.cat([z_first, z_rest], dim=2) - else: - zq = F.interpolate(zq, size=f.shape[-3:]) - zq = self.conv(zq) - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f - - -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXUpzSample3D(nn.Module): +class CogVideoXUpSample3D(nn.Module): r""" Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX Model. From 5f183bfe273da59c656824da2f9dc1c0fed28b8f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 13:33:57 +0200 Subject: [PATCH 29/94] reorder upsampling/downsampling blocks in order of invocation --- .../models/autoencoders/autoencoder_kl3d.py | 229 +++++++++--------- 1 file changed, 115 insertions(+), 114 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index e3122bb3ac2d..1fdac868c0c4 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -354,9 +354,9 @@ def __init__( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(in_features=temb_channels, out_features=out_channels) + self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) - self.dropout = torch.nn.Dropout(dropout) + self.dropout = nn.Dropout(dropout) self.conv2 = CogVideoXCausalConv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode @@ -373,7 +373,7 @@ def __init__( ) def forward( - self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None, *args, **kwargs + self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: Optional[torch.Tensor] = None, *args, **kwargs ) -> torch.Tensor: hidden_states = input_tensor if zq is not None: @@ -405,6 +405,112 @@ def forward( return output_tensor +# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged +class CogVideoXUpSample3D(nn.Module): + r""" + Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX + Model. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ + + def __init__(self, in_channels: int, out_channels: int, compress_time: bool = False): + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + x = torch.cat([x_first, x_rest], dim=2) + elif x.shape[2] > 1: + x = F.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = F.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + # only interpolate 2D + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = F.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + + return x + + +# Todo: Create vae_3d.py such as vae.py file? +class CogVideoXDownSample3D(nn.Module): + r""" + Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in + CogVideoX Model. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + compress_time (`bool`, *optional*, defaults to `False`): + Whether to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + compress_time: bool = False, + ): + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.compress_time = compress_time + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + + if x.shape[-1] % 2 == 1: + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] + + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + + return x + + class Encoder3D(nn.Module): r""" The `Encoder3D` layer of a variational autoencoder that encodes its input into a latent representation. @@ -482,9 +588,13 @@ def __init__( if i_level != self.num_resolutions - 1: if i_level < self.temporal_compress_level: - down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) + down.downsample = CogVideoXDownSample3D( + in_channels=block_in, out_channels=block_in, compress_time=True + ) else: - down.downsample = DownSample3D(in_channels=block_in, out_channels=block_in, compress_time=False) + down.downsample = CogVideoXDownSample3D( + in_channels=block_in, out_channels=block_in, compress_time=False + ) curr_res = curr_res // 2 self.down.append(down) @@ -691,115 +801,6 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: return hidden_states -## ==== After this is the special code of CogVideoX ==== ## - - -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXUpSample3D(nn.Module): - r""" - Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX - Model. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. - """ - - def __init__(self, in_channels: int, out_channels: int, compress_time: bool = False): - super().__init__() - - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.compress_time = compress_time - - def forward(self, x): - if self.compress_time: - if x.shape[2] > 1 and x.shape[2] % 2 == 1: - # split first frame - x_first, x_rest = x[:, :, 0], x[:, :, 1:] - - x_first = F.interpolate(x_first, scale_factor=2.0) - x_rest = F.interpolate(x_rest, scale_factor=2.0) - x_first = x_first[:, :, None, :, :] - x = torch.cat([x_first, x_rest], dim=2) - elif x.shape[2] > 1: - x = F.interpolate(x, scale_factor=2.0) - else: - x = x.squeeze(2) - x = F.interpolate(x, scale_factor=2.0) - x = x[:, :, None, :, :] - else: - # only interpolate 2D - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = F.interpolate(x, scale_factor=2.0) - x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - - return x - - -# Todo: Create vae_3d.py such as vae.py file? -class DownSample3D(nn.Module): - r""" - Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in - CogVideoX Model. - - Args: - in_channels (`int`, *optional*): - The number of input channels. - out_channels (`int`, *optional*): - The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - compress_time: bool = False, - ): - super(DownSample3D, self).__init__() - - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) - self.compress_time = compress_time - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.compress_time: - b, c, t, h, w = x.shape - x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) - - if x.shape[-1] % 2 == 1: - # split first frame - x_first, x_rest = x[..., 0], x[..., 1:] - - if x_rest.shape[-1] > 0: - x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) - x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - - else: - x = F.avg_pool1d(x, kernel_size=2, stride=2) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - - return x - - class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encodfing images into latents and decoding latent representations into images. From 470815cefaf71755d2d976a8350811b225c0e49b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 13:38:23 +0200 Subject: [PATCH 30/94] minor refactor --- src/diffusers/models/autoencoders/autoencoder_kl3d.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl3d.py index 1fdac868c0c4..0de9706c6476 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl3d.py @@ -447,12 +447,12 @@ def forward(self, x): b, c, t, h, w = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) x = F.interpolate(x, scale_factor=2.0) - x = x.reshape(b, t, c, x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) b, c, t, h, w = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) return x @@ -550,7 +550,7 @@ def __init__( pad_mode: str = "first", temporal_compress_times: int = 4, ): - super(Encoder3D, self).__init__() + super().__init__() self.act_fn = get_activation(act_fn) self.num_resolutions = len(block_out_channels) self.layers_per_block = layers_per_block @@ -689,7 +689,7 @@ def __init__( temporal_compress_times: int = 4, norm_num_groups=32, ): - super(Decoder3D, self).__init__() + super().__init__() self.act_fn = get_activation(act_fn) self.num_resolutions = len(block_out_channels) From e67cc5ae47d611dbd3db101ecafcec38e0d17bd9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 14:12:22 +0200 Subject: [PATCH 31/94] implement encode prompt --- src/diffusers/__init__.py | 3 +- src/diffusers/models/__init__.py | 3 +- src/diffusers/models/autoencoders/__init__.py | 2 +- ...er_kl3d.py => autoencoder_kl_cogvideox.py} | 2 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 130 +++++++++++++++++- 5 files changed, 132 insertions(+), 8 deletions(-) rename src/diffusers/models/autoencoders/{autoencoder_kl3d.py => autoencoder_kl_cogvideox.py} (99%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c9ba5e0e9cf8..3355c8e9e0ec 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -78,7 +78,7 @@ "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", - "AutoencoderKL3D", + "AutoencoderKLCogVideoX", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", @@ -522,6 +522,7 @@ AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, AutoencoderKL, + AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 231ad1a6945f..e05701c53b86 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -28,6 +28,7 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] @@ -77,7 +78,7 @@ from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderKL, - AutoencoderKL3D, + AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 6e6fd436797e..ccf4552b2a5e 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,6 +1,6 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL -from .autoencoder_kl3d import AutoencoderKL3D +from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny diff --git a/src/diffusers/models/autoencoders/autoencoder_kl3d.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py similarity index 99% rename from src/diffusers/models/autoencoders/autoencoder_kl3d.py rename to src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 0de9706c6476..a84fbd030bbd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl3d.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -801,7 +801,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: return hidden_states -class AutoencoderKL3D(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encodfing images into latents and decoding latent representations into images. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index ae3b42390e06..865fdb987807 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -21,7 +21,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer -from ...models import AutoencoderKL3D, CogVideoXTransformer3D +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3D from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -144,7 +144,7 @@ def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, - vae: AutoencoderKL3D, + vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3D, scheduler: KarrasDiffusionSchedulers, ): @@ -154,15 +154,135 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 255, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds def encode_prompt( self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 255, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): r""" - TODO: implement encode prompt with T5 XXL + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of video that should be generated per prompt. + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -224,5 +344,7 @@ def __call__( ) -> Union[CogVideoXPipelineOutput, Tuple]: """ TODO: implement forward pass + + Examples: """ pass From d45d199b99d826cb2d37c69a6217ccdb3609d359 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 14:13:38 +0200 Subject: [PATCH 32/94] make style --- .../pipelines/cogvideo/pipeline_cogvideox.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 865fdb987807..32d8ae5fea77 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -158,7 +158,7 @@ def __init__( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) - + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -225,9 +225,8 @@ def encode_prompt( less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): Whether to use classifier free guidance or not. - num_images_per_prompt (`int`, *optional*, defaults to 1): - Number of video that should be generated per prompt. - torch device to place the resulting embeddings on + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -247,7 +246,7 @@ def encode_prompt( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - + if prompt_embeds is None: prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt, @@ -280,9 +279,8 @@ def encode_prompt( device=device, dtype=dtype, ) - - return prompt_embeds, negative_prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): From 73469f9562d06bdafc3f352012b0ecbba97dfe0f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 14:13:56 +0200 Subject: [PATCH 33/94] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 699b158cbab7..44a20f57727b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLCogVideoX(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] From 45f7127ade87f9cc6508832d98d31a8395ad2e3f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 14:28:46 +0200 Subject: [PATCH 34/94] fix bug in handling long prompts --- scripts/convert_cogvideox_to_diffusers.py | 9 ++++----- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 4 ++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index a80e4ecdf155..23071106d60d 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -1,13 +1,12 @@ import argparse -from typing import Any, Dict, List, Tuple +from typing import Any, Dict import torch -import torch.nn as nn from diffusers import CogVideoXTransformer3D -def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tuple[str, nn.Module]]: +def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): to_q_key = key.replace("query_key_value", "to_q") to_k_key = key.replace("query_key_value", "to_k") to_v_key = key.replace("query_key_value", "to_v") @@ -18,7 +17,7 @@ def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]) -> Li state_dict.pop(key) -def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tuple[str, nn.Module]]: +def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): layer_id, weight_or_bias = key.split(".")[-2:] if "query" in key: @@ -29,7 +28,7 @@ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]) - state_dict[new_key] = state_dict.pop(key) -def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]) -> List[Tuple[str, nn.Module]]: +def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): layer_id, _, weight_or_bias = key.split(".")[-3:] weights_or_biases = state_dict[key].chunk(12, dim=0) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 32d8ae5fea77..3ded6efce1c9 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -157,6 +157,10 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 255 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) def _get_t5_prompt_embeds( From a449ceb3ef8749433025c3c76506e301dd47ebcc Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 31 Jul 2024 14:40:16 +0200 Subject: [PATCH 35/94] update conversion script --- scripts/convert_cogvideox_to_diffusers.py | 44 ++++++++++++++++++++--- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 23071106d60d..4d5593b33877 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -2,8 +2,9 @@ from typing import Any, Dict import torch +from transformers import T5EncoderModel, T5Tokenizer -from diffusers import CogVideoXTransformer3D +from diffusers import CogVideoXPipeline, CogVideoXTransformer3D, DPMSolverMultistepScheduler def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): @@ -86,7 +87,7 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: state_dict[new_key] = state_dict.pop(old_key) -def convert_transformer(ckpt_path: str, output_path: str, fp16: bool = False, push_to_hub: bool = False) -> None: +def convert_transformer(ckpt_path: str): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) @@ -105,14 +106,20 @@ def convert_transformer(ckpt_path: str, output_path: str, fp16: bool = False, pu handler_fn_inplace(key, original_state_dict) transformer.load_state_dict(original_state_dict, strict=True) - transformer.save_pretrained(output_path) + return transformer + + +def convert_vae(ckpt_path: str): + # TODO: wait for implementation + pass def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--transformer_ckpt_path", type=str, default=None, help="Path to original transformercheckpoint" + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") parser.add_argument( @@ -124,5 +131,32 @@ def get_args(): if __name__ == "__main__": args = get_args() + transformer = None + vae = None + if args.transformer_ckpt_path is not None: - convert_transformer(args.transformer_ckpt_path, args.output_path, args.fp16, args.push_to_hub) + transformer = convert_transformer(args.transformer_ckpt_path) + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id) + + # TODO: verify with authors + scheduler = DPMSolverMultistepScheduler.from_pretrained( + "runwayml/stable-diffusion-v1-5", + subfolder="scheduler", + algorithm_type="sde-dpmsolver++", + prediction_type="v_prediction", + ) + + pipe = CogVideoXPipeline( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + if args.fp16: + pipe = pipe.to(dtype=torch.float16) + + variant = "fp16" if args.fp16 else None + pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, push_to_hub=args.push_to_hub) From 4498cfc98caead65a0f40f4e882bff5564bcc244 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Wed, 31 Jul 2024 21:41:09 +0800 Subject: [PATCH 36/94] add doc draft --- .../en/api/models/autoencoderkl_cogvideox.md | 37 +++ .../en/api/models/cogvideox_transformer3d.md | 19 ++ docs/source/en/api/pipelines/cogvideox.md | 76 +++++ .../autoencoders/autoencoder_kl_cogvideox.py | 10 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 3 +- tests/pipelines/cogvideox/__init__.py | 0 tests/pipelines/cogvideox/test_cogvideox.py | 297 ++++++++++++++++++ 7 files changed, 434 insertions(+), 8 deletions(-) create mode 100644 docs/source/en/api/models/autoencoderkl_cogvideox.md create mode 100644 docs/source/en/api/models/cogvideox_transformer3d.md create mode 100644 docs/source/en/api/pipelines/cogvideox.md create mode 100644 tests/pipelines/cogvideox/__init__.py create mode 100644 tests/pipelines/cogvideox/test_cogvideox.py diff --git a/docs/source/en/api/models/autoencoderkl_cogvideox.md b/docs/source/en/api/models/autoencoderkl_cogvideox.md new file mode 100644 index 000000000000..e876092e06af --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_cogvideox.md @@ -0,0 +1,37 @@ + + +# AutoencoderKLCogVideoX + +The 3D variational autoencoder (VAE) model with KL loss using with CogVideoX. + +The abstract from the paper is: + + +## Loading from the original format + +By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded +from the original format using [`FromOriginalModelMixin.from_single_file`] as follows: + +```py +from diffusers import AutoencoderKLCogVideoX + +url = "3d-vae.pt" # can also be a local file +model = AutoencoderKLCogVideoX.from_single_file(url) +``` + +## AutoencoderKLCogVideoX + +[[autodoc]] AutoencoderKLCogVideoX + - decode + - encode + - all diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md new file mode 100644 index 000000000000..1ef71636820e --- /dev/null +++ b/docs/source/en/api/models/cogvideox_transformer3d.md @@ -0,0 +1,19 @@ + + +## CogVideoXTransformer3D + +A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX). + +## CogVideoXTransformer3D + +[[autodoc]] CogVideoXTransformer3D diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md new file mode 100644 index 000000000000..9641c0965685 --- /dev/null +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -0,0 +1,76 @@ + + +# CogVideoX + + +[The paper is still being written]() from Tsinghua University & ZhipuAI. + +The abstract from the paper is: + +The paper is still being written. + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +### Inference + +Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. + +First, load the pipeline: + +```python +import torch +from diffusers import LattePipeline + +pipeline = LattePipeline.from_pretrained( + "THUDM/CogVideoX", torch_dtype=torch.bfloat16 +).to("cuda") +``` + +Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: + +```python +pipeline.transformer.to(memory_format=torch.channels_last) +pipeline.vae.to(memory_format=torch.channels_last) +``` + +Finally, compile the components and run inference: + +```python +pipeline.transformer = torch.compile(pipeline.transformer) +pipeline.vae.decode = torch.compile(pipeline.vae.decode) + +video = pipeline(prompt="A dog wearing sunglasses floating in space, surreal, nebulae in background").frames[0] +``` + +The [benchmark]() results on an 80GB A100 machine are: + +``` +Without torch.compile(): Average inference time: 16.246 seconds. +With torch.compile(): Average inference time: 14.573 seconds. +``` + +## CogVideoXPipeline + +[[autodoc]] CogVideoXPipeline + - all + - __call__ + +## CogVideoXPipelineOutput +[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput + diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index a84fbd030bbd..a7aeb0942d9d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -295,9 +295,9 @@ def __init__( ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv = CogVideoXCausalConv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=0) - self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv = CogVideoXCausalConv3d(zq_channels, zq_channels, kernel_size=3, stride=1) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: if zq.shape[2] > 1: @@ -309,7 +309,7 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: zq = torch.cat([z_first, z_rest], dim=2) else: zq = F.interpolate(zq, size=f.shape[-3:]) - zq = self.conv(zq) + norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f @@ -702,8 +702,6 @@ def __init__( block_in = block_out_channels[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, in_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - self.conv_in = CogVideoXCausalConv3d(in_channels, block_in, kernel_size=3, pad_mode=pad_mode) # middle diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 32d8ae5fea77..136109bc45d5 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -1,5 +1,4 @@ -# TODO: Ask collaborators about license -# Copyright 2024 The CogVideoX Team and The HuggingFace Team. +# Copyright 2024 The The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/pipelines/cogvideox/__init__.py b/tests/pipelines/cogvideox/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py new file mode 100644 index 000000000000..db0369306773 --- /dev/null +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -0,0 +1,297 @@ +#Todo: Only a Draft + +# coding=utf-8 +# Copyright 2024 The The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3D +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = CogVideoXPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CogVideoXTransformer3D( + sample_size=8, + num_layers=1, + patch_size=2, + attention_head_dim=8, + num_attention_heads=3, + caption_channels=32, + in_channels=4, + cross_attention_dim=24, + out_channels=8, + attention_bias=True, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + ) + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = DDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "negative_prompt": "low quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "video_length": 1, + "output_type": "pt", + "clean_caption": False, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (1, 3, 8, 8)) + expected_video = torch.randn(1, 3, 8, 8) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass(self): + pass + + def test_save_load_optional_components(self): + if not hasattr(self.pipeline_class, "_optional_components"): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + + ( + prompt_embeds, + negative_prompt_embeds, + ) = pipe.encode_prompt(prompt) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "negative_prompt": None, + "negative_prompt_embeds": negative_prompt_embeds, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "video_length": 1, + "mask_feature": False, + "output_type": "pt", + "clean_caption": False, + } + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1.0) + + +@slow +@require_torch_gpu +class CogVideoXPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_cogvideox(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = CogVideoXPipeline.from_pretrained("THUDM/cogvideox-2b", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + videos = pipe( + prompt=prompt, + height=512, + width=512, + generator=generator, + num_inference_steps=2, + clean_caption=False, + ).frames + + video = videos[0] + expected_video = torch.randn(1, 512, 512, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video.fCogVideoXn(), expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video.fCogVideoXn()}" From bb4740ce29ddc6cdbfed5f19fbb0e56b76ccd760 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Thu, 1 Aug 2024 00:26:59 +0800 Subject: [PATCH 37/94] add clear_fake_cp_cache --- .../autoencoders/autoencoder_kl_cogvideox.py | 135 ++++++++++-------- tests/pipelines/cogvideox/test_cogvideox.py | 9 +- 2 files changed, 75 insertions(+), 69 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index a7aeb0942d9d..08d78c2842f4 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -167,7 +167,7 @@ # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXSaveConv3d(nn.Conv3d): +class CogVideoXSafeConv3d(nn.Conv3d): """ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. """ @@ -228,9 +228,12 @@ def cast_tuple(t, length=1): self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + self.temporal_dim = 2 + self.time_kernel_size = time_kernel_size + stride = (stride, 1, 1) dilation = (dilation, 1, 1) - self.conv = CogVideoXSaveConv3d( + self.conv = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -240,37 +243,36 @@ def cast_tuple(t, length=1): self.conv_cache = None - def forward(self, x): - if self.pad_mode == "constant": - causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_3d, mode="constant", value=0) - elif self.pad_mode == "first": - pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) - x = torch.cat([pad_x, x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - elif self.pad_mode == "reflect": - reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) - if reflect_x.shape[2] < self.time_pad: - reflect_x = torch.cat( - [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 - ) - x = torch.cat([reflect_x, x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) + def fake_cp_pass_from_previous_rank(self, input_): + dim = self.temporal_dim + kernel_size = self.time_kernel_size + if kernel_size == 1: + return input_ + + input_ = input_.transpose(0, dim) + + if self.conv_cache is not None: + input_ = torch.cat([self.conv_cache.transpose(0, dim).to(input_.device), input_], dim=0) else: - raise ValueError("Invalid pad mode") - if self.time_pad != 0 and self.conv_cache is None: - self.conv_cache = x[:, :, -self.time_pad :].detach().clone().cpu() - return self.conv(x) - elif self.time_pad != 0 and self.conv_cache is not None: - x = torch.cat([self.conv_cache.to(x.device), x], dim=2) - causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - x = F.pad(x, causal_padding_2d, mode="constant", value=0) - self.conv_cache = None - return self.conv(x) + input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + + input_ = input_.transpose(0, dim).contiguous() + return input_ + + def forward(self, input_, clear_fake_cp_cache=True): + input_parallel = self.fake_cp_pass_from_previous_rank(input_) - return self.conv(x) + del self.conv_cache + self.conv_cache = None + if not clear_fake_cp_cache: + self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) + + output_parallel = self.conv(input_parallel) + output = output_parallel + return output # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged @@ -279,7 +281,7 @@ class CogVideoXSpatialNorm3D(nn.Module): Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific to 3D-video like data. - CogVideoXSaveConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. Args: f_channels (`int`): @@ -295,12 +297,11 @@ def __init__( ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv = CogVideoXCausalConv3d(zq_channels, zq_channels, kernel_size=3, stride=1) self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - if zq.shape[2] > 1: + def forward(self, f: torch.Tensor, zq: torch.Tensor, clear_fake_cp_cache=True) -> torch.Tensor: + if f.shape[2] > 1 and f.shape[2] % 2 == 1: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] @@ -368,35 +369,41 @@ def __init__( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) else: - self.nin_shortcut = CogVideoXSaveConv3d( + self.nin_shortcut = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) def forward( - self, input_tensor: torch.Tensor, temb: torch.Tensor, zq: Optional[torch.Tensor] = None, *args, **kwargs + self, + input_tensor: torch.Tensor, + temb: torch.Tensor, + zq: torch.Tensor = None, + clear_fake_cp_cache: bool = True, + *args, + **kwargs, ) -> torch.Tensor: hidden_states = input_tensor if zq is not None: - hidden_states = self.norm1(hidden_states, zq) + hidden_states = self.norm1(hidden_states, zq, clear_fake_cp_cache=clear_fake_cp_cache) else: hidden_states = self.norm1(hidden_states) hidden_states = self.non_linearity(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) if temb is not None: hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] if zq is not None: - hidden_states = self.norm2(hidden_states, zq) + hidden_states = self.norm2(hidden_states, zq, clear_fake_cp_cache=clear_fake_cp_cache) else: hidden_states = self.norm2(hidden_states) hidden_states = self.non_linearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor, clear_fake_cp_cache=clear_fake_cp_cache) else: input_tensor = self.nin_shortcut(input_tensor) @@ -626,23 +633,23 @@ def __init__( block_in, conv_out_channels if double_z else out_channels, kernel_size=3, pad_mode=pad_mode ) - def forward(self, sample: torch.Tensor) -> torch.Tensor: + def forward(self, sample: torch.Tensor, clear_fake_cp_cache=True) -> torch.Tensor: # timestep embedding temb = None # DownSampling - sample = self.conv_in(sample) + sample = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) for i_level in range(self.num_resolutions): for i_block in range(self.layers_per_block): - sample = self.down[i_level].block[i_block](sample, temb) + sample = self.down[i_level].block[i_block](sample, temb, clear_fake_cp_cache=clear_fake_cp_cache) if i_level != self.num_resolutions - 1: sample = self.down[i_level].downsample(sample) # middle - sample = self.mid.block_1(sample, temb) - sample = self.mid.block_2(sample, temb) + sample = self.mid.block_1(sample, temb, clear_fake_cp_cache=clear_fake_cp_cache) + sample = self.mid.block_2(sample, temb, clear_fake_cp_cache=clear_fake_cp_cache) # post-process sample = self.norm_out(sample) @@ -766,24 +773,26 @@ def __init__( self.conv_out = CogVideoXCausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) - def forward(self, sample: torch.Tensor) -> torch.Tensor: + def forward(self, sample: torch.Tensor, clear_fake_cp_cache=True) -> torch.Tensor: r"""The forward method of the `Decoder` class.""" # timestep embedding temb = None - hidden_states = self.conv_in(sample) + hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) # middle - hidden_states = self.mid.block_1(hidden_states, temb, sample) + hidden_states = self.mid.block_1(hidden_states, temb, sample, clear_fake_cp_cache=clear_fake_cp_cache) - hidden_states = self.mid.block_2(hidden_states, temb, sample) + hidden_states = self.mid.block_2(hidden_states, temb, sample, clear_fake_cp_cache=clear_fake_cp_cache) # UpSampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.layers_per_block + 1): - hidden_states = self.up[i_level].block[i_block](hidden_states, temb, sample) + hidden_states = self.up[i_level].block[i_block]( + hidden_states, temb, sample, clear_fake_cp_cache=clear_fake_cp_cache + ) if i_level != 0: hidden_states = self.up[i_level].upsample(hidden_states) @@ -792,9 +801,9 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: if self.give_pre_end: return hidden_states - hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.norm_out(hidden_states, sample, clear_fake_cp_cache=clear_fake_cp_cache) hidden_states = self.act_fn(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) return hidden_states @@ -880,8 +889,8 @@ def __init__( act_fn=act_fn, resolution=sample_size, ) - self.quant_conv = CogVideoXSaveConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None - self.post_quant_conv = CogVideoXSaveConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None self.use_slicing = False self.use_tiling = False @@ -930,7 +939,7 @@ def disable_slicing(self): @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, return_dict: bool = True, fake_cp: bool = False ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -939,12 +948,13 @@ def encode( x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. - + fake_cp (`bool`, *optional*, defaults to `True`): + If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - h = self.encoder(x) + h = self.encoder(x, clear_fake_cp_cache=not fake_cp) if self.quant_conv is not None: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) @@ -954,7 +964,7 @@ def encode( @apply_forward_hook def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None + self, z: torch.FloatTensor, return_dict: bool = True, generator=None, fake_cp: bool = False ) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. @@ -963,7 +973,8 @@ def decode( z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - + fake_cp (`bool`, *optional*, defaults to `True`): + If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is @@ -972,7 +983,7 @@ def decode( """ if self.post_quant_conv is not None: z = self.post_quant_conv(z) - dec = self.decoder(z) + dec = self.decoder(z, clear_fake_cp_cache=not fake_cp) if not return_dict: return (dec,) return dec diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index db0369306773..0da39160bd18 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -1,4 +1,4 @@ -#Todo: Only a Draft +# Todo: Only a Draft # coding=utf-8 # Copyright 2024 The The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. @@ -24,12 +24,7 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - CogVideoXPipeline, - CogVideoXTransformer3D -) +from diffusers import AutoencoderKL, DDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3D from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, From e05f83479c0ca3afaa05ee6ecb7e8d89d12f4480 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 1 Aug 2024 06:05:49 +0200 Subject: [PATCH 38/94] refactor vae --- scripts/convert_cogvideox_to_diffusers.py | 54 +- .../autoencoders/autoencoder_kl_cogvideox.py | 769 ++++++++---------- src/diffusers/models/downsampling.py | 61 ++ .../transformers/cogvideox_transformer_3d.py | 4 +- src/diffusers/models/upsampling.py | 64 ++ 5 files changed, 515 insertions(+), 437 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 4d5593b33877..1ef11a80ef24 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -4,7 +4,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import CogVideoXPipeline, CogVideoXTransformer3D, DPMSolverMultistepScheduler +from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3D, DPMSolverMultistepScheduler def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): @@ -45,6 +45,22 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): state_dict.pop(key) +def remove_loss_keys_inplace(key: str, state_dict: Dict[str, Any]): + state_dict.pop(key) + + +def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): + key_split = key.split(".") + layer_index = int(key_split[2]) + replace_layer_index = 4 - 1 - layer_index + + key_split[1] = "up_blocks" + key_split[2] = str(replace_layer_index) + new_key = ".".join(key_split) + + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { "transformer.final_layernorm": "norm_final", "transformer": "transformer_blocks", @@ -71,6 +87,23 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, } +VAE_KEYS_RENAME_DICT = { + "block.": "resnets.", + "down.": "down_blocks.", + "downsample": "downsamplers.0", + "upsample": "upsamplers.0", + "nin_shortcut": "conv_shortcut", + "encoder.mid.block_1": "encoder.mid_block.resnets.0", + "encoder.mid.block_2": "encoder.mid_block.resnets.1", + "decoder.mid.block_1": "decoder.mid_block.resnets.0", + "decoder.mid.block_2": "decoder.mid_block.resnets.1", +} + +VAE_SPECIAL_KEYS_REMAP = { + "loss": remove_loss_keys_inplace, + "up.": replace_up_keys_inplace, +} + def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = saved_dict @@ -110,8 +143,23 @@ def convert_transformer(ckpt_path: str): def convert_vae(ckpt_path: str): - # TODO: wait for implementation - pass + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + vae = AutoencoderKLCogVideoX() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True) + return vae def get_args(): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 08d78c2842f4..4406cd04769c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -9,163 +9,13 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation +from ..downsampling import Downsample3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin +from ..upsampling import Upsample3D from .vae import DecoderOutput, DiagonalGaussianDistribution -## == Basic Block of 3D VAE Model design in CogVideoX === ### - - -## Draft of block -# class DownEncoderBlock3D(nn.Module): -# def __init__( -# self, -# in_channels: int, -# out_channels: int, -# dropout: float = 0.0, -# num_layers: int = 1, -# resnet_eps: float = 1e-6, -# resnet_act_fn: str = "swish", -# resnet_groups: int = 32, -# resnet_pre_norm: bool = True, -# pad_mode: str = "first", -# ): -# super().__init__() -# resnets = [] -# -# for i in range(num_layers): -# resnets.append( -# CogVideoXResnetBlock3D( -# in_channels=in_channels if i == 0 else out_channels, -# out_channels=out_channels, -# temb_channels=0, -# eps=resnet_eps, -# groups=resnet_groups, -# dropout=dropout, -# non_linearity=resnet_act_fn, -# conv_shortcut=resnet_pre_norm, -# pad_mode=pad_mode, -# ) -# ) -# in_channels = out_channels -# -# self.resnets = nn.ModuleList(resnets) -# self.downsampler = DownSample3D(in_channels=out_channels, out_channels=out_channels) if num_layers > 0 else None -# -# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: -# for resnet in self.resnets: -# hidden_states = resnet(hidden_states, temb=None) -# -# if self.downsampler is not None: -# hidden_states = self.downsampler(hidden_states) -# -# return hidden_states -# -# -# class Encoder3D(nn.Module): -# def __init__( -# self, -# in_channels: int = 3, -# out_channels: int = 16, -# down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), -# block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), -# layers_per_block: int = 3, -# act_fn: str = "silu", -# norm_num_groups: int = 32, -# dropout: float = 0.0, -# resolution: int = 256, -# double_z: bool = True, -# pad_mode: str = "first", -# temporal_compress_times: int = 4, -# ): -# super().__init__() -# self.act_fn = get_activation(act_fn) -# self.num_resolutions = len(block_out_channels) -# self.layers_per_block = layers_per_block -# self.resolution = resolution -# -# # log2 of temporal_compress_times -# self.temporal_compress_level = int(np.log2(temporal_compress_times)) -# -# self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) -# -# self.down_blocks = nn.ModuleList() -# self.downsamples = nn.ModuleList() -# -# for i_level in range(self.num_resolutions): -# block_in = block_out_channels[i_level - 1] if i_level > 0 else block_out_channels[0] -# block_out = block_out_channels[i_level] -# is_final_block = i_level == self.num_resolutions - 1 -# -# down_block = DownEncoderBlock3D( -# in_channels=block_in, -# out_channels=block_out, -# num_layers=self.layers_per_block, -# dropout=dropout, -# resnet_eps=1e-6, -# resnet_act_fn=act_fn, -# resnet_groups=norm_num_groups, -# resnet_pre_norm=True, -# pad_mode=pad_mode, -# ) -# self.down_blocks.append(down_block) -# -# if not is_final_block: -# compress_time = i_level < self.temporal_compress_level -# self.downsamples.append( -# DownSample3D(in_channels=block_out, out_channels=block_out, compress_time=compress_time) -# ) -# -# # middle -# block_in = block_out_channels[-1] -# self.mid_block_1 = CogVideoXResnetBlock3D( -# in_channels=block_in, -# out_channels=block_in, -# non_linearity=act_fn, -# temb_channels=0, -# groups=norm_num_groups, -# dropout=dropout, -# pad_mode=pad_mode, -# ) -# self.mid_block_2 = CogVideoXResnetBlock3D( -# in_channels=block_in, -# out_channels=block_in, -# non_linearity=act_fn, -# temb_channels=0, -# groups=norm_num_groups, -# dropout=dropout, -# pad_mode=pad_mode, -# ) -# -# # out -# self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) -# self.conv_act = get_activation(act_fn) -# -# conv_out_channels = 2 * out_channels if double_z else out_channels -# self.conv_out = CogVideoXCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, pad_mode=pad_mode) -# -# def forward(self, sample: torch.Tensor) -> torch.Tensor: -# temb = None -# -# # DownSampling -# sample = self.conv_in(sample) -# for i_level in range(self.num_resolutions): -# sample = self.down_blocks[i_level](sample) -# if i_level < len(self.downsamples): -# sample = self.downsamples[i_level](sample) -# -# sample = self.mid_block_1(sample, temb) -# sample = self.mid_block_2(sample, temb) -# -# # post-process -# sample = self.norm_out(sample) -# sample = self.conv_act(sample) -# sample = self.conv_out(sample) -# -# return sample - - # Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXSafeConv3d(nn.Conv3d): """ @@ -369,18 +219,16 @@ def __init__( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) else: - self.nin_shortcut = CogVideoXSafeConv3d( + self.conv_shortcut = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) def forward( self, input_tensor: torch.Tensor, - temb: torch.Tensor, - zq: torch.Tensor = None, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True, - *args, - **kwargs, ) -> torch.Tensor: hidden_states = input_tensor if zq is not None: @@ -405,117 +253,189 @@ def forward( if self.use_conv_shortcut: input_tensor = self.conv_shortcut(input_tensor, clear_fake_cp_cache=clear_fake_cp_cache) else: - input_tensor = self.nin_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden_states return output_tensor -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged -class CogVideoXUpSample3D(nn.Module): - r""" - Add compress_time option to the `UpSample` layer of a variational autoencoder that upsamples its input in CogVideoX - Model. +class CogVideoXDownBlock3D(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 0, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. - """ + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) - def __init__(self, in_channels: int, out_channels: int, compress_time: bool = False): + self.resnets = nn.ModuleList(resnets) + self.downsamplers = None + + if add_downsample: + self.downsamplers = nn.ModuleList([ + Downsample3D(out_channels, out_channels, padding=downsample_padding, compress_time=compress_time) + ]) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, clear_fake_cp_cache + ) + else: + hidden_states = resnet(hidden_states, temb, clear_fake_cp_cache) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CogVideoXMidBlock3D(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.compress_time = compress_time - - def forward(self, x): - if self.compress_time: - if x.shape[2] > 1 and x.shape[2] % 2 == 1: - # split first frame - x_first, x_rest = x[:, :, 0], x[:, :, 1:] - - x_first = F.interpolate(x_first, scale_factor=2.0) - x_rest = F.interpolate(x_rest, scale_factor=2.0) - x_first = x_first[:, :, None, :, :] - x = torch.cat([x_first, x_rest], dim=2) - elif x.shape[2] > 1: - x = F.interpolate(x, scale_factor=2.0) + resnets = [] + for _ in range(num_layers): + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + spatial_norm_dim=spatial_norm_dim, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, clear_fake_cp_cache + ) else: - x = x.squeeze(2) - x = F.interpolate(x, scale_factor=2.0) - x = x[:, :, None, :, :] - else: - # only interpolate 2D - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = F.interpolate(x, scale_factor=2.0) - x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) - - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) - - return x - - -# Todo: Create vae_3d.py such as vae.py file? -class CogVideoXDownSample3D(nn.Module): - r""" - Add compress_time option to the `DownSample` layer of a variational autoencoder that downsamples its input in - CogVideoX Model. + hidden_states = resnet(hidden_states, temb, clear_fake_cp_cache) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states - Args: - in_channels (`int`, *optional*): - The number of input channels. - out_channels (`int`, *optional*): - The number of output channels. - compress_time (`bool`, *optional*, defaults to `False`): - Whether to compress the time dimension. - """ +class CogVideoXUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: int = 16, + add_upsample: bool = True, + upsample_padding: int = 0, compress_time: bool = False, + pad_mode: str = "first", ): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) - self.compress_time = compress_time - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.compress_time: - b, c, t, h, w = x.shape - x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) - - if x.shape[-1] % 2 == 1: - # split first frame - x_first, x_rest = x[..., 0], x[..., 1:] - - if x_rest.shape[-1] > 0: - x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) - x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - - else: - x = F.avg_pool1d(x, kernel_size=2, stride=2) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_norm_dim=spatial_norm_dim, + pad_mode=pad_mode, + ) + ) - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + self.resnets = nn.ModuleList(resnets) + self.upsamplers = None - return x + if add_upsample: + self.upsamplers = nn.ModuleList([ + Upsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time) + ]) class Encoder3D(nn.Module): @@ -542,120 +462,114 @@ class Encoder3D(nn.Module): Whether to double the number of output channels for the last block. """ + _supports_gradient_checkpointing = True + def __init__( self, in_channels: int = 3, out_channels: int = 16, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), + down_block_types: Tuple[str, ...] = ("CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D",), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", + norm_eps: float = 1e-6, norm_num_groups: int = 32, dropout: float = 0.0, - resolution: int = 256, - double_z: bool = True, pad_mode: str = "first", - temporal_compress_times: int = 4, + temporal_compression_ratio: float = 4, ): super().__init__() self.act_fn = get_activation(act_fn) self.num_resolutions = len(block_out_channels) self.layers_per_block = layers_per_block - self.resolution = resolution - + # log2 of temporal_compress_times - self.temporal_compress_level = int(np.log2(temporal_compress_times)) + temporal_compress_level = int(np.log2(temporal_compression_ratio)) self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) - - curr_res = resolution - in_ch_mult = (block_out_channels[0],) + tuple(block_out_channels) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - - block_in = in_ch_mult[i_level] - block_out = block_out_channels[i_level] - - for i_block in range(self.layers_per_block): - block.append( - CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_out, - temb_channels=0, - non_linearity=act_fn, - dropout=dropout, - groups=norm_num_groups, - pad_mode=pad_mode, - ) + self.down_blocks = nn.ModuleList([]) + + # down blocks + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if down_block_type == "CogVideoXDownBlock3D": + down_block = CogVideoXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + compress_time=compress_time, ) - block_in = block_out - down = nn.Module() - down.block = block - - if i_level != self.num_resolutions - 1: - if i_level < self.temporal_compress_level: - down.downsample = CogVideoXDownSample3D( - in_channels=block_in, out_channels=block_in, compress_time=True - ) - else: - down.downsample = CogVideoXDownSample3D( - in_channels=block_in, out_channels=block_in, compress_time=False - ) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - block_in = in_ch_mult[-1] - self.mid.block_1 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - non_linearity=act_fn, - temb_channels=0, - groups=norm_num_groups, - dropout=dropout, - pad_mode=pad_mode, - ) - self.mid.block_2 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - non_linearity=act_fn, + else: + raise ValueError( + "Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`" + ) + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=block_out_channels[-1], temb_channels=0, - groups=norm_num_groups, dropout=dropout, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, pad_mode=pad_mode, ) - self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=norm_num_groups, eps=1e-6) - conv_out_channels = 2 * out_channels if double_z else out_channels + self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( - block_in, conv_out_channels if double_z else out_channels, kernel_size=3, pad_mode=pad_mode + block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode ) - def forward(self, sample: torch.Tensor, clear_fake_cp_cache=True) -> torch.Tensor: - # timestep embedding - - temb = None + self.gradient_checkpointing = False - # DownSampling + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True) -> torch.Tensor: + r"""The forward method of the `Encoder3D` class.""" sample = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) - for i_level in range(self.num_resolutions): - for i_block in range(self.layers_per_block): - sample = self.down[i_level].block[i_block](sample, temb, clear_fake_cp_cache=clear_fake_cp_cache) - - if i_level != self.num_resolutions - 1: - sample = self.down[i_level].downsample(sample) - # middle - sample = self.mid.block_1(sample, temb, clear_fake_cp_cache=clear_fake_cp_cache) - sample = self.mid.block_2(sample, temb, clear_fake_cp_cache=clear_fake_cp_cache) - - # post-process + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Down + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, temb, clear_fake_cp_cache + ) + + # 2. Mid + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, temb, clear_fake_cp_cache + ) + else: + # 1. Down + for down_block in self.down_blocks: + sample = down_block(sample, temb, clear_fake_cp_cache) + + # 2. Mid + sample = self.mid_block(sample, temb, clear_fake_cp_cache) + + # 3. Post-process sample = self.norm_out(sample) - sample = self.act_fn(sample) + sample = self.conv_act(sample) sample = self.conv_out(sample) - return sample @@ -682,130 +596,117 @@ class Decoder3D(nn.Module): The normalization type to use. Can be either `"group"` or `"spatial"`. """ + _supports_gradient_checkpointing = True + def __init__( self, in_channels: int = 16, out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D",), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, dropout: float = 0.0, - resolution: int = 256, - give_pre_end: bool = False, pad_mode: str = "first", - temporal_compress_times: int = 4, - norm_num_groups=32, + temporal_compression_ratio: float = 4, ): super().__init__() - self.act_fn = get_activation(act_fn) - self.num_resolutions = len(block_out_channels) - self.layers_per_block = layers_per_block - self.resolution = resolution - self.give_pre_end = give_pre_end - self.norm_num_groups = norm_num_groups - self.temporal_compress_level = int(np.log2(temporal_compress_times)) - - block_in = block_out_channels[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, in_channels, curr_res, curr_res) - self.conv_in = CogVideoXCausalConv3d(in_channels, block_in, kernel_size=3, pad_mode=pad_mode) - - # middle - self.mid = nn.Module() - self.mid.block_1 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - temb_channels=0, - dropout=dropout, - non_linearity=act_fn, - spatial_norm_dim=in_channels, - groups=norm_num_groups, - pad_mode=pad_mode, - ) + reversed_block_out_channels = list(reversed(block_out_channels)) + + resolution = block_out_channels[-1] // 2 ** (len(block_out_channels) - 1) + self.z_shape = (1, in_channels, resolution, resolution) + self.conv_in = CogVideoXCausalConv3d(in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode) - self.mid.block_2 = CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_in, + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=reversed_block_out_channels[0], temb_channels=0, - dropout=dropout, - non_linearity=act_fn, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, spatial_norm_dim=in_channels, - groups=norm_num_groups, pad_mode=pad_mode, ) - # UpSampling - - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - - block_out = block_out_channels[i_level] - for i_block in range(self.layers_per_block + 1): - block.append( - CogVideoXResnetBlock3D( - in_channels=block_in, - out_channels=block_out, - temb_channels=0, - non_linearity=act_fn, - dropout=dropout, - spatial_norm_dim=in_channels, - groups=norm_num_groups, - pad_mode=pad_mode, - ) + # up blocks + self.up_blocks = nn.ModuleList([]) + + output_channel = reversed_block_out_channels[0] + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if up_block_type == "CogVideoXUpBlock3D": + up_block = CogVideoXUpBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final_block, + compress_time=compress_time, + pad_mode=pad_mode, ) - block_in = block_out - - up = nn.Module() - up.block = block - - if i_level != 0: - if i_level < self.num_resolutions - self.temporal_compress_level: - up.upsample = CogVideoXUpSample3D(in_channels=block_in, out_channels=block_in, compress_time=False) - else: - up.upsample = CogVideoXUpSample3D(in_channels=block_in, out_channels=block_in, compress_time=True) - curr_res = curr_res * 2 - - self.up.insert(0, up) - - self.norm_out = CogVideoXSpatialNorm3D(f_channels=block_in, zq_channels=in_channels) - - self.conv_out = CogVideoXCausalConv3d(block_in, out_channels, kernel_size=3, pad_mode=pad_mode) + prev_output_channel = output_channel + else: + raise ValueError( + "Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`" + ) + + self.up_blocks.append(up_block) - def forward(self, sample: torch.Tensor, clear_fake_cp_cache=True) -> torch.Tensor: - r"""The forward method of the `Decoder` class.""" - # timestep embedding + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d(reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) - temb = None + self.gradient_checkpointing = False - hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True) -> torch.Tensor: + r"""The forward method of the `Decoder3D` class.""" + sample = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) - # middle - hidden_states = self.mid.block_1(hidden_states, temb, sample, clear_fake_cp_cache=clear_fake_cp_cache) + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) - hidden_states = self.mid.block_2(hidden_states, temb, sample, clear_fake_cp_cache=clear_fake_cp_cache) + return custom_forward - # UpSampling + # 1. Mid + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, temb, clear_fake_cp_cache + ) - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.layers_per_block + 1): - hidden_states = self.up[i_level].block[i_block]( - hidden_states, temb, sample, clear_fake_cp_cache=clear_fake_cp_cache + # 2. Up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), sample, temb, clear_fake_cp_cache ) - - if i_level != 0: - hidden_states = self.up[i_level].upsample(hidden_states) - - # end - if self.give_pre_end: - return hidden_states - - hidden_states = self.norm_out(hidden_states, sample, clear_fake_cp_cache=clear_fake_cp_cache) - hidden_states = self.act_fn(hidden_states) - hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) - - return hidden_states + else: + # 1. Mid + sample = self.mid_block(sample, temb, clear_fake_cp_cache) + + # 2. Up + for up_block in self.up_blocks: + sample = up_block(sample, temb, clear_fake_cp_cache) + + # 3. Post-process + sample = self.norm_out(sample, sample, clear_fake_cp_cache=clear_fake_cp_cache) + sample = self.conv_act(sample) + sample = self.conv_out(sample, clear_fake_cp_cache=clear_fake_cp_cache) + return sample class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -849,14 +750,16 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): def __init__( self, in_channels: int = 3, - out_channels: int = 16, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + out_channels: int = 3, + down_block_types: Tuple[str] = ("CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D"), + up_block_types: Tuple[str] = ("CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D",), block_out_channels: Tuple[int] = (128, 256, 256, 512), latent_channels: int = 16, layers_per_block: int = 3, act_fn: str = "silu", + norm_eps: float = 1e-6, norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, sample_size: int = 256, scaling_factor: float = 1.15258426, shift_factor: Optional[float] = None, @@ -872,22 +775,24 @@ def __init__( self.encoder = Encoder3D( in_channels=in_channels, out_channels=latent_channels, - block_out_channels=block_out_channels, down_block_types=down_block_types, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, + norm_eps=norm_eps, norm_num_groups=norm_num_groups, - double_z=True, - resolution=sample_size, + temporal_compression_ratio=temporal_compression_ratio, ) self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, + up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, act_fn=act_fn, - resolution=sample_size, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, ) self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 4e384e731c74..e107117c213e 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -285,6 +285,67 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.conv2d(inputs, weight, stride=2) +class Downsample3D(nn.Module): + r""" + A 3D Downsampling layer. + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `2`): + Stride of the convolution. + padding (`int`, defaults to `0`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 2, + padding: int = 0, + compress_time: bool = False, + ): + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + + if x.shape[-1] % 2 == 1: + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] + + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + + return x + + def downsample_2d( hidden_states: torch.Tensor, kernel: Optional[torch.Tensor] = None, diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 8263afa1ed0d..45e7d62eed2f 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -232,7 +232,7 @@ def __init__( sample_height: int = 60, sample_frames: int = 49, patch_size: int = 2, - time_compression: int = 4, + temporal_compression_ratio: int = 4, max_text_seq_length: int = 225, activation_fn: str = "gelu-approximate", timestep_activation_fn: str = "silu", @@ -251,7 +251,7 @@ def __init__( post_patch_height = sample_height // patch_size post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // time_compression + 1 + post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames # 1. Patch embedding diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 572844d2de0a..99bac8dad7c3 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -348,6 +348,70 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) +class Upsample3D(nn.Module): + r""" + A 3D Upsampling layer. + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `1`): + Stride of the convolution. + padding (`int`, defaults to `1`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + compress_time: bool = False + ) -> None: + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if self.compress_time: + if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + inputs = torch.cat([x_first, x_rest], dim=2) + elif inputs.shape[2] > 1: + inputs = F.interpolate(inputs, scale_factor=2.0) + else: + inputs = inputs.squeeze(2) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + # only interpolate 2D + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = self.conv(inputs) + inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) + + return inputs + + def upfirdn2d_native( tensor: torch.Tensor, kernel: torch.Tensor, From 03c28eef5b8f13fe48f4122b0281939064650a4c Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 1 Aug 2024 12:18:01 +0200 Subject: [PATCH 39/94] modeling fixes --- .../autoencoders/autoencoder_kl_cogvideox.py | 129 ++++++++++-------- 1 file changed, 70 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 4406cd04769c..5021136678f3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -93,24 +93,24 @@ def cast_tuple(t, length=1): self.conv_cache = None - def fake_cp_pass_from_previous_rank(self, input_): + def fake_cp_pass_from_previous_rank(self, inputs: torch.Tensor) -> torch.Tensor: dim = self.temporal_dim kernel_size = self.time_kernel_size if kernel_size == 1: - return input_ + return inputs - input_ = input_.transpose(0, dim) + inputs = inputs.transpose(0, dim) if self.conv_cache is not None: - input_ = torch.cat([self.conv_cache.transpose(0, dim).to(input_.device), input_], dim=0) + inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0) else: - input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0) - input_ = input_.transpose(0, dim).contiguous() - return input_ + inputs = inputs.transpose(0, dim).contiguous() + return inputs - def forward(self, input_, clear_fake_cp_cache=True): - input_parallel = self.fake_cp_pass_from_previous_rank(input_) + def forward(self, inputs: torch.Tensor, clear_fake_cp_cache: bool = True): + input_parallel = self.fake_cp_pass_from_previous_rank(inputs) del self.conv_cache self.conv_cache = None @@ -150,7 +150,7 @@ def __init__( self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - def forward(self, f: torch.Tensor, zq: torch.Tensor, clear_fake_cp_cache=True) -> torch.Tensor: + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: if f.shape[2] > 1 and f.shape[2] % 2 == 1: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] @@ -186,7 +186,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels - self.non_linearity = get_activation(non_linearity) + self.nonlinearity = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut if spatial_norm_dim is None: @@ -232,20 +232,20 @@ def forward( ) -> torch.Tensor: hidden_states = input_tensor if zq is not None: - hidden_states = self.norm1(hidden_states, zq, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.norm1(hidden_states, zq) else: hidden_states = self.norm1(hidden_states) - hidden_states = self.non_linearity(hidden_states) + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) if temb is not None: - hidden_states = hidden_states + self.temb_proj(self.non_linearity(temb))[:, :, None, None, None] + hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] if zq is not None: - hidden_states = self.norm2(hidden_states, zq, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.norm2(hidden_states, zq) else: hidden_states = self.norm2(hidden_states) - hidden_states = self.non_linearity(hidden_states) + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) @@ -306,9 +306,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: - output_states = () - + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -319,19 +317,16 @@ def create_forward(*inputs): return create_forward hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, clear_fake_cp_cache + create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache ) else: - hidden_states = resnet(hidden_states, temb, clear_fake_cp_cache) - - output_states = output_states + (hidden_states,) + hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states = output_states + (hidden_states,) - return hidden_states, output_states + return hidden_states class CogVideoXMidBlock3D(nn.Module): @@ -370,9 +365,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: - output_states = () - + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -383,14 +376,12 @@ def create_forward(*inputs): return create_forward hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, clear_fake_cp_cache + create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache ) else: - hidden_states = resnet(hidden_states, temb, clear_fake_cp_cache) - - output_states = output_states + (hidden_states,) + hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) - return hidden_states, output_states + return hidden_states class CogVideoXUpBlock3D(nn.Module): @@ -436,6 +427,29 @@ def __init__( self.upsamplers = nn.ModuleList([ Upsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time) ]) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: + r"""Forward method of the `CogVideoXUpBlock3D` class.""" + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache + ) + else: + hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states class Encoder3D(nn.Module): @@ -479,9 +493,6 @@ def __init__( temporal_compression_ratio: float = 4, ): super().__init__() - self.act_fn = get_activation(act_fn) - self.num_resolutions = len(block_out_channels) - self.layers_per_block = layers_per_block # log2 of temporal_compress_times temporal_compress_level = int(np.log2(temporal_compression_ratio)) @@ -539,7 +550,7 @@ def __init__( def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True) -> torch.Tensor: r"""The forward method of the `Encoder3D` class.""" - sample = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -550,27 +561,27 @@ def custom_forward(*inputs): # 1. Down for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, temb, clear_fake_cp_cache + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, temb, None, clear_fake_cp_cache ) # 2. Mid - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, temb, clear_fake_cp_cache + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, None, clear_fake_cp_cache ) else: # 1. Down for down_block in self.down_blocks: - sample = down_block(sample, temb, clear_fake_cp_cache) + hidden_states = down_block(hidden_states, temb, None, clear_fake_cp_cache) # 2. Mid - sample = self.mid_block(sample, temb, clear_fake_cp_cache) + hidden_states = self.mid_block(hidden_states, temb, None, clear_fake_cp_cache) # 3. Post-process - sample = self.norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - return sample + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + return hidden_states class Decoder3D(nn.Module): @@ -675,7 +686,7 @@ def __init__( def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True) -> torch.Tensor: r"""The forward method of the `Decoder3D` class.""" - sample = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -685,28 +696,28 @@ def custom_forward(*inputs): return custom_forward # 1. Mid - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, temb, clear_fake_cp_cache + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, sample, clear_fake_cp_cache ) # 2. Up for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), sample, temb, clear_fake_cp_cache + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, temb, sample, clear_fake_cp_cache ) else: # 1. Mid - sample = self.mid_block(sample, temb, clear_fake_cp_cache) + hidden_states = self.mid_block(hidden_states, temb, sample, clear_fake_cp_cache) # 2. Up for up_block in self.up_blocks: - sample = up_block(sample, temb, clear_fake_cp_cache) + hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache) # 3. Post-process - sample = self.norm_out(sample, sample, clear_fake_cp_cache=clear_fake_cp_cache) - sample = self.conv_act(sample) - sample = self.conv_out(sample, clear_fake_cp_cache=clear_fake_cp_cache) - return sample + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + return hidden_states class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -869,7 +880,7 @@ def encode( @apply_forward_hook def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None, fake_cp: bool = False + self, z: torch.FloatTensor, return_dict: bool = True, fake_cp: bool = False ) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. From 712ddbeac6843d55bcf97bc58f87fa509ed33281 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 1 Aug 2024 12:18:18 +0200 Subject: [PATCH 40/94] make style --- scripts/convert_cogvideox_to_diffusers.py | 4 +- .../autoencoders/autoencoder_kl_cogvideox.py | 150 +++++++++++------- src/diffusers/models/upsampling.py | 2 +- tests/pipelines/cogvideox/test_cogvideox.py | 2 +- 4 files changed, 101 insertions(+), 57 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 1ef11a80ef24..d9e452cd6c1e 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -53,7 +53,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): key_split = key.split(".") layer_index = int(key_split[2]) replace_layer_index = 4 - 1 - layer_index - + key_split[1] = "up_blocks" key_split[2] = str(replace_layer_index) new_key = ".".join(key_split) @@ -157,7 +157,7 @@ def convert_vae(ckpt_path: str): if special_key not in key: continue handler_fn_inplace(key, original_state_dict) - + vae.load_state_dict(original_state_dict, strict=True) return vae diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 5021136678f3..0df91e627679 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -262,7 +262,7 @@ def forward( class CogVideoXDownBlock3D(nn.Module): _supports_gradient_checkpointing = True - + def __init__( self, in_channels: int, @@ -298,40 +298,46 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.downsamplers = None - + if add_downsample: - self.downsamplers = nn.ModuleList([ - Downsample3D(out_channels, out_channels, padding=downsample_padding, compress_time=compress_time) - ]) - + self.downsamplers = nn.ModuleList( + [Downsample3D(out_channels, out_channels, padding=downsample_padding, compress_time=compress_time)] + ) + self.gradient_checkpointing = False - - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + clear_fake_cp_cache: bool = False, + ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: - + def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) - + return create_forward - + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache ) else: hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) - + if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - + return hidden_states class CogVideoXMidBlock3D(nn.Module): _supports_gradient_checkpointing = True - + def __init__( self, in_channels: int, @@ -362,25 +368,31 @@ def __init__( ) ) self.resnets = nn.ModuleList(resnets) - + self.gradient_checkpointing = False - - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + clear_fake_cp_cache: bool = False, + ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: - + def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) - + return create_forward - + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache ) else: hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) - + return hidden_states @@ -424,31 +436,37 @@ def __init__( self.upsamplers = None if add_upsample: - self.upsamplers = nn.ModuleList([ - Upsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time) - ]) - - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = False) -> torch.Tensor: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time)] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + clear_fake_cp_cache: bool = False, + ) -> torch.Tensor: r"""Forward method of the `CogVideoXUpBlock3D` class.""" for resnet in self.resnets: if self.training and self.gradient_checkpointing: - + def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) - + return create_forward - + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache ) else: hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) - + if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) - + return hidden_states @@ -482,7 +500,12 @@ def __init__( self, in_channels: int = 3, out_channels: int = 16, - down_block_types: Tuple[str, ...] = ("CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D",), + down_block_types: Tuple[str, ...] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", @@ -493,7 +516,7 @@ def __init__( temporal_compression_ratio: float = 4, ): super().__init__() - + # log2 of temporal_compress_times temporal_compress_level = int(np.log2(temporal_compression_ratio)) @@ -522,10 +545,8 @@ def __init__( compress_time=compress_time, ) else: - raise ValueError( - "Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`" - ) - + raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") + self.down_blocks.append(down_block) # mid block @@ -548,23 +569,26 @@ def __init__( self.gradient_checkpointing = False - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True) -> torch.Tensor: + def forward( + self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True + ) -> torch.Tensor: r"""The forward method of the `Encoder3D` class.""" hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) if self.training and self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - + # 1. Down for down_block in self.down_blocks: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), hidden_states, temb, None, clear_fake_cp_cache ) - + # 2. Mid hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), hidden_states, temb, None, clear_fake_cp_cache @@ -573,7 +597,7 @@ def custom_forward(*inputs): # 1. Down for down_block in self.down_blocks: hidden_states = down_block(hidden_states, temb, None, clear_fake_cp_cache) - + # 2. Mid hidden_states = self.mid_block(hidden_states, temb, None, clear_fake_cp_cache) @@ -613,7 +637,12 @@ def __init__( self, in_channels: int = 16, out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D",), + up_block_types: Tuple[str, ...] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", @@ -629,7 +658,9 @@ def __init__( resolution = block_out_channels[-1] // 2 ** (len(block_out_channels) - 1) self.z_shape = (1, in_channels, resolution, resolution) - self.conv_in = CogVideoXCausalConv3d(in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.conv_in = CogVideoXCausalConv3d( + in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode + ) # mid block self.mid_block = CogVideoXMidBlock3D( @@ -645,7 +676,7 @@ def __init__( # up blocks self.up_blocks = nn.ModuleList([]) - + output_channel = reversed_block_out_channels[0] temporal_compress_level = int(np.log2(temporal_compression_ratio)) @@ -672,23 +703,26 @@ def __init__( ) prev_output_channel = output_channel else: - raise ValueError( - "Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`" - ) - + raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") + self.up_blocks.append(up_block) self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels) self.conv_act = nn.SiLU() - self.conv_out = CogVideoXCausalConv3d(reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) + self.conv_out = CogVideoXCausalConv3d( + reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode + ) self.gradient_checkpointing = False - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True) -> torch.Tensor: + def forward( + self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True + ) -> torch.Tensor: r"""The forward method of the `Decoder3D` class.""" hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) if self.training and self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) @@ -712,7 +746,7 @@ def custom_forward(*inputs): # 2. Up for up_block in self.up_blocks: hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache) - + # 3. Post-process hidden_states = self.norm_out(hidden_states, sample) hidden_states = self.conv_act(hidden_states) @@ -762,8 +796,18 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D"), - up_block_types: Tuple[str] = ("CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D",), + down_block_types: Tuple[str] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types: Tuple[str] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), block_out_channels: Tuple[int] = (128, 256, 256, 512), latent_channels: int = 16, layers_per_block: int = 3, diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 99bac8dad7c3..7638c3880c0d 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -374,7 +374,7 @@ def __init__( kernel_size: int = 3, stride: int = 1, padding: int = 1, - compress_time: bool = False + compress_time: bool = False, ) -> None: super().__init__() diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index 0da39160bd18..85968995d765 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -24,7 +24,7 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKL, DDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3D +from diffusers import AutoencoderKL, CogVideoXPipeline, CogVideoXTransformer3D, DDIMScheduler from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, From 03ee7cd10918ddb9f4e37aaa57511d3ac9d0e778 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 1 Aug 2024 13:52:58 +0200 Subject: [PATCH 41/94] add pipeline implementation --- .../autoencoders/autoencoder_kl_cogvideox.py | 23 +- src/diffusers/models/embeddings.py | 2 +- .../transformers/cogvideox_transformer_3d.py | 18 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 274 ++++++++++++++++-- 4 files changed, 274 insertions(+), 43 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 0df91e627679..12d90151425e 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -470,9 +470,9 @@ def create_forward(*inputs): return hidden_states -class Encoder3D(nn.Module): +class CogVideoXEncoder3D(nn.Module): r""" - The `Encoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): @@ -572,7 +572,7 @@ def __init__( def forward( self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True ) -> torch.Tensor: - r"""The forward method of the `Encoder3D` class.""" + r"""The forward method of the `CogVideoXEncoder3D` class.""" hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) if self.training and self.gradient_checkpointing: @@ -608,9 +608,10 @@ def custom_forward(*inputs): return hidden_states -class Decoder3D(nn.Module): +class CogVideoXDecoder3D(nn.Module): r""" - The `Decoder3D` layer of a variational autoencoder that decodes its latent representation into an output sample. + The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. Args: in_channels (`int`, *optional*, defaults to 3): @@ -718,7 +719,7 @@ def __init__( def forward( self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True ) -> torch.Tensor: - r"""The forward method of the `Decoder3D` class.""" + r"""The forward method of the `CogVideoXDecoder3D` class.""" hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) if self.training and self.gradient_checkpointing: @@ -827,7 +828,7 @@ def __init__( ): super().__init__() - self.encoder = Encoder3D( + self.encoder = CogVideoXEncoder3D( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, @@ -838,7 +839,7 @@ def __init__( norm_num_groups=norm_num_groups, temporal_compression_ratio=temporal_compression_ratio, ) - self.decoder = Decoder3D( + self.decoder = CogVideoXDecoder3D( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, @@ -865,7 +866,7 @@ def __init__( self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Encoder3D, Decoder3D)): + if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): module.gradient_checkpointing = value def enable_tiling(self, use_tiling: bool = True): @@ -910,6 +911,7 @@ def encode( Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. fake_cp (`bool`, *optional*, defaults to `True`): If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). + Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. @@ -935,6 +937,7 @@ def decode( Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. fake_cp (`bool`, *optional*, defaults to `True`): If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). + Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is @@ -946,7 +949,7 @@ def decode( dec = self.decoder(z, clear_fake_cp_cache=not fake_cp) if not return_dict: return (dec,) - return dec + return DecoderOutput(sample=dec) def forward( self, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5ecd77840feb..937abb6fdca9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -362,7 +362,7 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): text_embeds = self.text_proj(text_embeds) B, F, C, H, W = image_embeds.shape - image_embeds = image_embeds.view(-1, C, H, W) + image_embeds = image_embeds.reshape(-1, C, H, W) image_embeds = self.proj(image_embeds) image_embeds = image_embeds.view(B, F, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [B, F, H x W, C] diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 45e7d62eed2f..f55b2ca3a5a3 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -308,41 +308,41 @@ def _set_gradient_checkpointing(self, module, value=False): def forward( self, - sample: torch.Tensor, + hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], attention_mask: Optional[Union[int, torch.Tensor]] = None, timestep_cond: Optional[torch.Tensor] = None, return_dict: bool = True, ): - batch_size, num_frames, channels, height, width = sample.shape + batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" + is_mps = hidden_states.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + timesteps = torch.tensor([timesteps], dtype=dtype, device=hidden_states.device) elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + timesteps = timesteps[None].to(hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) + timesteps = timesteps.expand(hidden_states.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) + t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, sample) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # 3. Position embedding seq_length = height * width * num_frames // (self.config.patch_size**2) @@ -358,7 +358,7 @@ def forward( # 4. Prepare attention mask if attention_mask is None: attention_mask = torch.ones(batch_size, self.num_patches + self.config.max_text_seq_length) - attention_mask = attention_mask.to(device=sample.device, dtype=sample.dtype) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) # 5. Transformer blocks for i, block in enumerate(self.transformer_blocks): diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 074f01356c6f..02b03d6c7161 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -1,4 +1,4 @@ -# Copyright 2024 The The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,19 +15,17 @@ import inspect from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3D from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - BaseOutput, - logging, - replace_example_docstring, -) +from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -153,20 +151,23 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = ( + self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 255 + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 225 ) - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, - max_sequence_length: int = 255, + max_sequence_length: int = 225, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -212,7 +213,7 @@ def encode_prompt( num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: int = 255, + max_sequence_length: int = 225, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -285,6 +286,48 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor, num_seconds: int): + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = [] + for i in range(num_seconds): + # Whether or not to clear fake context parallel cache + fake_cp = i + 1 < num_seconds + start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) + + current_frames = self.vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample + frames.append(current_frames) + + frames = torch.cat(frames, dim=2) + return frames + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -303,21 +346,57 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, ): - # TODO: implement check_inputs - pass + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None - ): - # TODO: implement prepare_latents - pass + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - def decode_latents(self, latents: torch.Tensor, video_length: int, vae_batch_size: int = 16): - # TODO: implement decode_latents - pass + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) @property def guidance_scale(self): @@ -342,10 +421,159 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_seconds: int = 6, + fps: int = 8, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 7.5, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[CogVideoXPipelineOutput, Tuple]: """ TODO: implement forward pass Examples: """ - pass + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + num_frames = 1 + num_seconds * fps + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # TODO: check if this is needed + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latents": + video = self.decode_latents(latents, num_seconds) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) From a31db5f952385d04fe5ba1ce2d9efbc53c08b468 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Thu, 1 Aug 2024 22:58:27 +0800 Subject: [PATCH 42/94] using with 226 instead of 225 of final weight --- .../models/transformers/cogvideox_transformer_3d.py | 2 +- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index f55b2ca3a5a3..b2a2c87e9275 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -233,7 +233,7 @@ def __init__( sample_frames: int = 49, patch_size: int = 2, temporal_compression_ratio: int = 4, - max_text_seq_length: int = 225, + max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", timestep_activation_fn: str = "silu", norm_type: str = "layer_norm", diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 02b03d6c7161..b4f78d1690ed 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -158,7 +158,7 @@ def __init__( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 225 + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226 ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -167,7 +167,7 @@ def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, - max_sequence_length: int = 225, + max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -213,7 +213,7 @@ def encode_prompt( num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: int = 225, + max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): From 351d1f009e8367244f428385bbbecb8f90f6f821 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Thu, 1 Aug 2024 23:31:03 +0800 Subject: [PATCH 43/94] remove 0.transformer_blocks.encoder.embed_tokens.weight --- scripts/convert_cogvideox_to_diffusers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index d9e452cd6c1e..a72210756d93 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -138,6 +138,12 @@ def convert_transformer(ckpt_path: str): continue handler_fn_inplace(key, original_state_dict) + # remove incompatible key + + incompatible_key = '0.transformer_blocks.encoder.embed_tokens.weight' + if incompatible_key in original_state_dict.keys(): + original_state_dict.pop(incompatible_key) + transformer.load_state_dict(original_state_dict, strict=True) return transformer From d0b8db2b117ff28ea5da84873ccae6337833a9ba Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 1 Aug 2024 21:12:49 +0200 Subject: [PATCH 44/94] update --- scripts/convert_cogvideox_to_diffusers.py | 11 +++-------- .../autoencoders/autoencoder_kl_cogvideox.py | 17 ++++++++--------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index a72210756d93..7300b5f2d778 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -45,7 +45,7 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): state_dict.pop(key) -def remove_loss_keys_inplace(key: str, state_dict: Dict[str, Any]): +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): state_dict.pop(key) @@ -85,6 +85,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): "query_layernorm_list": reassign_query_key_layernorm_inplace, "key_layernorm_list": reassign_query_key_layernorm_inplace, "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, + "embed_tokens": remove_keys_inplace, } VAE_KEYS_RENAME_DICT = { @@ -100,7 +101,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): } VAE_SPECIAL_KEYS_REMAP = { - "loss": remove_loss_keys_inplace, + "loss": remove_keys_inplace, "up.": replace_up_keys_inplace, } @@ -138,12 +139,6 @@ def convert_transformer(ckpt_path: str): continue handler_fn_inplace(key, original_state_dict) - # remove incompatible key - - incompatible_key = '0.transformer_blocks.encoder.embed_tokens.weight' - if incompatible_key in original_state_dict.keys(): - original_state_dict.pop(incompatible_key) - transformer.load_state_dict(original_state_dict, strict=True) return transformer diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 12d90151425e..04ed39fe0530 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -201,14 +201,15 @@ def __init__( f_channels=out_channels, zq_channels=spatial_norm_dim, ) + self.conv1 = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) + if temb_channels > 0: self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) self.dropout = nn.Dropout(dropout) - self.conv2 = CogVideoXCausalConv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) @@ -225,12 +226,12 @@ def __init__( def forward( self, - input_tensor: torch.Tensor, + inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True, ) -> torch.Tensor: - hidden_states = input_tensor + hidden_states = inputs if zq is not None: hidden_states = self.norm1(hidden_states, zq) else: @@ -245,18 +246,18 @@ def forward( hidden_states = self.norm2(hidden_states, zq) else: hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - input_tensor = self.conv_shortcut(input_tensor, clear_fake_cp_cache=clear_fake_cp_cache) + inputs = self.conv_shortcut(inputs, clear_fake_cp_cache=clear_fake_cp_cache) else: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states + inputs = self.conv_shortcut(inputs) + output_tensor = inputs + hidden_states return output_tensor @@ -657,8 +658,6 @@ def __init__( reversed_block_out_channels = list(reversed(block_out_channels)) - resolution = block_out_channels[-1] // 2 ** (len(block_out_channels) - 1) - self.z_shape = (1, in_channels, resolution, resolution) self.conv_in = CogVideoXCausalConv3d( in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode ) From fe6f5d64191786bd318604434aa4870900c131e1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 1 Aug 2024 22:06:08 +0200 Subject: [PATCH 45/94] ensure tokenizer config correctly uses 226 as text length --- scripts/convert_cogvideox_to_diffusers.py | 4 +++- .../models/autoencoders/autoencoder_kl_cogvideox.py | 2 +- .../models/transformers/cogvideox_transformer_3d.py | 5 ++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 7300b5f2d778..38000313cafd 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -105,6 +105,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): "up.": replace_up_keys_inplace, } +TOKENIZER_MAX_LENGTH = 226 + def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = saved_dict @@ -189,7 +191,7 @@ def get_args(): vae = convert_vae(args.vae_ckpt_path) text_encoder_id = "google/t5-v1_1-xxl" - tokenizer = T5Tokenizer.from_pretrained(text_encoder_id) + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id) # TODO: verify with authors diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 04ed39fe0530..9b86f1a71f9d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -748,7 +748,7 @@ def custom_forward(*inputs): hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache) # 3. Post-process - hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.norm_out(hidden_states, sample, clear_fake_cp_cache=clear_fake_cp_cache) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) return hidden_states diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index b2a2c87e9275..7c6c1b0124dd 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -346,14 +346,13 @@ def forward( # 3. Position embedding seq_length = height * width * num_frames // (self.config.patch_size**2) - text_seq_length = encoder_hidden_states.size(1) pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] + encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] + hidden_states = hidden_states[:, self.config.max_text_seq_length :] # 4. Prepare attention mask if attention_mask is None: From 4c2e8870e6ffab1e7266ac085d7a2115c08b6bc7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 00:51:15 +0200 Subject: [PATCH 46/94] add cogvideo specific attn processor --- src/diffusers/models/attention_processor.py | 79 +++++++++++++++++++ .../autoencoders/autoencoder_kl_cogvideox.py | 2 +- src/diffusers/models/embeddings.py | 3 + src/diffusers/models/normalization.py | 11 ++- .../transformers/cogvideox_transformer_3d.py | 25 +++++- .../pipelines/cogvideo/pipeline_cogvideox.py | 4 +- 6 files changed, 115 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5c5464c37683..d55cd74535cc 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2056,6 +2056,85 @@ def __call__( return hidden_states +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the CogVideoXTransformer3D model. It applies a normalization layer on the query and key vectors. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + print("attention hidden states:", hidden_states.sum()) + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + print("attention proj:", query.sum(), key.sum(), value.sum()) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Apply QK norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + print("attention norm:", query.sum(), key.sum(), value.sum()) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + print("attn hidden_states:", hidden_states.sum()) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + print("attn to_out:", hidden_states.sum()) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 9b86f1a71f9d..04ed39fe0530 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -748,7 +748,7 @@ def custom_forward(*inputs): hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache) # 3. Post-process - hidden_states = self.norm_out(hidden_states, sample, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.norm_out(hidden_states, sample) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) return hidden_states diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 937abb6fdca9..213354630f61 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -360,6 +360,7 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ text_embeds = self.text_proj(text_embeds) + print("patch text_embeds:", text_embeds.sum()) B, F, C, H, W = image_embeds.shape image_embeds = image_embeds.reshape(-1, C, H, W) @@ -367,8 +368,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): image_embeds = image_embeds.view(B, F, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [B, F, H x W, C] image_embeds = image_embeds.flatten(1, 2) # [B, F x H x W, C] + print("patch image_embeds:", image_embeds.sum()) embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() # [B, S + F x H x W, C] + print("patch concat:", embeds.sum()) return embeds diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index ecb831a3390d..1945c4d3172e 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -307,14 +307,17 @@ def __init__( def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - shift_msa, scale_msa, gate_msa, enc_shift_msa, enc_scale_msa, enc_gate_msa = self.linear( + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear( self.silu(temb) ).chunk(6, dim=1) - hidden_states = self.norm(hidden_states) * (1 + scale_msa)[:, None, :] + shift_msa[:, None, :] + print("adaln debug:", shift.sum(), scale.sum(), gate.sum(), enc_shift.sum(), enc_scale.sum(), enc_gate.sum()) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + print("hidden_states adaln:", hidden_states.sum()) encoder_hidden_states = ( - self.norm(encoder_hidden_states) * (1 + enc_scale_msa)[:, None, :] + enc_shift_msa[:, None, :] + self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] ) - return hidden_states, encoder_hidden_states, gate_msa[:, None, :], enc_gate_msa[:, None, :] + print("encoder_hidden_states adaln:", encoder_hidden_states.sum()) + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] if is_torch_version(">=", "2.1.0"): diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 7c6c1b0124dd..6a5df4d4be91 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -21,6 +21,7 @@ from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward +from ..attention_processor import CogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -114,6 +115,7 @@ def __init__( eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), ) # 2. Feed Forward @@ -141,17 +143,20 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, + temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) + print("norm 1:", norm_hidden_states.sum(), norm_encoder_hidden_states.sum(), gate_msa.sum(), enc_gate_msa.sum()) # attention text_length = norm_encoder_hidden_states.size(1) norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + print("norm concat attn:", norm_hidden_states.sum()) attn_output = self.attn1(norm_hidden_states, attention_mask=attention_mask) + print("attn output:", attn_output.sum()) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] @@ -160,17 +165,21 @@ def forward( norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb ) + print("norm 2:", norm_hidden_states.sum(), norm_encoder_hidden_states.sum(), gate_ff.sum(), enc_gate_ff.sum()) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + print("norm concat mlp:", norm_hidden_states.sum()) if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) + print("ff:", ff_output.sum()) hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] + print("block output:", hidden_states.sum(), encoder_hidden_states.sum()) return hidden_states, encoder_hidden_states @@ -340,16 +349,20 @@ def forward( # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) + print("emb:", emb.sum()) # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + print("hidden_states patch_embed:", hidden_states.sum()) # 3. Position embedding seq_length = height * width * num_frames // (self.config.patch_size**2) pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] + print("pos_embeds:", pos_embeds.sum()) hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) + print("hidden_states embedding:", hidden_states.sum()) encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] hidden_states = hidden_states[:, self.config.max_text_seq_length :] @@ -385,18 +398,26 @@ def custom_forward(*inputs): temb=emb, attention_mask=attention_mask, ) + + print("hidden_states loop:", i, hidden_states.sum()) + print("encoder_hidden_states loop:", i, encoder_hidden_states.sum()) hidden_states = self.norm_final(hidden_states) + print("hidden_states norm_final:", hidden_states.sum()) # 6. Final block shift, scale = self.adaln_out(emb).chunk(2, dim=1) + print("adaln shift scale:", shift.sum(), scale.sum()) hidden_states = self.norm_out(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + print("hidden_states norm_out:", hidden_states.sum()) hidden_states = self.proj_out(hidden_states) + print("hidden_states proj_out:", hidden_states.sum()) # 7. Unpatchify p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, p, p, self.config.out_channels) + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, p, p, channels) output = output.permute(0, 1, 6, 2, 4, 3, 5).flatten(5, 6).flatten(3, 4) + print("output:", output.sum()) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index b4f78d1690ed..bf6f089cedff 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -529,14 +529,14 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, t - ) # TODO: check if this is needed + ) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( - latent_model_input, + hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, From 41da084fbe09d15310c58cb16dab4a35bf299c11 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 00:53:09 +0200 Subject: [PATCH 47/94] remove debug prints --- src/diffusers/models/attention_processor.py | 5 ----- src/diffusers/models/normalization.py | 3 --- .../transformers/cogvideox_transformer_3d.py | 19 ------------------- 3 files changed, 27 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d55cd74535cc..1746e44a886a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2091,11 +2091,9 @@ def __call__( # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - print("attention hidden states:", hidden_states.sum()) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - print("attention proj:", query.sum(), key.sum(), value.sum()) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -2109,14 +2107,12 @@ def __call__( query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) - print("attention norm:", query.sum(), key.sum(), value.sum()) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - print("attn hidden_states:", hidden_states.sum()) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2125,7 +2121,6 @@ def __call__( hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - print("attn to_out:", hidden_states.sum()) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 1945c4d3172e..4543c030b1cf 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -310,13 +310,10 @@ def forward( shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear( self.silu(temb) ).chunk(6, dim=1) - print("adaln debug:", shift.sum(), scale.sum(), gate.sum(), enc_shift.sum(), enc_scale.sum(), enc_gate.sum()) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] - print("hidden_states adaln:", hidden_states.sum()) encoder_hidden_states = ( self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] ) - print("encoder_hidden_states adaln:", encoder_hidden_states.sum()) return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 6a5df4d4be91..5106cbadbf82 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -149,14 +149,11 @@ def forward( norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) - print("norm 1:", norm_hidden_states.sum(), norm_encoder_hidden_states.sum(), gate_msa.sum(), enc_gate_msa.sum()) # attention text_length = norm_encoder_hidden_states.size(1) norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - print("norm concat attn:", norm_hidden_states.sum()) attn_output = self.attn1(norm_hidden_states, attention_mask=attention_mask) - print("attn output:", attn_output.sum()) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] @@ -165,21 +162,17 @@ def forward( norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb ) - print("norm 2:", norm_hidden_states.sum(), norm_encoder_hidden_states.sum(), gate_ff.sum(), enc_gate_ff.sum()) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - print("norm concat mlp:", norm_hidden_states.sum()) if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) - print("ff:", ff_output.sum()) hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] - print("block output:", hidden_states.sum(), encoder_hidden_states.sum()) return hidden_states, encoder_hidden_states @@ -349,20 +342,16 @@ def forward( # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) - print("emb:", emb.sum()) # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - print("hidden_states patch_embed:", hidden_states.sum()) # 3. Position embedding seq_length = height * width * num_frames // (self.config.patch_size**2) pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] - print("pos_embeds:", pos_embeds.sum()) hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) - print("hidden_states embedding:", hidden_states.sum()) encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] hidden_states = hidden_states[:, self.config.max_text_seq_length :] @@ -398,26 +387,18 @@ def custom_forward(*inputs): temb=emb, attention_mask=attention_mask, ) - - print("hidden_states loop:", i, hidden_states.sum()) - print("encoder_hidden_states loop:", i, encoder_hidden_states.sum()) hidden_states = self.norm_final(hidden_states) - print("hidden_states norm_final:", hidden_states.sum()) # 6. Final block shift, scale = self.adaln_out(emb).chunk(2, dim=1) - print("adaln shift scale:", shift.sum(), scale.sum()) hidden_states = self.norm_out(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] - print("hidden_states norm_out:", hidden_states.sum()) hidden_states = self.proj_out(hidden_states) - print("hidden_states proj_out:", hidden_states.sum()) # 7. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, p, p, channels) output = output.permute(0, 1, 6, 2, 4, 3, 5).flatten(5, 6).flatten(3, 4) - print("output:", output.sum()) if not return_dict: return (output,) From 77558f31bf4cf8722d0c46fb8099df6539f34fd7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 01:03:02 +0200 Subject: [PATCH 48/94] add pipeline docs --- .../pipelines/cogvideo/pipeline_cogvideox.py | 88 ++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index bf6f089cedff..81f3fb309e2b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -35,7 +35,14 @@ EXAMPLE_DOC_STRING = """ Examples: ```python - # TODO: update example + >>> # TODO: verify this before merge + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX", torch_dtype=torch.bfloat16).to("cuda") + >>> video = pipe("a polar bear dancing, high quality, realistic", num_inference_steps=20).frames[0] + + >>> export_to_video(video, "output.mp4", fps=8) ``` """ @@ -102,6 +109,15 @@ def retrieve_timesteps( @dataclass class CogVideoXPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ frames: torch.Tensor @@ -444,10 +460,78 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[CogVideoXPipelineOutput, Tuple]: """ - TODO: implement forward pass + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_seconds (`int`, defaults to `6`): + Duration of video in seconds. Must be less than or equal to 6. + fps (`int`, defaults to `8`): + Number of frames per second in video. Must be equal to 8 (for now). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. """ + assert num_seconds <= 6 and fps == 8 + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs From e12458e16c41e6fcc1ea3ee21d69fcaccd345121 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 01:03:37 +0200 Subject: [PATCH 49/94] make style --- src/diffusers/models/attention_processor.py | 4 +++- src/diffusers/models/normalization.py | 8 ++------ src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 5 ++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1746e44a886a..8c91f8310adc 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2064,7 +2064,9 @@ class CogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "CogVideoXAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 4543c030b1cf..cb36ea0473ce 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -307,13 +307,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear( - self.silu(temb) - ).chunk(6, dim=1) + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] - encoder_hidden_states = ( - self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] - ) + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 81f3fb309e2b..1d28d4a589b3 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -118,6 +118,7 @@ class CogVideoXPipelineOutput(BaseOutput): denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. """ + frames: torch.Tensor @@ -611,9 +612,7 @@ def __call__( continue latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) From c33dd0213bb9c798c58bfff965723d5f5ae93fa3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 01:04:35 +0200 Subject: [PATCH 50/94] remove incorrect copied from --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 1d28d4a589b3..6e144661eafc 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -303,7 +303,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): From 71e7c82ae8c1f9a1c94d45067f7216c6ea19436f Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Fri, 2 Aug 2024 16:23:24 +0800 Subject: [PATCH 51/94] vae problem fix --- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 04ed39fe0530..30ddaa431c34 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -410,7 +410,7 @@ def __init__( resnet_groups: int = 32, spatial_norm_dim: int = 16, add_upsample: bool = True, - upsample_padding: int = 0, + upsample_padding: int = 1, compress_time: bool = False, pad_mode: str = "first", ): From ec53a30a0ef26a312fa9eee7c75937fa6563a0a5 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Fri, 2 Aug 2024 22:38:58 +0800 Subject: [PATCH 52/94] schedule --- src/diffusers/__init__.py | 1 + .../pipelines/cogvideo/pipeline_cogvideox.py | 5 +- src/diffusers/schedulers/__init__.py | 2 + .../schedulers/scheduling_ddim_cogvideox.py | 478 ++++++++++++++++++ 4 files changed, 483 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_ddim_cogvideox.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3355c8e9e0ec..0cb86d407791 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -155,6 +155,7 @@ [ "AmusedScheduler", "CMStochasticIterativeScheduler", + "CogVideoXDDIMScheduler", "DDIMInverseScheduler", "DDIMParallelScheduler", "DDIMScheduler", diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 6e144661eafc..866b70c51bd3 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3D from ...pipelines.pipeline_utils import DiffusionPipeline -from ...schedulers import KarrasDiffusionSchedulers +from ...schedulers import CogVideoXDDIMScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -160,14 +160,13 @@ def __init__( text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3D, - scheduler: KarrasDiffusionSchedulers, + scheduler: CogVideoXDDIMScheduler, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 696e9c3ad5d5..17fcf30f4f3e 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -43,6 +43,7 @@ _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] + _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"] _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"] _import_structure["scheduling_ddpm"] = ["DDPMScheduler"] @@ -140,6 +141,7 @@ from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler + from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py new file mode 100644 index 000000000000..d857a4b4aaef --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -0,0 +1,478 @@ +# Copyright 2024 The The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0 + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1-snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + a_t = ((1-alpha_prod_t_prev)/(1-alpha_prod_t))**0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + # breakpoint() + # # 5. compute variance: "sigma_t(η)" -> see formula (16) + # # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + # variance = self._get_variance(timestep, prev_timestep) + # std_dev_t = eta * variance ** (0.5) + + # if use_clipped_model_output: + # # the pred_epsilon is always re-derived from the clipped x_0 in Glide + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + # if eta > 0: + # if variance_noise is not None and generator is not None: + # raise ValueError( + # "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + # " `variance_noise` stays `None`." + # ) + + # if variance_noise is None: + # variance_noise = randn_tensor( + # model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + # ) + # variance = std_dev_t * variance_noise + + # prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps From 551c884acdd0fda16f926e1baa6fa708350fe39d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 17:40:23 +0200 Subject: [PATCH 53/94] remove debug prints --- src/diffusers/models/embeddings.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 213354630f61..937abb6fdca9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -360,7 +360,6 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ text_embeds = self.text_proj(text_embeds) - print("patch text_embeds:", text_embeds.sum()) B, F, C, H, W = image_embeds.shape image_embeds = image_embeds.reshape(-1, C, H, W) @@ -368,10 +367,8 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): image_embeds = image_embeds.view(B, F, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [B, F, H x W, C] image_embeds = image_embeds.flatten(1, 2) # [B, F x H x W, C] - print("patch image_embeds:", image_embeds.sum()) embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() # [B, S + F x H x W, C] - print("patch concat:", embeds.sum()) return embeds From 3def90523d3829bb332e7fb743b872dc76adaaf1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 17:40:36 +0200 Subject: [PATCH 54/94] update --- scripts/convert_cogvideox_to_diffusers.py | 21 ++++++++++++------- src/diffusers/__init__.py | 1 + .../transformers/cogvideox_transformer_3d.py | 6 ++++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 3 +-- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 38000313cafd..c8a2316b9d01 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -4,7 +4,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3D, DPMSolverMultistepScheduler +from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3D, CogVideoXDDIMScheduler def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): @@ -194,13 +194,18 @@ def get_args(): tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id) - # TODO: verify with authors - scheduler = DPMSolverMultistepScheduler.from_pretrained( - "runwayml/stable-diffusion-v1-5", - subfolder="scheduler", - algorithm_type="sde-dpmsolver++", - prediction_type="v_prediction", - ) + scheduler = CogVideoXDDIMScheduler.from_config({ + "snr_shift_scale": 3.0, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": True, + "set_alpha_to_one": True, + "timestep_spacing": "linspace" + }) pipe = CogVideoXPipeline( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0cb86d407791..589991ba18d2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -597,6 +597,7 @@ from .schedulers import ( AmusedScheduler, CMStochasticIterativeScheduler, + CogVideoXDDIMScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 5106cbadbf82..44ebb2a3897d 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -27,6 +27,8 @@ from ..modeling_utils import ModelMixin from ..normalization import CogVideoXLayerNormZero +from einops import rearrange + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -397,8 +399,8 @@ def custom_forward(*inputs): # 7. Unpatchify p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, p, p, channels) - output = output.permute(0, 1, 6, 2, 4, 3, 5).flatten(5, 6).flatten(3, 4) + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 866b70c51bd3..65edf9247d2a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -35,7 +35,6 @@ EXAMPLE_DOC_STRING = """ Examples: ```python - >>> # TODO: verify this before merge >>> from diffusers import CogVideoXPipeline >>> from diffusers.utils import export_to_video @@ -444,7 +443,7 @@ def __call__( fps: int = 8, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, - guidance_scale: float = 7.5, + guidance_scale: float = 6, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, From 21509aa7f51afd83c3e3cd2845fd2e985139d31b Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 3 Aug 2024 00:34:55 +0800 Subject: [PATCH 55/94] fp16 problem --- docs/source/en/api/pipelines/cogvideox.md | 2 +- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 3 --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 9641c0965685..46c84765b28b 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -38,7 +38,7 @@ import torch from diffusers import LattePipeline pipeline = LattePipeline.from_pretrained( - "THUDM/CogVideoX", torch_dtype=torch.bfloat16 + "THUDM/CogVideoX-2b", torch_dtype=torch.float16 ).to("cuda") ``` diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 30ddaa431c34..4600f880bbd8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -16,7 +16,6 @@ from .vae import DecoderOutput, DiagonalGaussianDistribution -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXSafeConv3d(nn.Conv3d): """ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. @@ -46,7 +45,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input) -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXCausalConv3d(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.""" @@ -125,7 +123,6 @@ def forward(self, inputs: torch.Tensor, clear_fake_cp_cache: bool = True): return output -# Todo: zRzRzRzRzRzRzR Move it to cogvideox model file since pr#2 has been merged class CogVideoXSpatialNorm3D(nn.Module): r""" Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 65edf9247d2a..d5957107629b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -38,7 +38,7 @@ >>> from diffusers import CogVideoXPipeline >>> from diffusers.utils import export_to_video - >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX", torch_dtype=torch.bfloat16).to("cuda") + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda") >>> video = pipe("a polar bear dancing, high quality, realistic", num_inference_steps=20).frames[0] >>> export_to_video(video, "output.mp4", fps=8) From b42b079213b1711120116e4bbcc0a9393dd0d2a8 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 3 Aug 2024 13:46:57 +0800 Subject: [PATCH 56/94] fix some comment --- docs/source/en/api/loaders/single_file.md | 2 ++ docs/source/en/api/pipelines/cogvideox.md | 7 +++-- scripts/convert_cogvideox_to_diffusers.py | 5 ++-- .../autoencoders/autoencoder_kl_cogvideox.py | 6 ++--- src/diffusers/models/downsampling.py | 26 ++++++++++++------- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/docs/source/en/api/loaders/single_file.md b/docs/source/en/api/loaders/single_file.md index 0af0ce6488d4..1c17a740755f 100644 --- a/docs/source/en/api/loaders/single_file.md +++ b/docs/source/en/api/loaders/single_file.md @@ -22,6 +22,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: ## Supported pipelines +- [`CogVideoXPipeline`] - [`StableDiffusionPipeline`] - [`StableDiffusionImg2ImgPipeline`] - [`StableDiffusionInpaintPipeline`] @@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: - [`UNet2DConditionModel`] - [`StableCascadeUNet`] - [`AutoencoderKL`] +- [`AutoencoderKLCogVideoX`] - [`ControlNetModel`] - [`SD3Transformer2DModel`] diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 46c84765b28b..2c96d714a723 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -10,12 +10,15 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. --> +# limitations under the License. + +## TODO: The paper is still being written. +--> # CogVideoX -[The paper is still being written]() from Tsinghua University & ZhipuAI. +[]() from Tsinghua University & ZhipuAI. The abstract from the paper is: diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index c8a2316b9d01..1d27c7b62d9e 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -172,7 +172,7 @@ def get_args(): ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" ) @@ -214,5 +214,4 @@ def get_args(): if args.fp16: pipe = pipe.to(dtype=torch.float16) - variant = "fp16" if args.fp16 else None - pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, push_to_hub=args.push_to_hub) + pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 4600f880bbd8..3c370cc32106 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -59,10 +59,8 @@ def __init__( ): super().__init__() - def cast_tuple(t, length=1): - return t if isinstance(t, tuple) else ((t,) * length) - - kernel_size = cast_tuple(kernel_size, 3) + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 time_kernel_size, height_kernel_size, width_kernel_size = kernel_size diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index e107117c213e..4e69f2350e32 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -321,28 +321,34 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if self.compress_time: b, c, t, h, w = x.shape + + # (b, c, t, h, w) -> (b, h, w, c, t) -> (b * h * w, c, t) x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) if x.shape[-1] % 2 == 1: - # split first frame x_first, x_rest = x[..., 0], x[..., 1:] - if x_rest.shape[-1] > 0: + # (b * h * w, c, t - 1) -> (b * h * w, c, (t - 1) // 2) x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + # (b * h * w, c, (t // 2) + 1) -> (b, h, w, c, (t // 2) + 1) -> (b, c, (t // 2) + 1, h, w) x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - else: + # (b * h * w, c, t) -> (b * h * w, c, t // 2) x = F.avg_pool1d(x, kernel_size=2, stride=2) + # (b * h * w, c, t // 2) -> (b, h, w, c, t // 2) -> (b, c, t // 2, h, w) x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - + # Pad the tensor + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + # (b, c, t, h, w) -> (b, t, c, h, w) -> (b * t, c, h, w) + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + # (b * t, c, h, w) -> (b, t, c, h, w) -> (b, c, t, h, w) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) return x From 477e12b235e16a6f4e67d616b9b1926887abe639 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 3 Aug 2024 15:08:07 +0800 Subject: [PATCH 57/94] fix --- docs/source/en/api/models/autoencoderkl_cogvideox.md | 2 +- docs/source/en/api/models/cogvideox_transformer3d.md | 6 +++--- scripts/convert_cogvideox_to_diffusers.py | 4 ++-- src/diffusers/__init__.py | 4 ++-- src/diffusers/models/__init__.py | 4 ++-- src/diffusers/models/attention_processor.py | 7 +------ src/diffusers/models/downsampling.py | 3 ++- src/diffusers/models/transformers/__init__.py | 2 +- .../models/transformers/cogvideox_transformer_3d.py | 12 +++++------- src/diffusers/models/upsampling.py | 3 ++- .../pipelines/cogvideo/pipeline_cogvideox.py | 8 ++++---- src/diffusers/schedulers/__init__.py | 2 +- .../schedulers/scheduling_ddim_cogvideox.py | 7 +++---- src/diffusers/utils/dummy_pt_objects.py | 2 +- tests/pipelines/cogvideox/test_cogvideox.py | 4 ++-- 15 files changed, 32 insertions(+), 38 deletions(-) diff --git a/docs/source/en/api/models/autoencoderkl_cogvideox.md b/docs/source/en/api/models/autoencoderkl_cogvideox.md index e876092e06af..a74c4d062ad6 100644 --- a/docs/source/en/api/models/autoencoderkl_cogvideox.md +++ b/docs/source/en/api/models/autoencoderkl_cogvideox.md @@ -19,7 +19,7 @@ The abstract from the paper is: ## Loading from the original format -By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded +By default the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows: ```py diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md index 1ef71636820e..242c7d0bbbe8 100644 --- a/docs/source/en/api/models/cogvideox_transformer3d.md +++ b/docs/source/en/api/models/cogvideox_transformer3d.md @@ -10,10 +10,10 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -## CogVideoXTransformer3D +## CogVideoXTransformer3DModel A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX). -## CogVideoXTransformer3D +## CogVideoXTransformer3DModel -[[autodoc]] CogVideoXTransformer3D +[[autodoc]] CogVideoXTransformer3DModel diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 1d27c7b62d9e..89c891ca68bd 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -4,7 +4,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3D, CogVideoXDDIMScheduler +from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): @@ -127,7 +127,7 @@ def convert_transformer(ckpt_path: str): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - transformer = CogVideoXTransformer3D() + transformer = CogVideoXTransformer3DModel() for key in list(original_state_dict.keys()): new_key = key[len(PREFIX_KEY) :] diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 589991ba18d2..effcde491f11 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -82,7 +82,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", - "CogVideoXTransformer3D", + "CogVideoXTransformer3DModel", "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAdapter", @@ -527,7 +527,7 @@ AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, - CogVideoXTransformer3D, + CogVideoXTransformer3DModel, ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e05701c53b86..5514eda26f82 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -42,7 +42,7 @@ _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] - _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3D"] + _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] @@ -94,7 +94,7 @@ from .modeling_utils import ModelMixin from .transformers import ( AuraFlowTransformer2DModel, - CogVideoXTransformer3D, + CogVideoXTransformer3DModel, DiTTransformer2DModel, DualTransformer2DModel, HunyuanDiT2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8c91f8310adc..368ea7e9c5bd 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2059,7 +2059,7 @@ def __call__( class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the CogVideoXTransformer3D model. It applies a normalization layer on the query and key vectors. + used in the CogVideoXTransformer3DModel. It applies a normalization layer on the query and key vectors. """ def __init__(self): @@ -2073,12 +2073,7 @@ def __call__( attn: Attention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - *args, - **kwargs, ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) input_ndim = hidden_states.ndim if input_ndim == 4: diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 4e69f2350e32..05ef627e34ec 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -286,8 +286,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class Downsample3D(nn.Module): + # Todo: Wait for paper relase. r""" - A 3D Downsampling layer. + A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI Args: in_channels (`int`): diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 0f8e6519a407..6ddb36aee82a 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -3,7 +3,7 @@ if is_torch_available(): from .auraflow_transformer_2d import AuraFlowTransformer2DModel - from .cogvideox_transformer_3d import CogVideoXTransformer3D + from .cogvideox_transformer_3d import CogVideoXTransformer3DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 44ebb2a3897d..3e6e3868faed 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -178,7 +178,7 @@ def forward( return hidden_states, encoder_hidden_states -class CogVideoXTransformer3D(ModelMixin, ConfigMixin): +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True """ @@ -322,8 +322,7 @@ def forward( batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding - timesteps = timestep - if not torch.is_tensor(timesteps): + if not torch.is_tensor(timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = hidden_states.device.type == "mps" @@ -331,12 +330,11 @@ def forward( dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=hidden_states.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(hidden_states.device) + timesteps = torch.tensor([timestep], dtype=dtype, device=hidden_states.device) + elif len(timestep.shape) == 0: + timesteps = timestep[None].to(hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(hidden_states.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 7638c3880c0d..e04e1dd4c448 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -349,8 +349,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class Upsample3D(nn.Module): + # Todo: Wait for paper relase. r""" - A 3D Upsampling layer. + A 3D Upsample3D layer using in [CogVideoX]() by Tsinghua University & ZhipuAI Args: in_channels (`int`): diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index d5957107629b..f896658693b7 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -21,7 +21,7 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3D +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler from ...utils import BaseOutput, logging, replace_example_docstring @@ -138,8 +138,8 @@ class CogVideoXPipeline(DiffusionPipeline): tokenizer (`T5Tokenizer`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - transformer ([`CogVideoXTransformer3D`]): - A text conditioned `CogVideoXTransformer3D` to denoise the encoded video latents. + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ @@ -158,7 +158,7 @@ def __init__( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, - transformer: CogVideoXTransformer3D, + transformer: CogVideoXTransformer3DModel, scheduler: CogVideoXDDIMScheduler, ): super().__init__() diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 17fcf30f4f3e..35684b88f2e7 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -141,8 +141,8 @@ from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler - from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from .scheduling_ddim import DDIMScheduler + from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index d857a4b4aaef..11cdec7797ff 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -24,7 +24,6 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @@ -192,7 +191,7 @@ def __init__( sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, - snr_shift_scale: float = 3.0 + snr_shift_scale: float = 3.0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -211,7 +210,7 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Modify: SNR shift following SD3 - self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1-snr_shift_scale) * self.alphas_cumprod) + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) # Rescale for zero SNR if rescale_betas_zero_snr: @@ -387,7 +386,7 @@ def step( " `v_prediction`" ) - a_t = ((1-alpha_prod_t_prev)/(1-alpha_prod_t))**0.5 + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t prev_sample = a_t * sample + b_t * pred_original_sample diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 44a20f57727b..66bbb093fb9e 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,7 +107,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class CogVideoXTransformer3D(metaclass=DummyObject): +class CogVideoXTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index 85968995d765..ade99a943290 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -24,7 +24,7 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKL, CogVideoXPipeline, CogVideoXTransformer3D, DDIMScheduler +from diffusers import AutoencoderKL, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, @@ -51,7 +51,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) - transformer = CogVideoXTransformer3D( + transformer = CogVideoXTransformer3DModel( sample_size=8, num_layers=1, patch_size=2, From fd0831c52c7a9ef71c9e3ec4d6fb1b64a148ef6b Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 3 Aug 2024 15:11:12 +0800 Subject: [PATCH 58/94] timestep fix --- src/diffusers/models/attention_processor.py | 1 + src/diffusers/models/transformers/cogvideox_transformer_3d.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 368ea7e9c5bd..fb39459a9051 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2072,6 +2072,7 @@ def __call__( self, attn: Attention, hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 3e6e3868faed..85b1580165d2 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -335,7 +335,7 @@ def forward( timesteps = timestep[None].to(hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - t_emb = self.time_proj(timesteps) + t_emb = self.time_proj(timestep) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. From d99528be943ee7c5966eb0b527178a1477c36f5e Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 3 Aug 2024 16:04:46 +0800 Subject: [PATCH 59/94] Restore the timesteps parameter --- src/diffusers/models/downsampling.py | 28 +++++++++---------- .../transformers/cogvideox_transformer_3d.py | 12 ++++---- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 05ef627e34ec..783d3971a242 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -321,35 +321,35 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if self.compress_time: - b, c, t, h, w = x.shape + batch_size, channels, frames, height, width = x.shape - # (b, c, t, h, w) -> (b, h, w, c, t) -> (b * h * w, c, t) - x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) + x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) if x.shape[-1] % 2 == 1: x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: - # (b * h * w, c, t - 1) -> (b * h * w, c, (t - 1) // 2) + # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) - # (b * h * w, c, (t // 2) + 1) -> (b, h, w, c, (t // 2) + 1) -> (b, c, (t // 2) + 1, h, w) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) else: - # (b * h * w, c, t) -> (b * h * w, c, t // 2) + # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) x = F.avg_pool1d(x, kernel_size=2, stride=2) - # (b * h * w, c, t // 2) -> (b, h, w, c, t // 2) -> (b, c, t // 2, h, w) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) # Pad the tensor pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) - b, c, t, h, w = x.shape - # (b, c, t, h, w) -> (b, t, c, h, w) -> (b * t, c, h, w) - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + batch_size, channels, frames, height, width = x.shape + # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) x = self.conv(x) - # (b * t, c, h, w) -> (b, t, c, h, w) -> (b, c, t, h, w) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) + x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) return x diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 85b1580165d2..1f8811a9b75e 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -322,7 +322,8 @@ def forward( batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding - if not torch.is_tensor(timestep): + timesteps = timestep + if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = hidden_states.device.type == "mps" @@ -330,12 +331,13 @@ def forward( dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timestep], dtype=dtype, device=hidden_states.device) - elif len(timestep.shape) == 0: - timesteps = timestep[None].to(hidden_states.device) + timesteps = torch.tensor([timesteps], dtype=dtype, device=hidden_states.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - t_emb = self.time_proj(timestep) + timesteps = timesteps.expand(hidden_states.shape[0]) + t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. From c7ee165c4f1868bf68cf922db080ebf9cdb4f80d Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 3 Aug 2024 22:22:28 +0800 Subject: [PATCH 60/94] Update downsampling.py --- src/diffusers/models/downsampling.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 783d3971a242..bfb5d04e13ec 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -342,14 +342,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) # Pad the tensor - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - batch_size, channels, frames, height, width = x.shape - # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) - x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) - x = self.conv(x) - # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) - x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + batch_size, channels, frames, height, width = x.shape + # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) + x = self.conv(x) + # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) + x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) return x From 61c6da076aff5389b48d8069de878bb8f4976038 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 3 Aug 2024 17:43:09 +0200 Subject: [PATCH 61/94] remove chunked ff code; reuse and refactor to support temb directly in adalayernorm --- scripts/convert_cogvideox_to_diffusers.py | 9 +-- src/diffusers/models/downsampling.py | 2 +- src/diffusers/models/normalization.py | 36 +++++++++--- .../transformers/cogvideox_transformer_3d.py | 56 ++----------------- 4 files changed, 39 insertions(+), 64 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 89c891ca68bd..9048844085a4 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -4,7 +4,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler +from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): @@ -75,9 +75,9 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): "time_embed.0": "time_embedding.linear_1", "time_embed.2": "time_embedding.linear_2", "mixins.patch_embed": "patch_embed", - "mixins.final_layer.norm_final": "norm_out", + "mixins.final_layer.norm_final": "norm_out.norm", "mixins.final_layer.linear": "proj_out", - "mixins.final_layer.adaLN_modulation.1": "adaln_out.1", + "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", } TRANSFORMER_SPECIAL_KEYS_REMAP = { @@ -176,6 +176,7 @@ def get_args(): parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" ) + parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") return parser.parse_args() @@ -192,7 +193,7 @@ def get_args(): text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) - text_encoder = T5EncoderModel.from_pretrained(text_encoder_id) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) scheduler = CogVideoXDDIMScheduler.from_config({ "snr_shift_scale": 3.0, diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index bfb5d04e13ec..905e7d9c374e 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -341,7 +341,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) - # Pad the tensor + # Pad the tensor pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) batch_size, channels, frames, height, width = x.shape diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index cb36ea0473ce..4bdd1aa16a4d 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -37,18 +37,36 @@ class AdaLayerNorm(nn.Module): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, num_embeddings: int): + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, use_embedding: bool = True): super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) + if use_embedding: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + output_dim = output_dim or embedding_dim * 2 + self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) - def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = torch.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift - return x + def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + input_ndim = hidden_states.ndim + + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if input_ndim == 3: + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + hidden_states = self.norm(hidden_states) * (1 + scale) + shift + return hidden_states class FP32LayerNorm(nn.LayerNorm): diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 1f8811a9b75e..614853661fa5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -25,29 +25,11 @@ from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import CogVideoXLayerNormZero - -from einops import rearrange +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): - # "feed_forward_chunk_size" can be used to save memory - if hidden_states.shape[chunk_dim] % chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = hidden_states.shape[chunk_dim] // chunk_size - ff_output = torch.cat( - [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, - ) - return ff_output - - @maybe_allow_in_graph class CogVideoXBlock(nn.Module): r""" @@ -91,12 +73,10 @@ def __init__( attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, @@ -110,10 +90,9 @@ def __init__( self.attn1 = Attention( query_dim=dim, - cross_attention_dim=cross_attention_dim, dim_head=attention_head_dim, heads=num_attention_heads, - qk_norm=norm_type if qk_norm else None, + qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, @@ -132,15 +111,6 @@ def __init__( bias=ff_bias, ) - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - def forward( self, hidden_states: torch.Tensor, @@ -167,11 +137,7 @@ def forward( # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - - if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) + ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] @@ -230,7 +196,6 @@ def __init__( text_embed_dim: int = 4096, num_layers: int = 30, dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, attention_bias: bool = True, sample_width: int = 90, sample_height: int = 60, @@ -240,7 +205,6 @@ def __init__( max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", timestep_activation_fn: str = "silu", - norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, spatial_interpolation_scale: float = 1.875, @@ -249,10 +213,6 @@ def __init__( super().__init__() inner_dim = num_attention_heads * attention_head_dim - self.height = sample_height - self.width = sample_width - self.frames = sample_frames - post_patch_height = sample_height // patch_size post_patch_width = sample_width // patch_size post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 @@ -279,7 +239,7 @@ def __init__( self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - # 4. Define spatial transformers blocks + # 4. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ CogVideoXBlock( @@ -288,10 +248,8 @@ def __init__( attention_head_dim=attention_head_dim, time_embed_dim=time_embed_dim, dropout=dropout, - cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, - norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) @@ -301,8 +259,7 @@ def __init__( self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) # 5. Output blocks - self.adaln_out = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * inner_dim)) - self.norm_out = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + self.norm_out = AdaLayerNorm(embedding_dim=time_embed_dim, output_dim=2 * inner_dim, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, use_embedding=False) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -393,8 +350,7 @@ def custom_forward(*inputs): hidden_states = self.norm_final(hidden_states) # 6. Final block - shift, scale = self.adaln_out(emb).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) # 7. Unpatchify From fa7fa9ccedc2d9b6c19e54e1b1c8aee17b7ee9fa Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 3 Aug 2024 18:48:54 +0200 Subject: [PATCH 62/94] =?UTF-8?q?make=20inference=202-3x=20faster=20(by=20?= =?UTF-8?q?fixing=20the=20bug=20i=20introduced)=20=F0=9F=9A=80=F0=9F=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/models/attention_processor.py | 14 ++------------ .../transformers/cogvideox_transformer_3d.py | 12 +++++------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fb39459a9051..7af037eaa65d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2075,12 +2075,8 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - + r"""Forward method for the CogVideoX Attention Processor.""" + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) batch_size, sequence_length, _ = hidden_states.shape if attention_mask is not None: @@ -2107,11 +2103,9 @@ def __call__( key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2120,11 +2114,7 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 614853661fa5..d1bd0c30894c 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -124,8 +124,11 @@ def forward( # attention text_length = norm_encoder_hidden_states.size(1) - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - attn_output = self.attn1(norm_hidden_states, attention_mask=attention_mask) + attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + ) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] @@ -315,11 +318,6 @@ def forward( encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] hidden_states = hidden_states[:, self.config.max_text_seq_length :] - # 4. Prepare attention mask - if attention_mask is None: - attention_mask = torch.ones(batch_size, self.num_patches + self.config.max_text_seq_length) - attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) - # 5. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: From 6988cc3a867f6df55c3adbf7edb9fdc3abc177d1 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sun, 4 Aug 2024 14:26:00 +0800 Subject: [PATCH 63/94] new schedule with dpm --- src/diffusers/__init__.py | 2 + .../pipelines/cogvideo/pipeline_cogvideox.py | 23 +- src/diffusers/schedulers/__init__.py | 2 + .../schedulers/scheduling_dpm_cogvideox.py | 482 ++++++++++++++++++ 4 files changed, 504 insertions(+), 5 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_dpm_cogvideox.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index effcde491f11..414bc82963f7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -156,6 +156,7 @@ "AmusedScheduler", "CMStochasticIterativeScheduler", "CogVideoXDDIMScheduler", + "CogVideoXDPMScheduler", "DDIMInverseScheduler", "DDIMParallelScheduler", "DDIMScheduler", @@ -598,6 +599,7 @@ AmusedScheduler, CMStochasticIterativeScheduler, CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index f896658693b7..0e5fb5a47aff 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...pipelines.pipeline_utils import DiffusionPipeline -from ...schedulers import CogVideoXDDIMScheduler +from ...schedulers import CogVideoXDDIMScheduler,CogVideoXDPMScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -159,7 +159,7 @@ def __init__( text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, - scheduler: CogVideoXDDIMScheduler, + scheduler: CogVideoXDPMScheduler, ): super().__init__() @@ -443,7 +443,10 @@ def __call__( fps: int = 8, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, - guidance_scale: float = 6, + guidance_scale: float = 7.5, + use_dpm_solver: bool = False, + use_dynamic_cfg: bool = True, + guidance_scale_schedule: Callable[[int, int], int] = lambda x, y: x + y, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -604,6 +607,8 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None for i, t in enumerate(timesteps): if self.interrupt: continue @@ -624,11 +629,19 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: + if use_dynamic_cfg: + new_guidance_scale = guidance_scale_schedule(guidance_scale, num_inference_steps - t.item()) # here "num_inference-t.item()" must be an int, not a tensor + else: + new_guidance_scale = guidance_scale noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + new_guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if not use_dpm_solver: + latents = self.scheduler.step(noise_pred.float(), t, latents.float(), **extra_step_kwargs, return_dict=False)[0].to(noise_pred.dtype) + else: + latents, old_pred_original_sample = self.scheduler.step(noise_pred.float(), old_pred_original_sample, t, timesteps[i-1] if i > 0 else None, latents.float(), **extra_step_kwargs, return_dict=False) + latents = latents.to(noise_pred.dtype) # call the callback, if provided if callback_on_step_end is not None: diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 35684b88f2e7..a52124fb1493 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -44,6 +44,7 @@ _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"] + _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"] _import_structure["scheduling_ddpm"] = ["DDPMScheduler"] @@ -143,6 +144,7 @@ from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler + from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py new file mode 100644 index 000000000000..705632953fcf --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -0,0 +1,482 @@ +# Copyright 2024 The The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): + + lamb = ((alpha_prod_t / (1-alpha_prod_t))**0.5).log() + lamb_next = ((alpha_prod_t_prev / (1-alpha_prod_t_prev))**0.5).log() + h = lamb_next - lamb + + if alpha_prod_t_back is not None: + lamb_previous = ((alpha_prod_t_back / (1-alpha_prod_t_back))**0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): + mult1 = ((1-alpha_prod_t_prev) / (1-alpha_prod_t))**0.5 * (-h).exp() + mult2 = (-2*h).expm1() * alpha_prod_t_prev**0.5 + + if alpha_prod_t_back is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def step( + self, + model_output: torch.Tensor, + old_pred_original_sample: torch.Tensor, + timestep: int, + timestep_back: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = False, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + mult = [mult for mult in self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)] + mult_noise = (1-alpha_prod_t_prev)**0.5 * (1 - (-2*h).exp())**0.5 + + prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * torch.randn_like(sample) + + if old_pred_original_sample is None or prev_timestep < 0: + # Save a network evaluation if all noise levels are 0 or on the first step + return prev_sample, pred_original_sample + else: + denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample + x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * torch.randn_like(sample) + + prev_sample = x_advanced + + return prev_sample, pred_original_sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps From ba4223ac3b901c392503b08cd6a4e4e50e41e47d Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sun, 4 Aug 2024 15:52:37 +0800 Subject: [PATCH 64/94] remove attenstion mask --- .../models/transformers/cogvideox_transformer_3d.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index d1bd0c30894c..fd89c19f7868 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -116,7 +116,6 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb @@ -127,7 +126,6 @@ def forward( attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, - attention_mask=attention_mask, ) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] @@ -275,7 +273,6 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], - attention_mask: Optional[Union[int, torch.Tensor]] = None, timestep_cond: Optional[torch.Tensor] = None, return_dict: bool = True, ): @@ -334,7 +331,6 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, emb, - attention_mask, **ckpt_kwargs, ) else: @@ -342,7 +338,6 @@ def custom_forward(*inputs): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, - attention_mask=attention_mask, ) hidden_states = self.norm_final(hidden_states) From 312f7dc4fde9f6820459810beeabc4a51351249d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 16:29:35 +0200 Subject: [PATCH 65/94] apply suggestions from review --- .../transformers/cogvideox_transformer_3d.py | 23 +++++--------- .../pipelines/cogvideo/pipeline_cogvideox.py | 30 +++++++++---------- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index fd89c19f7868..1efa66630bc5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -30,6 +30,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + @maybe_allow_in_graph class CogVideoXBlock(nn.Module): r""" @@ -260,7 +261,13 @@ def __init__( self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) # 5. Output blocks - self.norm_out = AdaLayerNorm(embedding_dim=time_embed_dim, output_dim=2 * inner_dim, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, use_embedding=False) + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + use_embedding=False, + ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -280,20 +287,6 @@ def forward( # 1. Time embedding timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = hidden_states.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=hidden_states.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(hidden_states.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(hidden_states.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 0e5fb5a47aff..46842a4f04da 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...pipelines.pipeline_utils import DiffusionPipeline -from ...schedulers import CogVideoXDDIMScheduler,CogVideoXDPMScheduler +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -159,7 +159,7 @@ def __init__( text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, - scheduler: CogVideoXDPMScheduler, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], ): super().__init__() @@ -443,10 +443,7 @@ def __call__( fps: int = 8, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, - guidance_scale: float = 7.5, - use_dpm_solver: bool = False, - use_dynamic_cfg: bool = True, - guidance_scale_schedule: Callable[[int, int], int] = lambda x, y: x + y, + guidance_scale: float = 6, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -629,19 +626,22 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: - if use_dynamic_cfg: - new_guidance_scale = guidance_scale_schedule(guidance_scale, num_inference_steps - t.item()) # here "num_inference-t.item()" must be an int, not a tensor - else: - new_guidance_scale = guidance_scale noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + new_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if not use_dpm_solver: - latents = self.scheduler.step(noise_pred.float(), t, latents.float(), **extra_step_kwargs, return_dict=False)[0].to(noise_pred.dtype) + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] else: - latents, old_pred_original_sample = self.scheduler.step(noise_pred.float(), old_pred_original_sample, t, timesteps[i-1] if i > 0 else None, latents.float(), **extra_step_kwargs, return_dict=False) - latents = latents.to(noise_pred.dtype) + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) # call the callback, if provided if callback_on_step_end is not None: From 1b1b26b65cd0b1ba92c040e26b53bc4d212c1132 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 16:29:57 +0200 Subject: [PATCH 66/94] make style --- scripts/convert_cogvideox_to_diffusers.py | 30 +++++++++++-------- src/diffusers/models/normalization.py | 14 +++++++-- src/diffusers/schedulers/__init__.py | 4 +-- .../schedulers/scheduling_ddim_cogvideox.py | 7 +++-- .../schedulers/scheduling_dpm_cogvideox.py | 29 ++++++++++-------- 5 files changed, 51 insertions(+), 33 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 9048844085a4..c03013a7fff9 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -176,7 +176,9 @@ def get_args(): parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" ) - parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) return parser.parse_args() @@ -195,18 +197,20 @@ def get_args(): tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) - scheduler = CogVideoXDDIMScheduler.from_config({ - "snr_shift_scale": 3.0, - "beta_end": 0.012, - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "clip_sample": False, - "num_train_timesteps": 1000, - "prediction_type": "v_prediction", - "rescale_betas_zero_snr": True, - "set_alpha_to_one": True, - "timestep_spacing": "linspace" - }) + scheduler = CogVideoXDDIMScheduler.from_config( + { + "snr_shift_scale": 3.0, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": True, + "set_alpha_to_one": True, + "timestep_spacing": "linspace", + } + ) pipe = CogVideoXPipeline( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 4bdd1aa16a4d..acc4417c15aa 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -37,7 +37,15 @@ class AdaLayerNorm(nn.Module): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, use_embedding: bool = True): + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + use_embedding: bool = True, + ): super().__init__() if use_embedding: self.emb = nn.Embedding(num_embeddings, embedding_dim) @@ -50,7 +58,9 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, out self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) - def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: input_ndim = hidden_states.ndim if self.emb is not None: diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a52124fb1493..bb9088538653 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -44,13 +44,13 @@ _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"] - _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"] _import_structure["scheduling_ddpm"] = ["DDPMScheduler"] _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"] @@ -144,13 +144,13 @@ from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler - from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index 11cdec7797ff..2528809b84da 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -371,15 +371,16 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output + # pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index 705632953fcf..783e54d61854 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -302,13 +302,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device) def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): - - lamb = ((alpha_prod_t / (1-alpha_prod_t))**0.5).log() - lamb_next = ((alpha_prod_t_prev / (1-alpha_prod_t_prev))**0.5).log() + lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log() + lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log() h = lamb_next - lamb if alpha_prod_t_back is not None: - lamb_previous = ((alpha_prod_t_back / (1-alpha_prod_t_back))**0.5).log() + lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next @@ -316,8 +315,8 @@ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None) return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): - mult1 = ((1-alpha_prod_t_prev) / (1-alpha_prod_t))**0.5 * (-h).exp() - mult2 = (-2*h).expm1() * alpha_prod_t_prev**0.5 + mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5 if alpha_prod_t_back is not None: mult3 = 1 + 1 / (2 * r) @@ -399,24 +398,25 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output + # pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) - + h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) - mult = [mult for mult in self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)] - mult_noise = (1-alpha_prod_t_prev)**0.5 * (1 - (-2*h).exp())**0.5 + mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)) + mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5 prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * torch.randn_like(sample) @@ -429,7 +429,10 @@ def step( prev_sample = x_advanced - return prev_sample, pred_original_sample + if not return_dict: + return (prev_sample, pred_original_sample) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( From ba1855c07e0dd7a8a6c24e365a0c20124678ad36 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 29 Jul 2024 12:37:47 +0530 Subject: [PATCH 67/94] add workflow to rebase with upstream main nightly. --- .github/workflows/upstream.yml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/upstream.yml diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml new file mode 100644 index 000000000000..331e81e5742f --- /dev/null +++ b/.github/workflows/upstream.yml @@ -0,0 +1,32 @@ +name: Rebase Upstream + +on: + schedule: + - cron: '0 0 * * *' # This runs the job nightly at midnight UTC + workflow_dispatch: + pull_request: + +permissions: + contents: write + +jobs: + rebase: + runs-on: ubuntu-latest + + steps: + - name: Checkout private repository + uses: actions/checkout@v2 + with: + ref: main + + - name: Fetch upstream changes + run: git fetch upstream + + - name: Rebase onto upstream main + run: git rebase upstream/main + + - name: Push changes to private main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + git push origin main --force From 7360ea1d03efcac8d27f8e2874e037cfc5445128 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 29 Jul 2024 12:39:27 +0530 Subject: [PATCH 68/94] add upstream --- .github/workflows/upstream.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml index 331e81e5742f..1965b9e9492b 100644 --- a/.github/workflows/upstream.yml +++ b/.github/workflows/upstream.yml @@ -19,6 +19,9 @@ jobs: with: ref: main + - name: Add upstream repository + run: git remote add upstream https://github.com/huggingface/diffusers.git + - name: Fetch upstream changes run: git fetch upstream From 2f1b7870e287926ce012ed13a08f82b3bcf09e5b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 29 Jul 2024 12:46:46 +0530 Subject: [PATCH 69/94] Revert "add workflow to rebase with upstream main nightly." --- .github/workflows/upstream.yml | 35 ---------------------------------- 1 file changed, 35 deletions(-) delete mode 100644 .github/workflows/upstream.yml diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml deleted file mode 100644 index 1965b9e9492b..000000000000 --- a/.github/workflows/upstream.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Rebase Upstream - -on: - schedule: - - cron: '0 0 * * *' # This runs the job nightly at midnight UTC - workflow_dispatch: - pull_request: - -permissions: - contents: write - -jobs: - rebase: - runs-on: ubuntu-latest - - steps: - - name: Checkout private repository - uses: actions/checkout@v2 - with: - ref: main - - - name: Add upstream repository - run: git remote add upstream https://github.com/huggingface/diffusers.git - - - name: Fetch upstream changes - run: git fetch upstream - - - name: Rebase onto upstream main - run: git rebase upstream/main - - - name: Push changes to private main - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - git push origin main --force From 90aa8be5344a08f2830dd5408b4ff85c0121df56 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 29 Jul 2024 12:51:46 +0530 Subject: [PATCH 70/94] add workflow for rebasing with upstream automatically. --- .github/workflows/upstream.yaml | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 .github/workflows/upstream.yaml diff --git a/.github/workflows/upstream.yaml b/.github/workflows/upstream.yaml new file mode 100644 index 000000000000..124cd50361d5 --- /dev/null +++ b/.github/workflows/upstream.yaml @@ -0,0 +1,36 @@ + +name: Rebase Upstream + +on: + schedule: + - cron: '0 0 * * *' # This runs the job nightly at midnight UTC + workflow_dispatch: + +permissions: + contents: write + +jobs: + rebase: + runs-on: ubuntu-latest + + steps: + - name: Checkout private repository + uses: actions/checkout@v2 + with: + ref: main + + - name: Add upstream repository + run: git remote add upstream https://github.com/huggingface/diffusers.git + + - name: Fetch upstream changes + run: git fetch upstream + + - name: Rebase onto upstream main + run: git rebase upstream/main + + - name: Push changes to private main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + git push origin main --force + From 92c8c0075613a6e5e0239e732a8d3470db0fe22d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 23:29:11 +0200 Subject: [PATCH 71/94] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 30 +++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 66bbb093fb9e..7dd596ba6535 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -990,6 +990,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CogVideoXDDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CogVideoXDPMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] From 03580c07b9d67d8d23df4b12cb7d2101e461650f Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 23:37:31 +0200 Subject: [PATCH 72/94] remove cogvideox-specific attention processor --- src/diffusers/models/attention_processor.py | 62 ------------------- .../transformers/cogvideox_transformer_3d.py | 8 ++- 2 files changed, 5 insertions(+), 65 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d78d4e05dfed..784eaaa62c55 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2243,68 +2243,6 @@ def __call__( return hidden_states -class CogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the CogVideoXTransformer3DModel. It applies a normalization layer on the query and key vectors. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "CogVideoXAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r"""Forward method for the CogVideoX Attention Processor.""" - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - batch_size, sequence_length, _ = hidden_states.shape - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Apply QK norms - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - - class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 1efa66630bc5..b2572712618f 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -21,7 +21,6 @@ from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward -from ..attention_processor import CogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -97,7 +96,6 @@ def __init__( eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), ) # 2. Feed Forward @@ -124,9 +122,13 @@ def forward( # attention text_length = norm_encoder_hidden_states.size(1) + + # CogVideoX uses concatenated text + video embeddings with self-attention instead of using + # them in cross-attention individually + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) attn_output = self.attn1( hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, + encoder_hidden_states=None, ) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] From 01c2dff338e7fbd8ae42b71682ed5346988586a5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 23:45:56 +0200 Subject: [PATCH 73/94] update docs --- .../en/api/models/cogvideox_transformer3d.md | 5 ++--- docs/source/en/api/pipelines/cogvideox.md | 14 +++++++------- .../pipelines/cogvideo/pipeline_cogvideox.py | 8 ++++++++ tests/pipelines/cogvideox/test_cogvideox.py | 5 +---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md index 242c7d0bbbe8..50f62133d1c3 100644 --- a/docs/source/en/api/models/cogvideox_transformer3d.md +++ b/docs/source/en/api/models/cogvideox_transformer3d.md @@ -1,4 +1,4 @@ - +specific language governing permissions and limitations under the License. --> ## CogVideoXTransformer3DModel diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 2c96d714a723..acb397dc3879 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -17,8 +17,7 @@ # CogVideoX - -[]() from Tsinghua University & ZhipuAI. +[TODO]() from Tsinghua University & ZhipuAI. The abstract from the paper is: @@ -58,14 +57,16 @@ Finally, compile the components and run inference: pipeline.transformer = torch.compile(pipeline.transformer) pipeline.vae.decode = torch.compile(pipeline.vae.decode) -video = pipeline(prompt="A dog wearing sunglasses floating in space, surreal, nebulae in background").frames[0] +# CogVideoX works very well with long and well-described prompts +prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." +video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] ``` -The [benchmark]() results on an 80GB A100 machine are: +The [benchmark](TODO: link) results on an 80GB A100 machine are: ``` -Without torch.compile(): Average inference time: 16.246 seconds. -With torch.compile(): Average inference time: 14.573 seconds. +Without torch.compile(): Average inference time: TODO seconds. +With torch.compile(): Average inference time: TODO seconds. ``` ## CogVideoXPipeline @@ -76,4 +77,3 @@ With torch.compile(): Average inference time: 14.573 seconds. ## CogVideoXPipelineOutput [[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput - diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 46842a4f04da..81f55b298011 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -39,6 +39,14 @@ >>> from diffusers.utils import export_to_video >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) >>> video = pipe("a polar bear dancing, high quality, realistic", num_inference_steps=20).frames[0] >>> export_to_video(video, "output.mp4", fps=8) diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index ade99a943290..b3402dedef65 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -1,7 +1,4 @@ -# Todo: Only a Draft - -# coding=utf-8 -# Copyright 2024 The The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# Copyright 2024 The HuggingFace Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 311845fc775e9e504912fe2788b42c40a4e1a792 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 23:46:50 +0200 Subject: [PATCH 74/94] update docs --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 81f55b298011..e2f8859ad1bc 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -47,8 +47,9 @@ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " ... "atmosphere of this unique musical performance." ... ) - >>> video = pipe("a polar bear dancing, high quality, realistic", num_inference_steps=20).frames[0] - + >>> video = pipe( + ... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=50 + ... ).frames[0] >>> export_to_video(video, "output.mp4", fps=8) ``` """ From 1b1b737acb0d431633ef245ac2af239a98bb187a Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 5 Aug 2024 11:28:31 +0800 Subject: [PATCH 75/94] cogvideox branch --- docs/source/en/api/loaders/single_file.md | 2 + .../en/api/models/autoencoderkl_cogvideox.md | 37 + .../en/api/models/cogvideox_transformer3d.md | 18 + docs/source/en/api/pipelines/cogvideox.md | 79 ++ scripts/convert_cogvideox_to_diffusers.py | 222 ++++ src/diffusers/__init__.py | 10 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_cogvideox.py | 964 ++++++++++++++++++ src/diffusers/models/downsampling.py | 68 ++ src/diffusers/models/embeddings.py | 85 ++ src/diffusers/models/normalization.py | 70 +- src/diffusers/models/transformers/__init__.py | 1 + .../transformers/cogvideox_transformer_3d.py | 352 +++++++ src/diffusers/models/upsampling.py | 65 ++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/cogvideo/__init__.py | 48 + .../pipelines/cogvideo/pipeline_cogvideox.py | 681 +++++++++++++ src/diffusers/schedulers/__init__.py | 4 + .../schedulers/scheduling_ddim_cogvideox.py | 479 +++++++++ .../schedulers/scheduling_dpm_cogvideox.py | 486 +++++++++ src/diffusers/utils/dummy_pt_objects.py | 60 ++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/cogvideox/__init__.py | 0 tests/pipelines/cogvideox/test_cogvideox.py | 289 ++++++ 25 files changed, 4033 insertions(+), 9 deletions(-) create mode 100644 docs/source/en/api/models/autoencoderkl_cogvideox.md create mode 100644 docs/source/en/api/models/cogvideox_transformer3d.md create mode 100644 docs/source/en/api/pipelines/cogvideox.md create mode 100644 scripts/convert_cogvideox_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py create mode 100644 src/diffusers/models/transformers/cogvideox_transformer_3d.py create mode 100644 src/diffusers/pipelines/cogvideo/__init__.py create mode 100644 src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py create mode 100644 src/diffusers/schedulers/scheduling_ddim_cogvideox.py create mode 100644 src/diffusers/schedulers/scheduling_dpm_cogvideox.py create mode 100644 tests/pipelines/cogvideox/__init__.py create mode 100644 tests/pipelines/cogvideox/test_cogvideox.py diff --git a/docs/source/en/api/loaders/single_file.md b/docs/source/en/api/loaders/single_file.md index 0af0ce6488d4..1c17a740755f 100644 --- a/docs/source/en/api/loaders/single_file.md +++ b/docs/source/en/api/loaders/single_file.md @@ -22,6 +22,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: ## Supported pipelines +- [`CogVideoXPipeline`] - [`StableDiffusionPipeline`] - [`StableDiffusionImg2ImgPipeline`] - [`StableDiffusionInpaintPipeline`] @@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: - [`UNet2DConditionModel`] - [`StableCascadeUNet`] - [`AutoencoderKL`] +- [`AutoencoderKLCogVideoX`] - [`ControlNetModel`] - [`SD3Transformer2DModel`] diff --git a/docs/source/en/api/models/autoencoderkl_cogvideox.md b/docs/source/en/api/models/autoencoderkl_cogvideox.md new file mode 100644 index 000000000000..a74c4d062ad6 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_cogvideox.md @@ -0,0 +1,37 @@ + + +# AutoencoderKLCogVideoX + +The 3D variational autoencoder (VAE) model with KL loss using with CogVideoX. + +The abstract from the paper is: + + +## Loading from the original format + +By default the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded +from the original format using [`FromOriginalModelMixin.from_single_file`] as follows: + +```py +from diffusers import AutoencoderKLCogVideoX + +url = "3d-vae.pt" # can also be a local file +model = AutoencoderKLCogVideoX.from_single_file(url) +``` + +## AutoencoderKLCogVideoX + +[[autodoc]] AutoencoderKLCogVideoX + - decode + - encode + - all diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md new file mode 100644 index 000000000000..065509d5a69e --- /dev/null +++ b/docs/source/en/api/models/cogvideox_transformer3d.md @@ -0,0 +1,18 @@ + + +## CogVideoXTransformer3DModel + +A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX). + +## CogVideoXTransformer3DModel + +[[autodoc]] CogVideoXTransformer3DModel diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md new file mode 100644 index 000000000000..52a4f9fc6d89 --- /dev/null +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -0,0 +1,79 @@ + + +# CogVideoX + +[TODO]() from Tsinghua University & ZhipuAI. + +The abstract from the paper is: + +The paper is still being written. + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +### Inference + +Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. + +First, load the pipeline: + +```python +import torch +from diffusers import LattePipeline + +pipeline = LattePipeline.from_pretrained( + "THUDM/CogVideoX-2b", torch_dtype=torch.float16 +).to("cuda") +``` + +Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: + +```python +pipeline.transformer.to(memory_format=torch.channels_last) +pipeline.vae.to(memory_format=torch.channels_last) +``` + +Finally, compile the components and run inference: + +```python +pipeline.transformer = torch.compile(pipeline.transformer) +pipeline.vae.decode = torch.compile(pipeline.vae.decode) + +# CogVideoX works very well with long and well-described prompts +prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." +video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +``` + +The [benchmark](TODO: link) results on an 80GB A100 machine are: + +``` +Without torch.compile(): Average inference time: TODO seconds. +With torch.compile(): Average inference time: TODO seconds. +``` + +## CogVideoXPipeline + +[[autodoc]] CogVideoXPipeline + - all + - __call__ + +## CogVideoXPipelineOutput +[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py new file mode 100644 index 000000000000..c03013a7fff9 --- /dev/null +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -0,0 +1,222 @@ +import argparse +from typing import Any, Dict + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel + + +def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): + to_q_key = key.replace("query_key_value", "to_q") + to_k_key = key.replace("query_key_value", "to_k") + to_v_key = key.replace("query_key_value", "to_v") + to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) + state_dict[to_q_key] = to_q + state_dict[to_k_key] = to_k + state_dict[to_v_key] = to_v + state_dict.pop(key) + + +def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, weight_or_bias = key.split(".")[-2:] + + if "query" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" + elif "key" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" + + state_dict[new_key] = state_dict.pop(key) + + +def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, _, weight_or_bias = key.split(".")[-3:] + + weights_or_biases = state_dict[key].chunk(12, dim=0) + norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) + norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) + + norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" + state_dict[norm1_key] = norm1_weights_or_biases + + norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" + state_dict[norm2_key] = norm2_weights_or_biases + + state_dict.pop(key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): + state_dict.pop(key) + + +def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): + key_split = key.split(".") + layer_index = int(key_split[2]) + replace_layer_index = 4 - 1 - layer_index + + key_split[1] = "up_blocks" + key_split[2] = str(replace_layer_index) + new_key = ".".join(key_split) + + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "transformer.final_layernorm": "norm_final", + "transformer": "transformer_blocks", + "attention": "attn1", + "mlp": "ff.net", + "dense_h_to_4h": "0.proj", + "dense_4h_to_h": "2", + ".layers": "", + "dense": "to_out.0", + "input_layernorm": "norm1.norm", + "post_attn1_layernorm": "norm2.norm", + "time_embed.0": "time_embedding.linear_1", + "time_embed.2": "time_embedding.linear_2", + "mixins.patch_embed": "patch_embed", + "mixins.final_layer.norm_final": "norm_out.norm", + "mixins.final_layer.linear": "proj_out", + "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "query_key_value": reassign_query_key_value_inplace, + "query_layernorm_list": reassign_query_key_layernorm_inplace, + "key_layernorm_list": reassign_query_key_layernorm_inplace, + "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, + "embed_tokens": remove_keys_inplace, +} + +VAE_KEYS_RENAME_DICT = { + "block.": "resnets.", + "down.": "down_blocks.", + "downsample": "downsamplers.0", + "upsample": "upsamplers.0", + "nin_shortcut": "conv_shortcut", + "encoder.mid.block_1": "encoder.mid_block.resnets.0", + "encoder.mid.block_2": "encoder.mid_block.resnets.1", + "decoder.mid.block_1": "decoder.mid_block.resnets.0", + "decoder.mid.block_2": "decoder.mid_block.resnets.1", +} + +VAE_SPECIAL_KEYS_REMAP = { + "loss": remove_keys_inplace, + "up.": replace_up_keys_inplace, +} + +TOKENIZER_MAX_LENGTH = 226 + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def convert_transformer(ckpt_path: str): + PREFIX_KEY = "model.diffusion_model." + + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + transformer = CogVideoXTransformer3DModel() + + for key in list(original_state_dict.keys()): + new_key = key[len(PREFIX_KEY) :] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True) + return transformer + + +def convert_vae(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + vae = AutoencoderKLCogVideoX() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" + ) + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + transformer = None + vae = None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + scheduler = CogVideoXDDIMScheduler.from_config( + { + "snr_shift_scale": 3.0, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": True, + "set_alpha_to_one": True, + "timestep_spacing": "linspace", + } + ) + + pipe = CogVideoXPipeline( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + if args.fp16: + pipe = pipe.to(dtype=torch.float16) + + pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d58bbdac1867..4e840a823b83 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -78,9 +78,11 @@ "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", + "AutoencoderKLCogVideoX", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", + "CogVideoXTransformer3DModel", "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAdapter", @@ -154,6 +156,8 @@ [ "AmusedScheduler", "CMStochasticIterativeScheduler", + "CogVideoXDDIMScheduler", + "CogVideoXDPMScheduler", "DDIMInverseScheduler", "DDIMParallelScheduler", "DDIMScheduler", @@ -249,6 +253,7 @@ "ChatGLMModel", "ChatGLMTokenizer", "CLIPImageProjection", + "CogVideoXPipeline", "CycleDiffusionPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", @@ -523,9 +528,11 @@ AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, AutoencoderKL, + AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, + CogVideoXTransformer3DModel, ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAdapter, @@ -596,6 +603,8 @@ from .schedulers import ( AmusedScheduler, CMStochasticIterativeScheduler, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, @@ -672,6 +681,7 @@ ChatGLMModel, ChatGLMTokenizer, CLIPImageProjection, + CogVideoXPipeline, CycleDiffusionPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 53fd3ebd4bbd..fe57b646664d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -28,6 +28,7 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] @@ -41,6 +42,7 @@ _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] + _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] @@ -77,6 +79,7 @@ from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, @@ -92,6 +95,7 @@ from .modeling_utils import ModelMixin from .transformers import ( AuraFlowTransformer2DModel, + CogVideoXTransformer3DModel, DiTTransformer2DModel, DualTransformer2DModel, FluxTransformer2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 885007b54ea1..ccf4552b2a5e 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,5 +1,6 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py new file mode 100644 index 000000000000..3c370cc32106 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -0,0 +1,964 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..downsampling import Downsample3D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..upsampling import Upsample3D +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +class CogVideoXSafeConv3d(nn.Conv3d): + """ + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + + # Set to 2GB, suitable for CuDNN + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) + + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super().forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super().forward(input) + + +class CogVideoXCausalConv3d(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: int = 1, + dilation: int = 1, + pad_mode: str = "constant", + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + self.temporal_dim = 2 + self.time_kernel_size = time_kernel_size + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = CogVideoXSafeConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + self.conv_cache = None + + def fake_cp_pass_from_previous_rank(self, inputs: torch.Tensor) -> torch.Tensor: + dim = self.temporal_dim + kernel_size = self.time_kernel_size + if kernel_size == 1: + return inputs + + inputs = inputs.transpose(0, dim) + + if self.conv_cache is not None: + inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0) + else: + inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0) + + inputs = inputs.transpose(0, dim).contiguous() + return inputs + + def forward(self, inputs: torch.Tensor, clear_fake_cp_cache: bool = True): + input_parallel = self.fake_cp_pass_from_previous_rank(inputs) + + del self.conv_cache + self.conv_cache = None + if not clear_fake_cp_cache: + self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) + + output_parallel = self.conv(input_parallel) + output = output_parallel + return output + + +class CogVideoXSpatialNorm3D(nn.Module): + r""" + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific + to 3D-video like data. + + CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = F.interpolate(z_first, size=f_first_size) + z_rest = F.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = F.interpolate(zq, size=f.shape[-3:]) + + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class CogVideoXResnetBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=spatial_norm_dim, + ) + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=spatial_norm_dim, + ) + + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if temb_channels > 0: + self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) + + self.dropout = nn.Dropout(dropout) + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + else: + self.conv_shortcut = CogVideoXSafeConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + clear_fake_cp_cache: bool = True, + ) -> torch.Tensor: + hidden_states = inputs + if zq is not None: + hidden_states = self.norm1(hidden_states, zq) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + hidden_states = self.norm2(hidden_states, zq) + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + inputs = self.conv_shortcut(inputs, clear_fake_cp_cache=clear_fake_cp_cache) + else: + inputs = self.conv_shortcut(inputs) + + output_tensor = inputs + hidden_states + return output_tensor + + +class CogVideoXDownBlock3D(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 0, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.downsamplers = None + + if add_downsample: + self.downsamplers = nn.ModuleList( + [Downsample3D(out_channels, out_channels, padding=downsample_padding, compress_time=compress_time)] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + clear_fake_cp_cache: bool = False, + ) -> torch.Tensor: + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache + ) + else: + hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class CogVideoXMidBlock3D(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + spatial_norm_dim=spatial_norm_dim, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + clear_fake_cp_cache: bool = False, + ) -> torch.Tensor: + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache + ) + else: + hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) + + return hidden_states + + +class CogVideoXUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: int = 16, + add_upsample: bool = True, + upsample_padding: int = 1, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_norm_dim=spatial_norm_dim, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.upsamplers = None + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time)] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + clear_fake_cp_cache: bool = False, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXUpBlock3D` class.""" + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache + ) + else: + hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CogVideoXEncoder3D(nn.Module): + r""" + The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + # log2 of temporal_compress_times + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.down_blocks = nn.ModuleList([]) + + # down blocks + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if down_block_type == "CogVideoXDownBlock3D": + down_block = CogVideoXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + compress_time=compress_time, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=block_out_channels[-1], + temb_channels=0, + dropout=dropout, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward( + self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True + ) -> torch.Tensor: + r"""The forward method of the `CogVideoXEncoder3D` class.""" + hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Down + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, temb, None, clear_fake_cp_cache + ) + + # 2. Mid + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, None, clear_fake_cp_cache + ) + else: + # 1. Down + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, temb, None, clear_fake_cp_cache) + + # 2. Mid + hidden_states = self.mid_block(hidden_states, temb, None, clear_fake_cp_cache) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + return hidden_states + + +class CogVideoXDecoder3D(nn.Module): + r""" + The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + reversed_block_out_channels = list(reversed(block_out_channels)) + + self.conv_in = CogVideoXCausalConv3d( + in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode + ) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=reversed_block_out_channels[0], + temb_channels=0, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + pad_mode=pad_mode, + ) + + # up blocks + self.up_blocks = nn.ModuleList([]) + + output_channel = reversed_block_out_channels[0] + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if up_block_type == "CogVideoXUpBlock3D": + up_block = CogVideoXUpBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final_block, + compress_time=compress_time, + pad_mode=pad_mode, + ) + prev_output_channel = output_channel + else: + raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") + + self.up_blocks.append(up_block) + + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward( + self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True + ) -> torch.Tensor: + r"""The forward method of the `CogVideoXDecoder3D` class.""" + hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Mid + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, sample, clear_fake_cp_cache + ) + + # 2. Up + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, temb, sample, clear_fake_cp_cache + ) + else: + # 1. Mid + hidden_states = self.mid_block(hidden_states, temb, sample, clear_fake_cp_cache) + + # 2. Up + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + return hidden_states + + +class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encodfing images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["CogVideoXResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types: Tuple[str] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 256, 512), + latent_channels: int = 16, + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_size: int = 256, + scaling_factor: float = 1.15258426, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + mid_block_add_attention: bool = True, + ): + super().__init__() + + self.encoder = CogVideoXEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = CogVideoXDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = False + + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True, fake_cp: bool = False + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + fake_cp (`bool`, *optional*, defaults to `True`): + If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self.encoder(x, clear_fake_cp_cache=not fake_cp) + if self.quant_conv is not None: + h = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(h) + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, fake_cp: bool = False + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + fake_cp (`bool`, *optional*, defaults to `True`): + If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + dec = self.decoder(z, clear_fake_cp_cache=not fake_cp) + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 4e384e731c74..905e7d9c374e 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -285,6 +285,74 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.conv2d(inputs, weight, stride=2) +class Downsample3D(nn.Module): + # Todo: Wait for paper relase. + r""" + A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `2`): + Stride of the convolution. + padding (`int`, defaults to `0`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 2, + padding: int = 0, + compress_time: bool = False, + ): + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.compress_time: + batch_size, channels, frames, height, width = x.shape + + # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) + x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) + + if x.shape[-1] % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + + x = torch.cat([x_first[..., None], x_rest], dim=-1) + # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) + x = F.avg_pool1d(x, kernel_size=2, stride=2) + # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + + # Pad the tensor + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + batch_size, channels, frames, height, width = x.shape + # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) + x = self.conv(x) + # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) + x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + def downsample_2d( hidden_states: torch.Tensor, kernel: Optional[torch.Tensor] = None, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a81f9e17cd0e..fdcdee620109 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -78,6 +78,53 @@ def get_timestep_embedding( return emb +def get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, +) -> np.ndarray: + r""" + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + return pos_embed + + def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): @@ -287,6 +334,44 @@ def forward(self, x, freqs_cis): ) +class CogVideoXPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, + ) -> None: + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + r""" + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + + B, F, C, H, W = image_embeds.shape + image_embeds = image_embeds.reshape(-1, C, H, W) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(B, F, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [B, F, H x W, C] + image_embeds = image_embeds.flatten(1, 2) # [B, F x H x W, C] + + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() # [B, S + F x H x W, C] + return embeds + + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 8d09999e5c95..dcedd529fab9 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -37,18 +37,46 @@ class AdaLayerNorm(nn.Module): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, num_embeddings: int): + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + use_embedding: bool = True, + ): super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) + if use_embedding: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + output_dim = output_dim or embedding_dim * 2 + self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) - def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = torch.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift - return x + def forward( + self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + input_ndim = hidden_states.ndim + + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if input_ndim == 3: + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + hidden_states = self.norm(hidden_states) * (1 + scale) + shift + return hidden_states class FP32LayerNorm(nn.LayerNorm): @@ -321,6 +349,30 @@ def forward( return x +class CogVideoXLayerNormZero(nn.Module): + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + if is_torch_version(">=", "2.1.0"): LayerNorm = nn.LayerNorm else: diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d0d351ce88e1..d55dfe57d6f3 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -3,6 +3,7 @@ if is_torch_available(): from .auraflow_transformer_2d import AuraFlowTransformer2DModel + from .cogvideox_transformer_3d import CogVideoXTransformer3DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py new file mode 100644 index 000000000000..8d3111296ec6 --- /dev/null +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -0,0 +1,352 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, FeedForward +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in CogVideoX model. TODO: add link to CogVideoX upon release + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + ) -> torch.Tensor: + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + text_length = norm_encoder_hidden_states.size(1) + + # CogVideoX uses concatenated text + video embeddings with self-attention instead of using + # them in cross-attention individually + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + ) + + hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input. + out_channels (`int`, *optional*): + The number of channels in the output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + patch_size (`int`, *optional*): + The size of the patches to use in the patch embedding layer. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. During inference, you can denoise for up to but not more steps than + `num_embeds_ada_norm`. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. + caption_channels (`int`, *optional*): + The number of channels in the caption embeddings. + video_length (`int`, *optional*): + The number of frames in the video-like data. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: Optional[int] = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + post_patch_height = sample_height // patch_size + post_patch_width = sample_width // patch_size + post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 + self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. 3D positional embeddings + spatial_pos_embedding = get_3d_sincos_pos_embed( + inner_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + spatial_interpolation_scale, + temporal_interpolation_scale, + ) + spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) + pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) + pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) + self.register_buffer("pos_embedding", pos_embedding, persistent=False) + + # 3. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 4. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 5. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + use_embedding=False, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + + # 3. Position embedding + seq_length = height * width * num_frames // (self.config.patch_size**2) + + pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] + hidden_states = hidden_states + pos_embeds + hidden_states = self.embedding_dropout(hidden_states) + + encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] + hidden_states = hidden_states[:, self.config.max_text_seq_length :] + + # 5. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + ) + + hidden_states = self.norm_final(hidden_states) + + # 6. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 7. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 572844d2de0a..e04e1dd4c448 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -348,6 +348,71 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) +class Upsample3D(nn.Module): + # Todo: Wait for paper relase. + r""" + A 3D Upsample3D layer using in [CogVideoX]() by Tsinghua University & ZhipuAI + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `1`): + Stride of the convolution. + padding (`int`, defaults to `1`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + compress_time: bool = False, + ) -> None: + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if self.compress_time: + if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + inputs = torch.cat([x_first, x_rest], dim=2) + elif inputs.shape[2] > 1: + inputs = F.interpolate(inputs, scale_factor=2.0) + else: + inputs = inputs.squeeze(2) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + # only interpolate 2D + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = self.conv(inputs) + inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) + + return inputs + + def upfirdn2d_native( tensor: torch.Tensor, kernel: torch.Tensor, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 10f6c4a92054..c08f45bb0c97 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -131,6 +131,7 @@ "AudioLDM2UNet2DConditionModel", ] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] + _import_structure["cogvideo"] = ["CogVideoXPipeline"] _import_structure["controlnet"].extend( [ "BlipDiffusionControlNetPipeline", @@ -437,6 +438,7 @@ ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline + from .cogvideo import CogVideoXPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/cogvideo/__init__.py b/src/diffusers/pipelines/cogvideo/__init__.py new file mode 100644 index 000000000000..d155d3ef51b7 --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_cogvideox import CogVideoXPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py new file mode 100644 index 000000000000..e2f8859ad1bc --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -0,0 +1,681 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe( + ... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=50 + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class CogVideoXPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +class CogVideoXPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor, num_seconds: int): + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = [] + for i in range(num_seconds): + # Whether or not to clear fake context parallel cache + fake_cp = i + 1 < num_seconds + start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) + + current_frames = self.vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample + frames.append(current_frames) + + frames = torch.cat(frames, dim=2) + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_seconds: int = 6, + fps: int = 8, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_seconds (`int`, defaults to `6`): + Duration of video in seconds. Must be less than or equal to 6. + fps (`int`, defaults to `8`): + Number of frames per second in video. Must be equal to 8 (for now). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + assert num_seconds <= 6 and fps == 8 + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + num_frames = 1 + num_seconds * fps + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latents": + video = self.decode_latents(latents, num_seconds) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 696e9c3ad5d5..bb9088538653 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -43,12 +43,14 @@ _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] + _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"] _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"] _import_structure["scheduling_ddpm"] = ["DDPMScheduler"] _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"] @@ -141,12 +143,14 @@ from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler + from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py new file mode 100644 index 000000000000..edb443169aa2 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -0,0 +1,479 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + # breakpoint() + # # 5. compute variance: "sigma_t(η)" -> see formula (16) + # # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + # variance = self._get_variance(timestep, prev_timestep) + # std_dev_t = eta * variance ** (0.5) + + # if use_clipped_model_output: + # # the pred_epsilon is always re-derived from the clipped x_0 in Glide + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + # if eta > 0: + # if variance_noise is not None and generator is not None: + # raise ValueError( + # "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + # " `variance_noise` stays `None`." + # ) + + # if variance_noise is None: + # variance_noise = randn_tensor( + # model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + # ) + # variance = std_dev_t * variance_noise + + # prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py new file mode 100644 index 000000000000..c02300708de9 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -0,0 +1,486 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): + lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log() + lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log() + h = lamb_next - lamb + + if alpha_prod_t_back is not None: + lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): + mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5 + + if alpha_prod_t_back is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def step( + self, + model_output: torch.Tensor, + old_pred_original_sample: torch.Tensor, + timestep: int, + timestep_back: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = False, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)) + mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5 + + prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * torch.randn_like(sample) + + if old_pred_original_sample is None or prev_timestep < 0: + # Save a network evaluation if all noise levels are 0 or on the first step + return prev_sample, pred_original_sample + else: + denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample + x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * torch.randn_like(sample) + + prev_sample = x_advanced + + if not return_dict: + return (prev_sample, pred_original_sample) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 34c6c20f7f2c..740b08249104 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLCogVideoX(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] @@ -92,6 +107,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CogVideoXTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ConsistencyDecoderVAE(metaclass=DummyObject): _backends = ["torch"] @@ -975,6 +1005,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CogVideoXDDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CogVideoXDPMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3e9a33503906..dd8c7f624406 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -287,6 +287,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogVideoXPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/cogvideox/__init__.py b/tests/pipelines/cogvideox/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py new file mode 100644 index 000000000000..b3402dedef65 --- /dev/null +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -0,0 +1,289 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = CogVideoXPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CogVideoXTransformer3DModel( + sample_size=8, + num_layers=1, + patch_size=2, + attention_head_dim=8, + num_attention_heads=3, + caption_channels=32, + in_channels=4, + cross_attention_dim=24, + out_channels=8, + attention_bias=True, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + ) + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = DDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "negative_prompt": "low quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "video_length": 1, + "output_type": "pt", + "clean_caption": False, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (1, 3, 8, 8)) + expected_video = torch.randn(1, 3, 8, 8) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass(self): + pass + + def test_save_load_optional_components(self): + if not hasattr(self.pipeline_class, "_optional_components"): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + + ( + prompt_embeds, + negative_prompt_embeds, + ) = pipe.encode_prompt(prompt) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "negative_prompt": None, + "negative_prompt_embeds": negative_prompt_embeds, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "video_length": 1, + "mask_feature": False, + "output_type": "pt", + "clean_caption": False, + } + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1.0) + + +@slow +@require_torch_gpu +class CogVideoXPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_cogvideox(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = CogVideoXPipeline.from_pretrained("THUDM/cogvideox-2b", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + videos = pipe( + prompt=prompt, + height=512, + width=512, + generator=generator, + num_inference_steps=2, + clean_caption=False, + ).frames + + video = videos[0] + expected_video = torch.randn(1, 512, 512, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video.fCogVideoXn(), expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video.fCogVideoXn()}" From 2d9602cc96937a25f37d254538f1606cf25ebe5e Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 5 Aug 2024 13:55:25 +0800 Subject: [PATCH 76/94] add CogVideoX team, Tsinghua University & ZhipuAI --- docs/source/en/api/models/cogvideox_transformer3d.md | 2 +- docs/source/en/api/pipelines/cogvideox.md | 2 +- src/diffusers/models/transformers/cogvideox_transformer_3d.py | 3 ++- src/diffusers/schedulers/scheduling_ddim_cogvideox.py | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md index 50f62133d1c3..5cf7549812ea 100644 --- a/docs/source/en/api/models/cogvideox_transformer3d.md +++ b/docs/source/en/api/models/cogvideox_transformer3d.md @@ -1,4 +1,4 @@ - +specific language governing permissions and limitations under the License. --> # AutoencoderKLCogVideoX -The 3D variational autoencoder (VAE) model with KL loss using with CogVideoX. - -The abstract from the paper is: - +The 3D variational autoencoder (VAE) model with KL loss using CogVideoX. ## Loading from the original format -By default the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded -from the original format using [`FromOriginalModelMixin.from_single_file`] as follows: +By default, the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows: ```py from diffusers import AutoencoderKLCogVideoX url = "THUDM/CogVideoX-2b" # can also be a local file model = AutoencoderKLCogVideoX.from_single_file(url) + ``` ## AutoencoderKLCogVideoX @@ -35,3 +31,39 @@ model = AutoencoderKLCogVideoX.from_single_file(url) - decode - encode - all + +## CogVideoXSafeConv3d + +[[autodoc]] CogVideoXSafeConv3d + +## CogVideoXCausalConv3d + +[[autodoc]] CogVideoXCausalConv3d + +## CogVideoXSpatialNorm3D + +[[autodoc]] CogVideoXSpatialNorm3D + +## CogVideoXResnetBlock3D + +[[autodoc]] CogVideoXResnetBlock3D + +## CogVideoXDownBlock3D + +[[autodoc]] CogVideoXDownBlock3D + +## CogVideoXMidBlock3D + +[[autodoc]] CogVideoXMidBlock3D + +## CogVideoXUpBlock3D + +[[autodoc]] CogVideoXUpBlock3D + +## CogVideoXEncoder3D + +[[autodoc]] CogVideoXEncoder3D + +## CogVideoXDecoder3D + +[[autodoc]] CogVideoXDecoder3D \ No newline at end of file diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index df1d351ca1f7..072f836dce76 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -175,9 +175,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ # Store DoRA scale if present. if dora_present_in_unet: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." - unet_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) # Handle text encoder LoRAs. elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): @@ -197,13 +197,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) if lora_name.startswith(("lora_te_", "lora_te1_")): - te_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) elif lora_name.startswith("lora_te2_"): - te2_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) # Store alpha if present. if lora_name_alpha in state_dict: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 3c370cc32106..b6ae3975d526 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -9,10 +9,10 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation -from ..downsampling import Downsample3D +from ..downsampling import CogVideoXDownsample3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..upsampling import Upsample3D +from ..upsampling import CogVideoXUpsample3D from .vae import DecoderOutput, DiagonalGaussianDistribution @@ -46,7 +46,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class CogVideoXCausalConv3d(nn.Module): - r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.""" + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of output channels. + kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel. + stride (int, optional): Stride of the convolution. Default is 1. + dilation (int, optional): Dilation rate of the convolution. Default is 1. + pad_mode (str, optional): Padding mode. Default is "constant". + """ def __init__( self, @@ -162,6 +171,22 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: class CogVideoXResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the CogVideoX model. + + Args: + in_channels (int): Number of input channels. + out_channels (Optional[int], optional): Number of output channels. If None, defaults to `in_channels`. Default is None. + dropout (float, optional): Dropout rate. Default is 0.0. + temb_channels (int, optional): Number of time embedding channels. Default is 512. + groups (int, optional): Number of groups for group normalization. Default is 32. + eps (float, optional): Epsilon value for normalization layers. Default is 1e-6. + non_linearity (str, optional): Activation function to use. Default is "swish". + conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False. + spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. + pad_mode (str, optional): Padding mode. Default is "first". + """ + def __init__( self, in_channels: int, @@ -257,6 +282,24 @@ def forward( class CogVideoXDownBlock3D(nn.Module): + r""" + A downsampling block used in the CogVideoX model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + temb_channels (int): Number of time embedding channels. + dropout (float, optional): Dropout rate. Default is 0.0. + num_layers (int, optional): Number of layers in the block. Default is 1. + resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. + resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". + resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. + add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True. + downsample_padding (int, optional): Padding for the downsampling layer. Default is 0. + compress_time (bool, optional): If True, apply temporal compression. Default is False. + pad_mode (str, optional): Padding mode. Default is "first". + """ + _supports_gradient_checkpointing = True def __init__( @@ -297,7 +340,11 @@ def __init__( if add_downsample: self.downsamplers = nn.ModuleList( - [Downsample3D(out_channels, out_channels, padding=downsample_padding, compress_time=compress_time)] + [ + CogVideoXDownsample3D( + out_channels, out_channels, padding=downsample_padding, compress_time=compress_time + ) + ] ) self.gradient_checkpointing = False @@ -332,6 +379,21 @@ def create_forward(*inputs): class CogVideoXMidBlock3D(nn.Module): + r""" + A middle block used in the CogVideoX model. + + Args: + in_channels (int): Number of input channels. + temb_channels (int): Number of time embedding channels. + dropout (float, optional): Dropout rate. Default is 0.0. + num_layers (int, optional): Number of layers in the block. Default is 1. + resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. + resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". + resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. + spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. + pad_mode (str, optional): Padding mode. Default is "first". + """ + _supports_gradient_checkpointing = True def __init__( @@ -393,6 +455,25 @@ def create_forward(*inputs): class CogVideoXUpBlock3D(nn.Module): + r""" + An upsampling block used in the CogVideoX model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + temb_channels (int): Number of time embedding channels. + dropout (float, optional): Dropout rate. Default is 0.0. + num_layers (int, optional): Number of layers in the block. Default is 1. + resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. + resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". + resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. + spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16. + add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True. + upsample_padding (int, optional): Padding for the upsampling layer. Default is 1. + compress_time (bool, optional): If True, apply temporal compression. Default is False. + pad_mode (str, optional): Padding mode. Default is "first". + """ + def __init__( self, in_channels: int, @@ -433,7 +514,11 @@ def __init__( if add_upsample: self.upsamplers = nn.ModuleList( - [Upsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time)] + [ + CogVideoXUpsample3D( + out_channels, out_channels, padding=upsample_padding, compress_time=compress_time + ) + ] ) def forward( diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 905e7d9c374e..3ac8953e3dcc 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -285,7 +285,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.conv2d(inputs, weight, stride=2) -class Downsample3D(nn.Module): +class CogVideoXDownsample3D(nn.Module): # Todo: Wait for paper relase. r""" A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e397a95a9ddb..1258964385da 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -368,7 +368,9 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] - embeds = torch.cat([text_embeds, image_embeds],dim=1).contiguous() # [batch, seq_length + num_frames x height x width, channels] + embeds = torch.cat( + [text_embeds, image_embeds], dim=1 + ).contiguous() # [batch, seq_length + num_frames x height x width, channels] return embeds diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 8d3111296ec6..16899fd32f30 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -150,8 +150,6 @@ def forward( class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): - _supports_gradient_checkpointing = True - """ A Transformer model for video-like data. @@ -188,6 +186,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): The number of frames in the video-like data. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index e04e1dd4c448..007a055c1c91 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -348,10 +348,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) -class Upsample3D(nn.Module): - # Todo: Wait for paper relase. +class CogVideoXUpsample3D(nn.Module): r""" - A 3D Upsample3D layer using in [CogVideoX]() by Tsinghua University & ZhipuAI + A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. Args: in_channels (`int`): diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 717151eb0d02..1bd58e205787 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -538,7 +538,9 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - assert num_seconds in [4, 5, 6] and fps == 8, "The number of seconds must be 4, 5, or 6, and the fps must be 8." + assert ( + num_seconds in [4, 5, 6] and fps == 8 + ), "The number of seconds must be 4, 5, or 6, and the fps must be 8." if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs From 22dcceb8588beed8dfacbab276ccd428a6d17bb9 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 5 Aug 2024 19:17:16 +0800 Subject: [PATCH 80/94] messages --- src/diffusers/models/transformers/cogvideox_transformer_3d.py | 2 +- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 16899fd32f30..386ac90f1805 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -151,7 +151,7 @@ def forward( class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): """ - A Transformer model for video-like data. + A Transformer model for video-like data in CogVideoX. TODO: add link to CogVideoX upon release Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 1bd58e205787..85c434d4a162 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -540,7 +540,7 @@ def __call__( assert ( num_seconds in [4, 5, 6] and fps == 8 - ), "The number of seconds must be 4, 5, or 6, and the fps must be 8." + ), "The number of seconds must be 4, 5, or 6, and the fps must be 8. Other values are not supported in CogVideoX." if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs From e4d65ccdd7fce42cce1ebed22ebfbc56e4a5e1e8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 5 Aug 2024 13:57:42 +0200 Subject: [PATCH 81/94] update --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 8 +++++++- src/diffusers/schedulers/scheduling_dpm_cogvideox.py | 7 +++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 1bd58e205787..33fa46012d7c 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -14,6 +14,7 @@ # limitations under the License. import inspect +import math from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union @@ -622,7 +623,7 @@ def __call__( if self.interrupt: continue - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -635,8 +636,12 @@ def __call__( timestep=timestep, return_dict=False, )[0] + noise_pred = noise_pred.float() # perform guidance + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -654,6 +659,7 @@ def __call__( **extra_step_kwargs, return_dict=False, ) + latents = latents.to(prompt_embeds.dtype) # call the callback, if provided if callback_on_step_end is not None: diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index c02300708de9..1a2c7be7115b 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -25,6 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @@ -419,14 +420,16 @@ def step( mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)) mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5 - prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * torch.randn_like(sample) + noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise if old_pred_original_sample is None or prev_timestep < 0: # Save a network evaluation if all noise levels are 0 or on the first step return prev_sample, pred_original_sample else: denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample - x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * torch.randn_like(sample) + noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise prev_sample = x_advanced From 70a54a8230c72c64bf1bf1997f771b426ee5f5fa Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 5 Aug 2024 14:38:42 +0200 Subject: [PATCH 82/94] use num_frames instead of num_seconds --- .../pipelines/cogvideo/pipeline_cogvideox.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index d220b84731c3..08bc012584ed 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -49,7 +49,7 @@ ... "atmosphere of this unique musical performance." ... ) >>> video = pipe( - ... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=50 + ... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=20 ... ).frames[0] >>> export_to_video(video, "output.mp4", fps=8) ``` @@ -449,7 +449,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, - num_seconds: int = 6, + num_frames: int = 48, fps: int = 8, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, @@ -482,10 +482,11 @@ def __call__( The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_seconds (`int`, defaults to `6`): - Duration of video in seconds. Must be 4, 5, or 6. - fps (`int`, defaults to `8`): - Number of frames per second in video. Must be equal to 8 (for now). + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -540,8 +541,8 @@ def __call__( """ assert ( - num_seconds in [4, 5, 6] and fps == 8 - ), "The number of seconds must be 4, 5, or 6, and the fps must be 8. Other values are not supported in CogVideoX." + num_frames <= 48 and num_frames % fps == 0 and fps == 8 + ), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -597,7 +598,7 @@ def __call__( # 5. Prepare latents. latent_channels = self.transformer.config.in_channels - num_frames = 1 + num_seconds * fps + num_frames += 1 latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, @@ -642,6 +643,7 @@ def __call__( self._guidance_scale = 1 + guidance_scale * ( (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) + print(self._guidance_scale) if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -676,7 +678,7 @@ def __call__( progress_bar.update() if not output_type == "latents": - video = self.decode_latents(latents, num_seconds) + video = self.decode_latents(latents, num_frames // fps) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents From 9a0b9065181d0f2627ede107494f900e0054fffb Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 5 Aug 2024 22:08:50 +0800 Subject: [PATCH 83/94] restore --- .../loaders/lora_conversion_utils.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 072f836dce76..b5a617f6708b 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -175,9 +175,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ # Store DoRA scale if present. if dora_present_in_unet: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." - unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = ( - state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - ) + unet_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) # Handle text encoder LoRAs. elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): @@ -197,13 +197,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) if lora_name.startswith(("lora_te_", "lora_te1_")): - te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( - state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - ) + te_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) elif lora_name.startswith("lora_te2_"): - te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( - state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - ) + te2_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) # Store alpha if present. if lora_name_alpha in state_dict: @@ -325,4 +325,4 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): else: prefix = "text_encoder_2." new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" - return {new_name: alpha} + return {new_name: alpha} \ No newline at end of file From 32da2e7673cfe0475a47c41b859f5fbd8bf17a40 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 5 Aug 2024 22:20:06 +0800 Subject: [PATCH 84/94] Update lora_conversion_utils.py --- src/diffusers/loaders/lora_conversion_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index b5a617f6708b..df1d351ca1f7 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -325,4 +325,4 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): else: prefix = "text_encoder_2." new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" - return {new_name: alpha} \ No newline at end of file + return {new_name: alpha} From 878f609aa5ce4a78fea0f048726889debde1d7e8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 5 Aug 2024 17:52:13 +0200 Subject: [PATCH 85/94] remove dynamic guidance scale --- .../models/autoencoders/autoencoder_kl_cogvideox.py | 3 ++- src/diffusers/models/upsampling.py | 2 +- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 9 ++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index b6ae3975d526..5218d42b03df 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -176,7 +176,8 @@ class CogVideoXResnetBlock3D(nn.Module): Args: in_channels (int): Number of input channels. - out_channels (Optional[int], optional): Number of output channels. If None, defaults to `in_channels`. Default is None. + out_channels (Optional[int], optional): + Number of output channels. If None, defaults to `in_channels`. Default is None. dropout (float, optional): Dropout rate. Default is 0.0. temb_channels (int, optional): Number of time embedding channels. Default is 512. groups (int, optional): Number of groups for group normalization. Default is 32. diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 007a055c1c91..fd5ed28c7070 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -350,7 +350,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class CogVideoXUpsample3D(nn.Module): r""" - A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. + A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. Args: in_channels (`int`): diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 08bc012584ed..96a8cfc93b80 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -14,7 +14,6 @@ # limitations under the License. import inspect -import math from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union @@ -640,10 +639,10 @@ def __call__( noise_pred = noise_pred.float() # perform guidance - self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 - ) - print(self._guidance_scale) + # self._guidance_scale = 1 + guidance_scale * ( + # (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + # ) + # print(self._guidance_scale) if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) From de9e0b2f5a5ad8998c4e9fd636a52d1d73baa1cd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 Aug 2024 07:02:32 +0200 Subject: [PATCH 86/94] address review comments Co-Authored-By: YiYi Xu --- .../en/api/models/autoencoderkl_cogvideox.md | 48 ++----------------- .../en/api/models/cogvideox_transformer3d.md | 14 +++++- .../autoencoders/autoencoder_kl_cogvideox.py | 15 ++++++ .../transformers/cogvideox_transformer_3d.py | 31 +++++------- .../pipelines/cogvideo/pipeline_cogvideox.py | 7 --- .../schedulers/scheduling_ddim_cogvideox.py | 30 ------------ 6 files changed, 44 insertions(+), 101 deletions(-) diff --git a/docs/source/en/api/models/autoencoderkl_cogvideox.md b/docs/source/en/api/models/autoencoderkl_cogvideox.md index bcff73e0ebc5..6aae61d39035 100644 --- a/docs/source/en/api/models/autoencoderkl_cogvideox.md +++ b/docs/source/en/api/models/autoencoderkl_cogvideox.md @@ -11,18 +11,16 @@ specific language governing permissions and limitations under the License. --> # AutoencoderKLCogVideoX -The 3D variational autoencoder (VAE) model with KL loss using CogVideoX. +The 3D variational autoencoder (VAE) model with KL loss using [CogVideoX](https://github.com/THUDM/CogVideo). -## Loading from the original format +TODO: add paper and abstract here -By default, the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows: +The model can be loaded with the following code snippet. -```py +```python from diffusers import AutoencoderKLCogVideoX -url = "THUDM/CogVideoX-2b" # can also be a local file -model = AutoencoderKLCogVideoX.from_single_file(url) - +vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-2b", subfolder="vae", torch_dtype=torch.float16).to("cuda") ``` ## AutoencoderKLCogVideoX @@ -31,39 +29,3 @@ model = AutoencoderKLCogVideoX.from_single_file(url) - decode - encode - all - -## CogVideoXSafeConv3d - -[[autodoc]] CogVideoXSafeConv3d - -## CogVideoXCausalConv3d - -[[autodoc]] CogVideoXCausalConv3d - -## CogVideoXSpatialNorm3D - -[[autodoc]] CogVideoXSpatialNorm3D - -## CogVideoXResnetBlock3D - -[[autodoc]] CogVideoXResnetBlock3D - -## CogVideoXDownBlock3D - -[[autodoc]] CogVideoXDownBlock3D - -## CogVideoXMidBlock3D - -[[autodoc]] CogVideoXMidBlock3D - -## CogVideoXUpBlock3D - -[[autodoc]] CogVideoXUpBlock3D - -## CogVideoXEncoder3D - -[[autodoc]] CogVideoXEncoder3D - -## CogVideoXDecoder3D - -[[autodoc]] CogVideoXDecoder3D \ No newline at end of file diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md index 2514bf2cf03c..9a0bbc3061ba 100644 --- a/docs/source/en/api/models/cogvideox_transformer3d.md +++ b/docs/source/en/api/models/cogvideox_transformer3d.md @@ -9,9 +9,19 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -## CogVideoXTransformer3DModel +# CogVideoXTransformer3DModel + +A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo). + +TODO: add paper and abstract here + +The model can be loaded with the following code snippet. + +```python +from diffusers import CogVideoXTransformer3DModel -A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX). +vae = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolder="transformer", torch_dtype=torch.float16).to("cuda") +``` ## CogVideoXTransformer3DModel diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 5218d42b03df..4561e0d58305 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -1,3 +1,18 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Tuple, Union import numpy as np diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 386ac90f1805..a27172ba5a3b 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -41,30 +41,23 @@ class CogVideoXBlock(nn.Module): num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. """ def __init__( diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 96a8cfc93b80..e5b6fc459e5f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -425,13 +425,6 @@ def check_inputs( def guidance_scale(self): return self._guidance_scale - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 - @property def num_timesteps(self): return self._num_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index edb443169aa2..ec5c5f3e1c5d 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -392,36 +392,6 @@ def step( b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t prev_sample = a_t * sample + b_t * pred_original_sample - # breakpoint() - # # 5. compute variance: "sigma_t(η)" -> see formula (16) - # # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - # variance = self._get_variance(timestep, prev_timestep) - # std_dev_t = eta * variance ** (0.5) - - # if use_clipped_model_output: - # # the pred_epsilon is always re-derived from the clipped x_0 in Glide - # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon - - # # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - # prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - - # if eta > 0: - # if variance_noise is not None and generator is not None: - # raise ValueError( - # "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" - # " `variance_noise` stays `None`." - # ) - - # if variance_noise is None: - # variance_noise = randn_tensor( - # model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype - # ) - # variance = std_dev_t * variance_noise - - # prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) From 9c086f5afa5b81a4b6801750b1f44aeae7e99712 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 Aug 2024 07:08:17 +0200 Subject: [PATCH 87/94] dynamic cfg; fix cfg support Co-Authored-By: YiYi Xu --- .../pipelines/cogvideo/pipeline_cogvideox.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index e5b6fc459e5f..cce0f8dc459a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -14,6 +14,7 @@ # limitations under the License. import inspect +import math from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union @@ -446,6 +447,7 @@ def __call__( num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, + use_dynamic_cfg: bool = False, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -616,7 +618,7 @@ def __call__( if self.interrupt: continue - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -632,11 +634,11 @@ def __call__( noise_pred = noise_pred.float() # perform guidance - # self._guidance_scale = 1 + guidance_scale * ( - # (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 - # ) - # print(self._guidance_scale) - if self.do_classifier_free_guidance: + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) From 62d94aaadd03970c7bb732a141f7f1069beb0f68 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 Aug 2024 07:44:51 +0200 Subject: [PATCH 88/94] address review comments Co-Authored-By: YiYi Xu --- .../en/api/models/autoencoderkl_cogvideox.md | 12 ++++++-- .../en/api/models/cogvideox_transformer3d.md | 8 +++-- docs/source/en/api/pipelines/cogvideox.md | 29 ++++++++++++------- src/diffusers/models/normalization.py | 22 +++++++------- .../pipelines/cogvideo/pipeline_cogvideox.py | 7 ++--- 5 files changed, 48 insertions(+), 30 deletions(-) diff --git a/docs/source/en/api/models/autoencoderkl_cogvideox.md b/docs/source/en/api/models/autoencoderkl_cogvideox.md index 6aae61d39035..122812b31d2e 100644 --- a/docs/source/en/api/models/autoencoderkl_cogvideox.md +++ b/docs/source/en/api/models/autoencoderkl_cogvideox.md @@ -11,9 +11,7 @@ specific language governing permissions and limitations under the License. --> # AutoencoderKLCogVideoX -The 3D variational autoencoder (VAE) model with KL loss using [CogVideoX](https://github.com/THUDM/CogVideo). - -TODO: add paper and abstract here +The 3D variational autoencoder (VAE) model with KL loss used in [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI. The model can be loaded with the following code snippet. @@ -29,3 +27,11 @@ vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-2b", subfolder="va - decode - encode - all + +## AutoencoderKLOutput + +[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md index 9a0bbc3061ba..8c8baae7b537 100644 --- a/docs/source/en/api/models/cogvideox_transformer3d.md +++ b/docs/source/en/api/models/cogvideox_transformer3d.md @@ -11,9 +11,7 @@ specific language governing permissions and limitations under the License. --> # CogVideoXTransformer3DModel -A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo). - -TODO: add paper and abstract here +A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI. The model can be loaded with the following code snippet. @@ -26,3 +24,7 @@ vae = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolde ## CogVideoXTransformer3DModel [[autodoc]] CogVideoXTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 521125ff189b..21a87576696f 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -10,18 +10,18 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. - -## TODO: The paper is still being written. +# limitations under the License. --> # CogVideoX -[TODO]() from Tsinghua University & ZhipuAI. + + +[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI. The abstract from the paper is: -The paper is still being written. +*We introduce CogVideoX, a large-scale diffusion transformer model designed for generating videos based on text prompts. To efficently model video data, we propose to levearge a 3D Variational Autoencoder (VAE) to compresses videos along both spatial and temporal dimensions. To improve the text-video alignment, we propose an expert transformer with the expert adaptive LayerNorm to facilitate the deep fusion between the two modalities. By employing a progressive training technique, CogVideoX is adept at producing coherent, long-duration videos characterized by significant motion. In addition, we develop an effectively text-video data processing pipeline that includes various data preprocessing strategies and a video captioning method. It significantly helps enhance the performance of CogVideoX, improving both generation quality and semantic alignment. Results show that CogVideoX demonstrates state-of-the-art performance across both multiple machine metrics and human evaluations. The model weight of CogVideoX-2B is publicly available at https://github.com/THUDM/CogVideo.* @@ -37,11 +37,20 @@ First, load the pipeline: ```python import torch -from diffusers import LattePipeline - -pipeline = LattePipeline.from_pretrained( - "THUDM/CogVideoX-2b", torch_dtype=torch.float16 -).to("cuda") +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda") +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +export_to_video(video, "output.mp4", fps=8) ``` Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index dcedd529fab9..54a1d511966c 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -44,39 +44,41 @@ def __init__( output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, - use_embedding: bool = True, + chunk_dim: int = 0, ): super().__init__() - if use_embedding: + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: self.emb = nn.Embedding(num_embeddings, embedding_dim) else: self.emb = None - output_dim = output_dim or embedding_dim * 2 - self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) def forward( - self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None ) -> torch.Tensor: - input_ndim = hidden_states.ndim - if self.emb is not None: temb = self.emb(timestep) temb = self.linear(self.silu(temb)) - if input_ndim == 3: + if self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX for now. shift, scale = temb.chunk(2, dim=1) shift = shift[:, None, :] scale = scale[:, None, :] else: scale, shift = temb.chunk(2, dim=0) - hidden_states = self.norm(hidden_states) * (1 + scale) + shift - return hidden_states + x = self.norm(x) * (1 + scale) + shift + return x class FP32LayerNorm(nn.LayerNorm): diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index cce0f8dc459a..765da1e48239 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -36,10 +36,11 @@ EXAMPLE_DOC_STRING = """ Examples: ```python + >>> import torch >>> from diffusers import CogVideoXPipeline >>> from diffusers.utils import export_to_video - >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda") + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") >>> prompt = ( ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " @@ -48,9 +49,7 @@ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " ... "atmosphere of this unique musical performance." ... ) - >>> video = pipe( - ... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=20 - ... ).frames[0] + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] >>> export_to_video(video, "output.mp4", fps=8) ``` """ From 5e4dd15168e3382306fb8fb6e344076798f781c3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 Aug 2024 10:07:22 +0200 Subject: [PATCH 89/94] update tests --- .../autoencoders/autoencoder_kl_cogvideox.py | 16 +- src/diffusers/models/normalization.py | 6 +- .../transformers/cogvideox_transformer_3d.py | 6 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 14 +- tests/pipelines/cogvideox/test_cogvideox.py | 193 +++++++++--------- 5 files changed, 122 insertions(+), 113 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 4561e0d58305..16db002eb467 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -163,9 +163,10 @@ def __init__( self, f_channels: int, zq_channels: int, + groups: int = 32, ): super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) @@ -232,10 +233,12 @@ def __init__( self.norm1 = CogVideoXSpatialNorm3D( f_channels=in_channels, zq_channels=spatial_norm_dim, + groups=groups, ) self.norm2 = CogVideoXSpatialNorm3D( f_channels=out_channels, zq_channels=spatial_norm_dim, + groups=groups, ) self.conv1 = CogVideoXCausalConv3d( @@ -537,6 +540,8 @@ def __init__( ] ) + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.Tensor, @@ -803,7 +808,7 @@ def __init__( self.up_blocks.append(up_block) - self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels) + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode @@ -852,7 +857,8 @@ def custom_forward(*inputs): class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" - A VAE model with KL loss for encodfing images into latents and decoding latent representations into images. + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [CogVideoX](https://github.com/THUDM/CogVideo). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -879,9 +885,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - mid_block_add_attention (`bool`, *optional*, default to `True`): - If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the - mid_block will only have resnet blocks """ _supports_gradient_checkpointing = True @@ -919,7 +922,6 @@ def __init__( force_upcast: float = True, use_quant_conv: bool = False, use_post_quant_conv: bool = False, - mid_block_add_attention: bool = True, ): super().__init__() diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 54a1d511966c..5740fed9f30c 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -34,7 +34,11 @@ class AdaLayerNorm(nn.Module): Parameters: embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the embeddings dictionary. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): """ def __init__( diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index a27172ba5a3b..9eae35d62e69 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -34,7 +34,7 @@ @maybe_allow_in_graph class CogVideoXBlock(nn.Module): r""" - Transformer block used in CogVideoX model. TODO: add link to CogVideoX upon release + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. Parameters: dim (`int`): The number of channels in the input and output. @@ -144,7 +144,7 @@ def forward( class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): """ - A Transformer model for video-like data in CogVideoX. TODO: add link to CogVideoX upon release + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. @@ -262,7 +262,7 @@ def __init__( output_dim=2 * inner_dim, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, - use_embedding=False, + chunk_dim=1, ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 765da1e48239..6b124e518ae7 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -153,7 +153,7 @@ class CogVideoXPipeline(DiffusionPipeline): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = [ @@ -181,9 +181,6 @@ def __init__( self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226 - ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -213,7 +210,7 @@ def _get_t5_prompt_embeds( untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" @@ -459,6 +456,7 @@ def __call__( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, ) -> Union[CogVideoXPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -524,6 +522,9 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. Examples: @@ -580,6 +581,7 @@ def __call__( num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, device=device, ) if do_classifier_free_guidance: @@ -670,7 +672,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if not output_type == "latents": + if not output_type == "latent": video = self.decode_latents(latents, num_frames // fps) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index b3402dedef65..2219cde57088 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -14,14 +14,13 @@ import gc import inspect -import tempfile import unittest import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKL, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler +from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, @@ -43,41 +42,71 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - - required_optional_params = PipelineTesterMixin.required_optional_params + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) def get_dummy_components(self): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( - sample_size=8, - num_layers=1, - patch_size=2, + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings + # But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel + # to be 32. The internal dim is product of num_attention_heads and attention_head_dim + num_attention_heads=4, attention_head_dim=8, - num_attention_heads=3, - caption_channels=32, in_channels=4, - cross_attention_dim=24, - out_channels=8, - attention_bias=True, - activation_fn="gelu-approximate", - num_embeds_ada_norm=1000, - norm_type="ada_norm_single", - norm_elementwise_affine=False, - norm_eps=1e-6, + out_channels=4, + time_embed_dim=2, + text_embed_dim=32, # Must match with tiny-random-t5 + num_layers=1, + sample_width=16, # latent width: 2 -> final width: 16 + sample_height=16, # latent height: 2 -> final height: 16 + sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=16, ) + torch.manual_seed(0) - vae = AutoencoderKL() + vae = AutoencoderKLCogVideoX( + in_channels=3, + out_channels=3, + down_block_types=( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types=( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + torch.manual_seed(0) scheduler = DDIMScheduler() text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") components = { - "transformer": transformer.eval(), - "vae": vae.eval(), + "transformer": transformer, + "vae": vae, "scheduler": scheduler, - "text_encoder": text_encoder.eval(), + "text_encoder": text_encoder, "tokenizer": tokenizer, } return components @@ -88,16 +117,22 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { - "prompt": "A painting of a squirrel eating a burger", - "negative_prompt": "low quality", + "prompt": "dance monkey", + "negative_prompt": "", "generator": generator, "num_inference_steps": 2, - "guidance_scale": 5.0, - "height": 8, - "width": 8, - "video_length": 1, + "guidance_scale": 6.0, + # Cannot reduce because convolution kernel becomes bigger than sample + "height": 16, + "width": 16, + # TODO(aryan): improve this + # Cannot make this lower due to assert condition in pipeline at the moment. + # The reason why 8 can't be used here is due to how context-parallel cache works where the first + # second of video is decoded from latent frames (0, 3) instead of [(0, 2), (2, 3)]. If 8 is used, + # the number of output frames that you get are 5. + "num_frames": 8, + "max_sequence_length": 16, "output_type": "pt", - "clean_caption": False, } return inputs @@ -113,8 +148,8 @@ def test_inference(self): video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (1, 3, 8, 8)) - expected_video = torch.randn(1, 3, 8, 8) + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) max_diff = np.abs(generated_video - expected_video).max() self.assertLessEqual(max_diff, 1e10) @@ -180,76 +215,41 @@ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) - def test_attention_slicing_forward_pass(self): - pass - - def test_save_load_optional_components(self): - if not hasattr(self.pipeline_class, "_optional_components"): + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: return components = self.get_dummy_components() pipe = self.pipeline_class(**components) - for component in pipe.components.values(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(torch_device) - - prompt = inputs["prompt"] - generator = inputs["generator"] - - ( - prompt_embeds, - negative_prompt_embeds, - ) = pipe.encode_prompt(prompt) - - # inputs with prompt converted to embeddings - inputs = { - "prompt_embeds": prompt_embeds, - "negative_prompt": None, - "negative_prompt_embeds": negative_prompt_embeds, - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 5.0, - "height": 8, - "width": 8, - "video_length": 1, - "mask_feature": False, - "output_type": "pt", - "clean_caption": False, - } - - # set all optional components to None - for optional_component in pipe._optional_components: - setattr(pipe, optional_component, None) - - output = pipe(**inputs)[0] - - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir, safe_serialization=False) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) - pipe_loaded.to(torch_device) - - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe_loaded.set_progress_bar_config(disable=None) - - for optional_component in pipe._optional_components: - self.assertTrue( - getattr(pipe_loaded, optional_component) is None, - f"`{optional_component}` did not stay set to None after loading.", + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", ) - output_loaded = pipe_loaded(**inputs)[0] - - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, 1.0) - @slow @require_torch_gpu @@ -269,21 +269,22 @@ def tearDown(self): def test_cogvideox(self): generator = torch.Generator("cpu").manual_seed(0) - pipe = CogVideoXPipeline.from_pretrained("THUDM/cogvideox-2b", torch_dtype=torch.float16) + pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16) pipe.enable_model_cpu_offload() prompt = self.prompt videos = pipe( prompt=prompt, - height=512, - width=512, + height=480, + width=720, + num_frames=16, generator=generator, num_inference_steps=2, - clean_caption=False, + output_type="pt", ).frames video = videos[0] - expected_video = torch.randn(1, 512, 512, 3).numpy() + expected_video = torch.randn(1, 16, 480, 720, 3).numpy() - max_diff = numpy_cosine_similarity_distance(video.fCogVideoXn(), expected_video) - assert max_diff < 1e-3, f"Max diff is too high. got {video.fCogVideoXn()}" + max_diff = numpy_cosine_similarity_distance(video, expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video}" From d1c575ad7ee0390c2735f50cc59a79aae666567a Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 Aug 2024 10:19:34 +0200 Subject: [PATCH 90/94] fix docs error --- docs/source/en/_toctree.yml | 6 ++++++ docs/source/en/api/pipelines/cogvideox.md | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b32c8ffe6c1b..5750c518e83d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -235,6 +235,8 @@ title: VQModel - local: api/models/autoencoderkl title: AutoencoderKL + - local: api/models/autoencoderkl_cogvideox + title: AutoencoderKLCogVideoX - local: api/models/asymmetricautoencoderkl title: AsymmetricAutoencoderKL - local: api/models/stable_cascade_unet @@ -259,6 +261,8 @@ title: FluxTransformer2DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel + - local: api/models/cogvideox_transformer3d + title: CogVideoXTransformer3DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel - local: api/models/transformer_temporal @@ -298,6 +302,8 @@ title: AutoPipeline - local: api/pipelines/blip_diffusion title: BLIP-Diffusion + - local: api/pipelines/cogvideox + title: CogVideoX - local: api/pipelines/consistency_models title: Consistency Models - local: api/pipelines/controlnet diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 21a87576696f..02fe6d883f5d 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -85,4 +85,5 @@ With torch.compile(): Average inference time: TODO seconds. - __call__ ## CogVideoXPipelineOutput -[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput + +[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput From 11224d953b59e17f1dea17006e8832ecd372e811 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 Aug 2024 15:39:00 +0200 Subject: [PATCH 91/94] alternative implementation to context parallel cache --- .../autoencoders/autoencoder_kl_cogvideox.py | 93 ++++++++++--------- .../pipelines/cogvideo/pipeline_cogvideox.py | 6 +- 2 files changed, 50 insertions(+), 49 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 16db002eb467..dad3d3874ad1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..downsampling import CogVideoXDownsample3D @@ -31,6 +32,9 @@ from .vae import DecoderOutput, DiagonalGaussianDistribution +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class CogVideoXSafeConv3d(nn.Conv3d): """ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. @@ -129,13 +133,15 @@ def fake_cp_pass_from_previous_rank(self, inputs: torch.Tensor) -> torch.Tensor: inputs = inputs.transpose(0, dim).contiguous() return inputs - def forward(self, inputs: torch.Tensor, clear_fake_cp_cache: bool = True): - input_parallel = self.fake_cp_pass_from_previous_rank(inputs) - + def _clear_fake_context_parallel_cache(self): del self.conv_cache self.conv_cache = None - if not clear_fake_cp_cache: - self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + input_parallel = self.fake_cp_pass_from_previous_rank(inputs) + + self._clear_fake_context_parallel_cache() + self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) @@ -268,15 +274,16 @@ def forward( inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, - clear_fake_cp_cache: bool = True, ) -> torch.Tensor: hidden_states = inputs + if zq is not None: hidden_states = self.norm1(hidden_states, zq) else: hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv1(hidden_states) if temb is not None: hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] @@ -288,16 +295,13 @@ def forward( hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - inputs = self.conv_shortcut(inputs, clear_fake_cp_cache=clear_fake_cp_cache) - else: - inputs = self.conv_shortcut(inputs) + inputs = self.conv_shortcut(inputs) - output_tensor = inputs + hidden_states - return output_tensor + hidden_states = hidden_states + inputs + return hidden_states class CogVideoXDownBlock3D(nn.Module): @@ -373,7 +377,6 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, - clear_fake_cp_cache: bool = False, ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -385,10 +388,10 @@ def create_forward(*inputs): return create_forward hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache + create_custom_forward(resnet), hidden_states, temb, zq ) else: - hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) + hidden_states = resnet(hidden_states, temb, zq) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -453,7 +456,6 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, - clear_fake_cp_cache: bool = False, ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -465,10 +467,10 @@ def create_forward(*inputs): return create_forward hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache + create_custom_forward(resnet), hidden_states, temb, zq ) else: - hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) + hidden_states = resnet(hidden_states, temb, zq) return hidden_states @@ -547,7 +549,6 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, - clear_fake_cp_cache: bool = False, ) -> torch.Tensor: r"""Forward method of the `CogVideoXUpBlock3D` class.""" for resnet in self.resnets: @@ -560,10 +561,10 @@ def create_forward(*inputs): return create_forward hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache + create_custom_forward(resnet), hidden_states, temb, zq ) else: - hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache) + hidden_states = resnet(hidden_states, temb, zq) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -671,11 +672,9 @@ def __init__( self.gradient_checkpointing = False - def forward( - self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True - ) -> torch.Tensor: + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXEncoder3D` class.""" - hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: @@ -688,25 +687,25 @@ def custom_forward(*inputs): # 1. Down for down_block in self.down_blocks: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, temb, None, clear_fake_cp_cache + create_custom_forward(down_block), hidden_states, temb, None ) # 2. Mid hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, None, clear_fake_cp_cache + create_custom_forward(self.mid_block), hidden_states, temb, None ) else: # 1. Down for down_block in self.down_blocks: - hidden_states = down_block(hidden_states, temb, None, clear_fake_cp_cache) + hidden_states = down_block(hidden_states, temb, None) # 2. Mid - hidden_states = self.mid_block(hidden_states, temb, None, clear_fake_cp_cache) + hidden_states = self.mid_block(hidden_states, temb, None) # 3. Post-process hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv_out(hidden_states) return hidden_states @@ -816,11 +815,9 @@ def __init__( self.gradient_checkpointing = False - def forward( - self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True - ) -> torch.Tensor: + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXDecoder3D` class.""" - hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: @@ -832,26 +829,26 @@ def custom_forward(*inputs): # 1. Mid hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, sample, clear_fake_cp_cache + create_custom_forward(self.mid_block), hidden_states, temb, sample ) # 2. Up for up_block in self.up_blocks: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, temb, sample, clear_fake_cp_cache + create_custom_forward(up_block), hidden_states, temb, sample ) else: # 1. Mid - hidden_states = self.mid_block(hidden_states, temb, sample, clear_fake_cp_cache) + hidden_states = self.mid_block(hidden_states, temb, sample) # 2. Up for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache) + hidden_states = up_block(hidden_states, temb, sample) # 3. Post-process hidden_states = self.norm_out(hidden_states, sample) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache) + hidden_states = self.conv_out(hidden_states) return hidden_states @@ -966,6 +963,12 @@ def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): module.gradient_checkpointing = value + def clear_fake_context_parallel_cache(self): + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") + module._clear_fake_context_parallel_cache() + def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to @@ -1013,7 +1016,7 @@ def encode( The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - h = self.encoder(x, clear_fake_cp_cache=not fake_cp) + h = self.encoder(x) if self.quant_conv is not None: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) @@ -1022,9 +1025,7 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) @apply_forward_hook - def decode( - self, z: torch.FloatTensor, return_dict: bool = True, fake_cp: bool = False - ) -> Union[DecoderOutput, torch.FloatTensor]: + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. @@ -1043,7 +1044,7 @@ def decode( """ if self.post_quant_conv is not None: z = self.post_quant_conv(z) - dec = self.decoder(z, clear_fake_cp_cache=not fake_cp) + dec = self.decoder(z) if not return_dict: return (dec,) return DecoderOutput(sample=dec) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 6b124e518ae7..04f2752175af 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -338,13 +338,13 @@ def decode_latents(self, latents: torch.Tensor, num_seconds: int): frames = [] for i in range(num_seconds): - # Whether or not to clear fake context parallel cache - fake_cp = i + 1 < num_seconds start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) - current_frames = self.vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample + current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample frames.append(current_frames) + self.vae.clear_fake_context_parallel_cache() + frames = torch.cat(frames, dim=2) return frames From 70cea9154b7d5b77138ce02217a04bedcf70ddf5 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 6 Aug 2024 04:33:44 -1000 Subject: [PATCH 92/94] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/cogvideox.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 02fe6d883f5d..96138a0ecb11 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -29,6 +29,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m +This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). + ### Inference Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. From cbc4d32d6c7123fea0f284a18057a77925e3a547 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 6 Aug 2024 21:38:43 +0200 Subject: [PATCH 93/94] remove tiling and slicing until their implementations are complete --- .../autoencoders/autoencoder_kl_cogvideox.py | 39 ++----------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index dad3d3874ad1..6aad4d63410f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -117,7 +117,7 @@ def __init__( self.conv_cache = None - def fake_cp_pass_from_previous_rank(self, inputs: torch.Tensor) -> torch.Tensor: + def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: dim = self.temporal_dim kernel_size = self.time_kernel_size if kernel_size == 1: @@ -138,7 +138,7 @@ def _clear_fake_context_parallel_cache(self): self.conv_cache = None def forward(self, inputs: torch.Tensor) -> torch.Tensor: - input_parallel = self.fake_cp_pass_from_previous_rank(inputs) + input_parallel = self.fake_context_parallel_forward(inputs) self._clear_fake_context_parallel_cache() self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() @@ -969,38 +969,9 @@ def clear_fake_context_parallel_cache(self): logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") module._clear_fake_context_parallel_cache() - def enable_tiling(self, use_tiling: bool = True): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = use_tiling - - def disable_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.enable_tiling(False) - - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True, fake_cp: bool = False + self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -1009,8 +980,6 @@ def encode( x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. - fake_cp (`bool`, *optional*, defaults to `True`): - If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). Returns: The latent representations of the encoded images. If `return_dict` is True, a @@ -1033,8 +1002,6 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - fake_cp (`bool`, *optional*, defaults to `True`): - If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work). Returns: [`~models.vae.DecoderOutput`] or `tuple`: From 827a70ae2aecaf08af0484ebaef22a592efe67b5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 7 Aug 2024 12:30:50 +0530 Subject: [PATCH 94/94] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu --- docs/source/en/api/pipelines/cogvideox.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 96138a0ecb11..51026091b348 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -31,7 +31,7 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). -### Inference +## Inference Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. @@ -42,7 +42,7 @@ import torch from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda") +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda") prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "