diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index ec47ff8dc6..78e3cc2a0c 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -34,23 +34,10 @@ import numpy as np import torch -from monai.utils import StrEnum - +from .ddpm import DDPMPredictionType from .scheduler import Scheduler - -class DDIMPredictionType(StrEnum): - """ - Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. - - epsilon: predicting the noise of the diffusion process - sample: directly predicting the noisy sample - v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf - """ - - EPSILON = "epsilon" - SAMPLE = "sample" - V_PREDICTION = "v_prediction" +DDIMPredictionType = DDPMPredictionType class DDIMScheduler(Scheduler): @@ -126,6 +113,13 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps + if self.steps_offset >= step_ratio: + raise ValueError( + f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " + f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" + f" the max train timestep." + ) + # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) @@ -159,7 +153,6 @@ def step( timestep: current discrete timestep in the diffusion chain. sample: current instance of sample being created by diffusion process. eta: weight of noise for added noise in diffusion step. - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. Returns: