Skip to content

Commit

Permalink
[PNDM in LDM pipeline] use inspect in pipeline instead of unused kwar…
Browse files Browse the repository at this point in the history
…gs (#167)

use inspect instead of unused kwargs
  • Loading branch information
patil-suraj authored Aug 12, 2022
1 parent 3228eb1 commit c72e343
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -59,6 +60,12 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
if not accepts_eta:
extra_kwrags["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
Expand All @@ -79,7 +86,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, eta=eta)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

import torch

from tqdm.auto import tqdm
Expand Down Expand Up @@ -31,11 +33,17 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
if not accepts_eta:
extra_kwrags["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]

# decode the image latents with the VAE
image = self.vqvae.decode(latents)
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def step(
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
**kwargs,
):
if self.counter < len(self.prk_timesteps):
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
Expand Down

0 comments on commit c72e343

Please sign in to comment.