Skip to content

Commit

Permalink
[Chore] remove deprecation from transformer2d regarding the output cl…
Browse files Browse the repository at this point in the history
…ass. (#8698)

* remove deprecation from transformer2d regarding the output class.

* up

* deprecate more
  • Loading branch information
sayakpaul committed Dec 23, 2024
1 parent 027d73f commit 23145f2
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/api/models/transformer2d.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted

## Transformer2DModelOutput

[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
from ..models.attention import JointTransformerBlock
from ..models.attention_processor import Attention, AttentionProcessor
from ..models.modeling_outputs import Transformer2DModelOutput
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from .transformers.transformer_2d import Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@


class Transformer2DModelOutput(Transformer2DModelOutput):
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)


class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...models.normalization import AdaLayerNormContinuous
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from .transformer_2d import Transformer2DModelOutput
from ..modeling_outputs import Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down

0 comments on commit 23145f2

Please sign in to comment.