-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathdemo_util.py
108 lines (86 loc) · 3.51 KB
/
demo_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Demo file for sampling images from TiTok.
Copyright (2024) Bytedance Ltd. and/or its affiliates
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from omegaconf import OmegaConf
from modeling.titok import TiTok
from modeling.maskgit import ImageBert, UViTBert
from modeling.rar import RAR
def get_config_cli():
cli_conf = OmegaConf.from_cli()
yaml_conf = OmegaConf.load(cli_conf.config)
conf = OmegaConf.merge(yaml_conf, cli_conf)
return conf
def get_config(config_path):
conf = OmegaConf.load(config_path)
return conf
def get_titok_tokenizer(config):
tokenizer = TiTok(config)
tokenizer.load_state_dict(torch.load(config.experiment.tokenizer_checkpoint, map_location="cpu"))
tokenizer.eval()
tokenizer.requires_grad_(False)
return tokenizer
def get_titok_generator(config):
if config.model.generator.model_type == "ViT":
model_cls = ImageBert
elif config.model.generator.model_type == "UViT":
model_cls = UViTBert
else:
raise ValueError(f"Unsupported model type {config.model.generator.model_type}")
generator = model_cls(config)
generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu"))
generator.eval()
generator.requires_grad_(False)
return generator
def get_rar_generator(config):
model_cls = RAR
generator = model_cls(config)
generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu"))
generator.eval()
generator.requires_grad_(False)
generator.set_random_ratio(0)
return generator
@torch.no_grad()
def sample_fn(generator,
tokenizer,
labels=None,
guidance_scale=3.0,
guidance_decay="constant",
guidance_scale_pow=3.0,
randomize_temperature=2.0,
softmax_temperature_annealing=False,
num_sample_steps=8,
device="cuda",
return_tensor=False):
generator.eval()
tokenizer.eval()
if labels is None:
# goldfish, chicken, tiger, cat, hourglass, ship, dog, race car, airliner, teddy bear, random
labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, torch.randint(0, 999, size=(1,))]
if not isinstance(labels, torch.Tensor):
labels = torch.LongTensor(labels).to(device)
generated_tokens = generator.generate(
condition=labels,
guidance_scale=guidance_scale,
guidance_decay=guidance_decay,
guidance_scale_pow=guidance_scale_pow,
randomize_temperature=randomize_temperature,
softmax_temperature_annealing=softmax_temperature_annealing,
num_sample_steps=num_sample_steps)
generated_image = tokenizer.decode_tokens(
generated_tokens.view(generated_tokens.shape[0], -1)
)
if return_tensor:
return generated_image
generated_image = torch.clamp(generated_image, 0.0, 1.0)
generated_image = (generated_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
return generated_image