-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample_c2i.py
179 lines (159 loc) · 7.38 KB
/
sample_c2i.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Modified from:
# LLaMAGen: https://github.com/FoundationVision/LlamaGen/blob/main/autoregressive/sample/sample_c2i_ddp.py
# DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.nn.functional as F
import torch.distributed as dist
from omegaconf import OmegaConf
from tqdm import tqdm
import os
from PIL import Image
import numpy as np
import math
import argparse
import sys
sys.path.append("./")
from RandAR.dataset.builder import build_dataset
from RandAR.utils.distributed import init_distributed_mode, is_main_process
from RandAR.dataset.augmentation import center_crop_arr
from RandAR.util import instantiate_from_config, load_safetensors
def create_npz_from_sample_folder(sample_dir, num=50_000):
"""
Builds a single .npz file from a folder of .png samples.
"""
samples = []
for i in tqdm(range(num), desc="Building .npz file from samples"):
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
npz_path = f"{sample_dir}.npz"
np.savez(npz_path, arr_0=samples)
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
return npz_path
def main(args):
# Setup PyTorch:
assert (
torch.cuda.is_available()
), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
torch.set_grad_enabled(False)
# Setup DDP:
dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
config = OmegaConf.load(args.config)
# create and load model
tokenizer = instantiate_from_config(config.tokenizer).to(device).eval()
ckpt = torch.load(args.vq_ckpt, map_location="cpu")
if 'model' in ckpt:
state_dict = ckpt['model']
else:
state_dict = ckpt
tokenizer.load_state_dict(state_dict)
# create and load gpt model
precision = {"none": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[
args.precision
]
latent_size = args.image_size // args.downsample_size
gpt_model = instantiate_from_config(config.ar_model).to(device=device, dtype=precision)
model_weight = load_safetensors(args.gpt_ckpt)
gpt_model.load_state_dict(model_weight, strict=True)
gpt_model.eval()
# Create folder to save samples:
ckpt_string_name = (
os.path.basename(args.gpt_ckpt)
.replace(".pth", "")
.replace(".pt", "")
.replace(".safetensors", "")
)
folder_name = (
f"{args.exp_name}-{ckpt_string_name}-size-{args.image_size}-size-{args.image_size_eval}-"
f"cfg-{args.cfg_scales}-seed-{args.global_seed}"
)
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
if rank == 0:
os.makedirs(sample_folder_dir, exist_ok=True)
print(f"Saving .png samples at {sample_folder_dir}")
dist.barrier()
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
n = args.per_proc_batch_size
global_batch_size = n * dist.get_world_size()
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
if rank == 0:
print(f"Total number of images that will be sampled: {total_samples}")
assert (
total_samples % dist.get_world_size() == 0
), "total_samples must be divisible by world_size"
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
assert (
samples_needed_this_gpu % n == 0
), "samples_needed_this_gpu must be divisible by the per-GPU batch size"
iterations = int(samples_needed_this_gpu // n)
pbar = range(iterations)
pbar = tqdm(pbar) if rank == 0 else pbar
total = 0
cur_iter = 0
for _ in pbar:
c_indices = torch.randint(0, args.num_classes, (n,), device=device)
cfg_scales = args.cfg_scales.split(",")
cfg_scales = [float(cfg_scale) for cfg_scale in cfg_scales]
indices = gpt_model.generate(
cond=c_indices,
token_order=None,
cfg_scales=cfg_scales,
num_inference_steps=args.num_inference_steps,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
samples = tokenizer.decode_codes_to_img(indices, args.image_size_eval)
for i, sample in enumerate(samples):
index = i * dist.get_world_size() + rank + total
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
total += global_batch_size
cur_iter += 1
# I use this line to look at the initial images to check the correctness
# comment this out if you want to generate more
if args.debug:
import pdb; pdb.set_trace()
# Make sure all processes have finished saving their samples before attempting to convert to .npz
dist.barrier()
if rank == 0:
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
print("Done.")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/randar/randar_xl_0.7b.yaml")
parser.add_argument("--exp-name", type=str, required=True)
parser.add_argument("--gpt-ckpt", type=str, default=None)
parser.add_argument("--gpt-type", type=str, choices=["c2i", "t2i"], default="c2i")
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input",)
parser.add_argument("--precision", type=str, default="bf16", choices=["none", "fp16", "bf16"])
parser.add_argument("--compile", action="store_true", default=True)
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
parser.add_argument("--image-size", type=int, choices=[128, 256, 384, 512], default=256)
parser.add_argument("--image-size-eval", type=int, choices=[128, 256, 384, 512], default=256)
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--cfg-scales", type=str, default="1.0, 4.0")
parser.add_argument("--sample-dir", type=str, default="./samples")
parser.add_argument("--num-inference-steps", type=int, default=88)
parser.add_argument("--per-proc-batch-size", type=int, default=32)
parser.add_argument("--num-fid-samples", type=int, default=50000)
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--top-k", type=int, default=0, help="top-k value to sample with")
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
parser.add_argument("--debug", action="store_true", default=False)
args = parser.parse_args()
main(args)