Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6676 port diffusion schedulers #7364

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions monai/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,10 @@
import numpy as np
import torch

from monai.utils import StrEnum

from .ddpm import DDPMPredictionType
from .scheduler import Scheduler


class DDIMPredictionType(StrEnum):
"""
Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument.

epsilon: predicting the noise of the diffusion process
sample: directly predicting the noisy sample
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
"""

EPSILON = "epsilon"
SAMPLE = "sample"
V_PREDICTION = "v_prediction"
DDIMPredictionType = DDPMPredictionType


class DDIMScheduler(Scheduler):
Expand Down Expand Up @@ -126,6 +113,13 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N

self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
if self.steps_offset >= step_ratio:
raise ValueError(
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
f" the max train timestep."
)

# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
Expand Down Expand Up @@ -159,7 +153,6 @@ def step(
timestep: current discrete timestep in the diffusion chain.
sample: current instance of sample being created by diffusion process.
eta: weight of noise for added noise in diffusion step.
predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.

Returns:
Expand Down
Loading