From 38a3e4df926c59bc122191c0fc8066755e98b6d2 Mon Sep 17 00:00:00 2001 From: Subho Ghosh <93722719+ighoshsubho@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:59:02 +0530 Subject: [PATCH] flux controlnet control_guidance_start and control_guidance_end implement (#9571) * flux controlnet control_guidance_start and control_guidance_end implement * minor fix - added docstrings, consistent controlnet scale flux and SD3 --- .../flux/pipeline_flux_controlnet.py | 42 +++++++++++++++++-- ...pipeline_flux_controlnet_image_to_image.py | 35 +++++++++++++++- .../pipeline_flux_controlnet_inpainting.py | 37 +++++++++++++++- 3 files changed, 108 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 6c072c4820206..a301f6742c057 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -72,7 +72,9 @@ >>> image = pipe( ... prompt, ... control_image=control_image, - ... controlnet_conditioning_scale=0.6, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=1.0, ... num_inference_steps=28, ... guidance_scale=3.5, ... ).images[0] @@ -572,6 +574,8 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, control_image: PipelineImageInput = None, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, @@ -614,6 +618,10 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is @@ -674,6 +682,17 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -839,7 +858,16 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 6. Denoising loop + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -856,12 +884,20 @@ def __call__( guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + # controlnet controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, - conditioning_scale=controlnet_conditioning_scale, + conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index deeb9e3f546a8..b5ff2236c4d0d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -69,7 +69,9 @@ ... prompt, ... image=init_image, ... control_image=control_image, - ... controlnet_conditioning_scale=0.6, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=1.0, ... strength=0.7, ... num_inference_steps=2, ... guidance_scale=3.5, @@ -631,6 +633,8 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, num_images_per_prompt: Optional[int] = 1, @@ -710,6 +714,17 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + self.check_inputs( prompt, prompt_2, @@ -862,6 +877,14 @@ def __call__( latents, ) + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -877,11 +900,19 @@ def __call__( ) guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, - conditioning_scale=controlnet_conditioning_scale, + conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index e763200155f6a..1c1c253021006 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -71,6 +71,8 @@ ... image=init_image, ... mask_image=mask_image, ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, ... controlnet_conditioning_scale=0.7, ... strength=0.7, ... num_inference_steps=28, @@ -737,6 +739,8 @@ def __call__( timesteps: List[int] = None, num_inference_steps: int = 28, guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, num_images_per_prompt: Optional[int] = 1, @@ -783,6 +787,10 @@ def __call__( Custom timesteps to use for the denoising process. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. control_mode (`int` or `List[int]`, *optional*): The mode for the ControlNet. If multiple ControlNets are used, this should be a list. controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): @@ -826,6 +834,17 @@ def __call__( global_height = height global_width = width + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + # 1. Check inputs self.check_inputs( prompt, @@ -1031,6 +1050,14 @@ def __call__( generator, ) + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -1049,11 +1076,19 @@ def __call__( else: guidance = None + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, - conditioning_scale=controlnet_conditioning_scale, + conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds,