-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
202 lines (157 loc) · 6.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import itertools
import math
from os import path
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from ray.air import session, ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from transformers import CLIPTextModel
from dataset import collate, get_train_dataset
from flags import train_arguments
def prior_preserving_loss(model_pred, target, weight):
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
return loss + weight * prior_loss
def get_target(scheduler, noise):
"""Get the target for loss depending on the prediction type.
"""
pred_type = scheduler.config.prediction_type
if pred_type == "epsilon":
return noise
if pred_type == "v_prediction":
return scheduler.get_velocity(latents, noise, timesteps)
raise ValueError(f"Unknown prediction type {pred_type}")
def load_models(args, cuda):
"""Load pre-trained Stable Diffusion models.
"""
# Load all models in bfloat16 to save GRAM.
# For models that are only used for inferencing, full precision is also not required.
dtype = torch.bfloat16
text_encoder = CLIPTextModel.from_pretrained(
args.model_dir, subfolder="text_encoder", torch_dtype=dtype,
)
text_encoder.to(cuda[1])
text_encoder.train()
noise_scheduler = DDPMScheduler.from_pretrained(
args.model_dir, subfolder="scheduler", torch_dtype=dtype,
)
# VAE is only used for inference, keeping weights in full precision is not required.
vae = AutoencoderKL.from_pretrained(
args.model_dir, subfolder="vae", torch_dtype=dtype,
)
# We are not training VAE part of the model.
vae.requires_grad_(False)
vae.to(cuda[1])
# Convert unet to bf16 to save GRAM.
unet = UNet2DConditionModel.from_pretrained(
args.model_dir, subfolder="unet", torch_dtype=dtype,
)
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
# UNET is the largest component, occupying first GPU by itself.
unet.to(cuda[0])
unet.train()
torch.cuda.empty_cache()
return text_encoder, noise_scheduler, vae, unet
def get_cuda_devices():
devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
assert len(devices) >= 2, "Require at least 2 GPU devices to work."
return devices
def train_fn(args):
cuda = get_cuda_devices()
# Load pre-trained models.
text_encoder, noise_scheduler, vae, unet = load_models(args, cuda)
# Use the regular AdamW optimizer to work with bfloat16 weights.
optimizer = torch.optim.AdamW(
itertools.chain(text_encoder.parameters(), unet.parameters()),
lr=args.lr,
)
train_dataset = session.get_dataset_shard("train")
# Train!
num_update_steps_per_epoch = train_dataset.count()
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
total_batch_size = args.train_batch_size
print(f"Running {num_train_epochs} epochs. Max training steps {args.max_train_steps}.")
global_step = 0
for epoch in range(num_train_epochs):
for step, batch in enumerate(
train_dataset.iter_torch_batches(batch_size=args.train_batch_size)
):
# Load batch on GPU 2 because VAE and text encoder are there.
batch = collate(batch, cuda[1], torch.bfloat16)
optimizer.zero_grad()
# Convert images to latent space
latents = vae.encode(batch["images"]).latent_dist.sample() * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]
# Predict the noise residual. We need to move all data bits to GPU 1.
model_pred = unet(
noisy_latents.to(cuda[0]), timesteps.to(cuda[0]), encoder_hidden_states.to(cuda[0])
).sample
target = get_target(noise_scheduler, noise).to(cuda[0])
# Now, move model prediction to GPU 2 for loss calculation.
loss = prior_preserving_loss(model_pred, target, args.prior_loss_weight)
loss.backward()
# Gradient clipping before optimizer stepping.
clip_grad_norm_(
itertools.chain(text_encoder.parameters(), unet.parameters()),
args.max_grad_norm
)
optimizer.step() # Step all optimizers.
global_step += 1
results = {
"step": global_step,
"loss": loss.detach().item(),
}
session.report(results)
if global_step >= args.max_train_steps:
break
# Create pipeline using the trained modules and save it.
pipeline = DiffusionPipeline.from_pretrained(
args.model_dir,
text_encoder=text_encoder,
unet=unet,
)
pipeline.save_pretrained(args.output_dir)
if __name__ == "__main__":
args = train_arguments().parse_args()
# Build training dataset.
train_dataset = get_train_dataset(args)
print(f"Loaded training dataset (size: {train_dataset.count()})")
# Train with Ray AIR TorchTrainer.
trainer = TorchTrainer(
train_fn,
train_loop_config=args,
scaling_config=ScalingConfig(
use_gpu=True,
num_workers=1,
resources_per_worker={
"GPU": 2,
}
),
datasets={
"train": train_dataset,
},
)
result = trainer.fit()
print(result)