diff --git a/mmagic/models/editors/controlnet/controlnet.py b/mmagic/models/editors/controlnet/controlnet.py index 272eaee079..77a61346f4 100644 --- a/mmagic/models/editors/controlnet/controlnet.py +++ b/mmagic/models/editors/controlnet/controlnet.py @@ -46,6 +46,9 @@ class ControlStableDiffusion(StableDiffusion): dtype (str, optional): The dtype for the model. Defaults to 'fp16'. enable_xformers (bool, optional): Whether to use xformers. Defaults to True. + noise_offset_weight (bool, optional): The weight of noise offset + introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa + Defaults to 0. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to dict(type='DataPreprocessor'). @@ -63,12 +66,14 @@ def __init__(self, test_scheduler: Optional[ModelType] = None, dtype: str = 'fp32', enable_xformers: bool = True, + noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor=dict(type='DataPreprocessor'), init_cfg: Optional[dict] = None): super().__init__(vae, text_encoder, tokenizer, unet, scheduler, - test_scheduler, dtype, enable_xformers, tomesd_cfg, - data_preprocessor, init_cfg) + test_scheduler, dtype, enable_xformers, + noise_offset_weight, tomesd_cfg, data_preprocessor, + init_cfg) default_args = dict() if dtype is not None: diff --git a/mmagic/models/editors/dreambooth/dreambooth.py b/mmagic/models/editors/dreambooth/dreambooth.py index 1b10332013..0ad2af3a8f 100644 --- a/mmagic/models/editors/dreambooth/dreambooth.py +++ b/mmagic/models/editors/dreambooth/dreambooth.py @@ -51,6 +51,9 @@ class DreamBooth(StableDiffusion): dtype (str, optional): The dtype for the model. Defaults to 'fp16'. enable_xformers (bool, optional): Whether to use xformers. Defaults to True. + noise_offset_weight (bool, optional): The weight of noise offset + introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa + Defaults to 0. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to dict(type='DataPreprocessor'). @@ -73,14 +76,16 @@ def __init__(self, finetune_text_encoder: bool = False, dtype: str = 'fp16', enable_xformers: bool = True, + noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor: Optional[ModelType] = dict( type='DataPreprocessor'), init_cfg: Optional[dict] = None): super().__init__(vae, text_encoder, tokenizer, unet, scheduler, - test_scheduler, dtype, enable_xformers, tomesd_cfg, - data_preprocessor, init_cfg) + test_scheduler, dtype, enable_xformers, + noise_offset_weight, tomesd_cfg, data_preprocessor, + init_cfg) self.num_class_images = num_class_images self.class_prior_prompt = class_prior_prompt self.prior_loss_weight = prior_loss_weight diff --git a/mmagic/models/editors/stable_diffusion/stable_diffusion.py b/mmagic/models/editors/stable_diffusion/stable_diffusion.py index cda4ff5c0b..38cbdff289 100644 --- a/mmagic/models/editors/stable_diffusion/stable_diffusion.py +++ b/mmagic/models/editors/stable_diffusion/stable_diffusion.py @@ -47,6 +47,9 @@ class StableDiffusion(BaseModel): when dtype is defined for submodels. Defaults to None. enable_xformers (bool, optional): Whether to use xformers. Defaults to True. + noise_offset_weight (bool, optional): The weight of noise offset + introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise + Defaults to 0. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. init_cfg (dict, optional): The weight initialized config for @@ -62,6 +65,7 @@ def __init__(self, test_scheduler: Optional[ModelType] = None, dtype: Optional[str] = None, enable_xformers: bool = True, + noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor: Optional[ModelType] = dict( type='DataPreprocessor'), @@ -102,6 +106,9 @@ def __init__(self, self.unet_sample_size = self.unet.sample_size self.vae_scale_factor = 2**(len(self.vae.block_out_channels) - 1) + self.enable_noise_offset = noise_offset_weight > 0 + self.noise_offset_weight = noise_offset_weight + self.enable_xformers = enable_xformers self.set_xformers() @@ -612,6 +619,15 @@ def train_step(self, data, optim_wrapper_dict): latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) + + if self.enable_noise_offset: + noise = noise + self.noise_offset_weight * torch.randn( + latents.shape[0], + latents.shape[1], + 1, + 1, + device=noise.device) + timesteps = torch.randint( 0, self.scheduler.num_train_timesteps, (num_batches, ),