Skip to content

Commit

Permalink
Added ability to pair samples with a closer noise with optimal_noise_…
Browse files Browse the repository at this point in the history
…pairing_samples
  • Loading branch information
jaretburkett committed Jan 22, 2025
1 parent 29122b1 commit 89dd041
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
36 changes: 29 additions & 7 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,37 @@ def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma

def get_optimal_noise(self, latents, dtype=torch.float32):
batch_num = latents.shape[0]
chunks = torch.chunk(latents, batch_num, dim=0)
noise_chunks = []
for chunk in chunks:
noise_samples = [torch.randn_like(chunk, device=chunk.device, dtype=dtype) for _ in range(self.train_config.optimal_noise_pairing_samples)]
# find the one most similar to the chunk
lowest_loss = 999999999999
best_noise = None
for noise in noise_samples:
loss = torch.nn.functional.mse_loss(chunk, noise)
if loss < lowest_loss:
lowest_loss = loss
best_noise = noise
noise_chunks.append(best_noise)
noise = torch.cat(noise_chunks, dim=0)
return noise


def get_noise(self, latents, batch_size, dtype=torch.float32):
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.optimal_noise_pairing_samples > 1:
noise = self.get_optimal_noise(latents, dtype=dtype)
else:
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)

if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
Expand Down
3 changes: 3 additions & 0 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ def __init__(self, **kwargs):

# diffusion feature extractor
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)

# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)


class ModelConfig:
Expand Down

0 comments on commit 89dd041

Please sign in to comment.