Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable diffusion pipeline #168

Merged
merged 7 commits into from
Aug 14, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,49 @@ def __call__(
self.vae.to(torch_device)
self.text_encoder.to(torch_device)

# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]

do_classifier_free_guidance = guidance_scale > 1.0
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
# get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0:
if do_classifier_free_guidance:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]

# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat((uncond_embeddings, text_embeddings), dim=0)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

# get the intial random noise
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)

self.scheduler.set_timesteps(num_inference_steps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if not accepts_eta:
if accepts_eta:
extra_kwrags["eta"] = eta

self.scheduler.set_timesteps(num_inference_steps)

for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
context = text_embeddings
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = torch.cat([latents] * 2)
context = torch.cat([uncond_embeddings, text_embeddings])
# expand the latents if we are doing classifier free guidance
if do_classifier_free_guidance:
latents = torch.cat((latents, latents), dim=0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since text_embeddings are already expanded, we only need to expand latents here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that gave a problem with naming but should be fixed now


# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred = self.unet(latents, t, encoder_hidden_states=text_embeddings)["sample"]

# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]
Expand Down