Skip to content

Commit

Permalink
Big Model Renaming (open-mmlab#109)
Browse files Browse the repository at this point in the history
* up

* change model name

* renaming

* more changes

* up

* up

* up

* save checkpoint

* finish api / naming

* finish config renaming

* rename all weights

* finish really
  • Loading branch information
patrickvonplaten authored Jul 20, 2022
1 parent 13e37ca commit 9c3820d
Show file tree
Hide file tree
Showing 24 changed files with 591 additions and 655 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree

```python
import torch
from diffusers import UNetUnconditionalModel, DDIMScheduler
from diffusers import UNet2DModel, DDIMScheduler
import PIL.Image
import numpy as np
import tqdm
Expand All @@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. Load models
scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
unet = UNetUnconditionalModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)

# 2. Sample gaussian noise
generator = torch.manual_seed(23)
Expand Down
38 changes: 19 additions & 19 deletions scripts/convert_ddpm_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
import argparse
import json
import torch
Expand Down Expand Up @@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
continue

new_path = new_path.replace('down.', 'downsample_blocks.')
new_path = new_path.replace('up.', 'upsample_blocks.')
new_path = new_path.replace('up.', 'up_blocks.')

if additional_replacements is not None:
for replacement in additional_replacements:
Expand Down Expand Up @@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config):
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}

num_upsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
upsample_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_upsample_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}

for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
Expand Down Expand Up @@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config):
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])

for i in range(num_upsample_blocks):
block_id = num_upsample_blocks - 1 - i
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i

if any('upsample' in layer for layer in upsample_blocks[i]):
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']

if any('block' in layer for layer in upsample_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in upsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])

if any('attn' in layer for layer in upsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_attn > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])

new_checkpoint = {k.replace('mid_new_2', 'mid'): v for k, v in new_checkpoint.items()}
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
return new_checkpoint


Expand Down Expand Up @@ -225,7 +225,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
if "ddpm" in config:
del config["ddpm"]

model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)

scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
Expand Down
38 changes: 19 additions & 19 deletions scripts/convert_ldm_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import argparse
import json
import torch
from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline


def shave_segments(path, n_shave_prefix_segments=1):
Expand Down Expand Up @@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths = renew_attention_paths(attentions)
to_split = {
'middle_block.1.qkv.bias': {
'key': 'mid.attentions.0.key.bias',
'query': 'mid.attentions.0.query.bias',
'value': 'mid.attentions.0.value.bias',
'key': 'mid_block.attentions.0.key.bias',
'query': 'mid_block.attentions.0.query.bias',
'value': 'mid_block.attentions.0.value.bias',
},
'middle_block.1.qkv.weight': {
'key': 'mid.attentions.0.key.weight',
'query': 'mid.attentions.0.query.weight',
'value': 'mid.attentions.0.value.weight',
'key': 'mid_block.attentions.0.key.weight',
'query': 'mid_block.attentions.0.query.weight',
'value': 'mid_block.attentions.0.value.weight',
},
}
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
Expand All @@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)

meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)

if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']

# Clear attentions as they have been attributed above.
if len(attentions) == 2:
Expand All @@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attentions)
meta_path = {
'old': f'output_blocks.{i}.1',
'new': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}'
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
}
to_split = {
f'output_blocks.{i}.1.qkv.bias': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
},
f'output_blocks.{i}.1.qkv.weight': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
},
}
assign_to_checkpoint(
Expand All @@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']])
new_path = '.'.join(['upsample_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])

new_checkpoint[new_path] = checkpoint[old_path]

Expand Down Expand Up @@ -319,7 +319,7 @@ def convert_ldm_checkpoint(checkpoint, config):
if "ldm" in config:
del config["ldm"]

model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)

try:
Expand Down
20 changes: 10 additions & 10 deletions scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import argparse
import json
import torch
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel


def convert_ncsnpp_checkpoint(checkpoint, config):
"""
Takes a state dict and the path to
"""
new_model_architecture = UNetUnconditionalModel(**config)
new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture = UNet2DModel(**config)
new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data

Expand Down Expand Up @@ -92,14 +92,14 @@ def set_resnet_weights(new_layer, old_checkpoint, index):
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
module_index += 1

set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
module_index += 1
set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index)
set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
module_index += 1

for i, block in enumerate(new_model_architecture.upsample_blocks):
for i, block in enumerate(new_model_architecture.up_blocks):
has_attentions = hasattr(block, "attentions")
for j in range(len(block.resnets)):
set_resnet_weights(block.resnets[j], checkpoint, module_index)
Expand Down Expand Up @@ -134,7 +134,7 @@ def set_resnet_weights(new_layer, old_checkpoint, index):

parser.add_argument(
"--checkpoint_path",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
type=str,
required=False,
help="Path to the checkpoint to convert.",
Expand Down Expand Up @@ -171,7 +171,7 @@ def set_resnet_weights(new_layer, old_checkpoint, index):
if "sde" in config:
del config["sde"]

model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)

try:
Expand Down
17 changes: 10 additions & 7 deletions scripts/generate_logits.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from huggingface_hub import HfApi
from transformers.file_utils import has_file
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel
import random
import torch
api = HfApi()
Expand Down Expand Up @@ -70,19 +70,22 @@
models = api.list_models(filter="diffusers")
for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":

if mod.modelId == "CompVis/ldm-celebahq-256" or not has_file(mod.modelId, "config.json"):
model = UNetUnconditionalModel.from_pretrained(mod.modelId, subfolder = "unet")
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]

print(f"Started running {mod.modelId}!!!")

if mod.modelId.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
else:
model = UNetUnconditionalModel.from_pretrained(mod.modelId)
model = UNet2DModel.from_pretrained(local_checkpoint)

torch.manual_seed(0)
random.seed(0)

noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)['sample']

torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
print(f"{mod.modelId} has passed succesfully!!!")
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__version__ = "0.0.4"

from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def get_config_dict(

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
" login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
Expand Down
Loading

0 comments on commit 9c3820d

Please sign in to comment.