Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PNDM in LDM pipeline] use inspect in pipeline instead of unused kwargs #167

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -59,6 +60,12 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
if not accepts_eta:
extra_kwrags["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
Expand All @@ -79,7 +86,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, eta=eta)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

import torch

from tqdm.auto import tqdm
Expand Down Expand Up @@ -31,11 +33,17 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
if not accepts_eta:
extra_kwrags["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]

# decode the image latents with the VAE
image = self.vqvae.decode(latents)
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def step(
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3

):
if self.counter < len(self.prk_timesteps):
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
Expand Down