Skip to content

Commit

Permalink
[Core] introduce controlnet module (#8768)
Browse files Browse the repository at this point in the history
* move vae flax module.

* controlnet module.

* prepare for PR.

* revert a commit

* gracefully deprecate controlnet deps.

* fix

* fix doc path

* fix-copies

* fix path

* style

* style

* conflicts

* fix

* fix-copies

* sparsectrl.

* updates

* fix

* updates

* updates

* updates

* fix

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
  • Loading branch information
sayakpaul and DN6 committed Dec 23, 2024
1 parent 221d6db commit e92bbf4
Show file tree
Hide file tree
Showing 26 changed files with 2,970 additions and 2,752 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/api/models/controlnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro

## ControlNetOutput

[[autodoc]] models.controlnet.ControlNetOutput
[[autodoc]] models.controlnets.controlnet.ControlNetOutput

## FlaxControlNetModel

[[autodoc]] FlaxControlNetModel

## FlaxControlNetOutput

[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput
2 changes: 1 addition & 1 deletion docs/source/en/api/models/controlnet_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di

## SD3ControlNetOutput

[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput

Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def forward(
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
[`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@


else:
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
Expand Down Expand Up @@ -914,7 +914,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
from .models.controlnet_flax import FlaxControlNetModel
from .models.controlnets.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
Expand Down
39 changes: 25 additions & 14 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DMultiControlNetModel",
]
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
Expand Down Expand Up @@ -74,7 +78,7 @@
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]

if is_flax_available():
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]

Expand All @@ -94,12 +98,19 @@
ConsistencyDecoderVAE,
VQModel,
)
from .controlnet import ControlNetModel
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_sparsectrl import SparseControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .controlnets import (
ControlNetModel,
ControlNetXSAdapter,
FluxControlNetModel,
FluxMultiControlNetModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DMultiControlNetModel,
MultiControlNetModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SparseControlNetModel,
UNetControlNetXSModel,
)
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
Expand Down Expand Up @@ -137,7 +148,7 @@
)

if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
from .controlnets import FlaxControlNetModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL

Expand Down
Loading

0 comments on commit e92bbf4

Please sign in to comment.