Skip to content

Commit

Permalink
Removing .float() (autocast in fp16 will discard this (I think)). (
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Sep 14, 2022
1 parent ab7a78e commit 7c4b38b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def forward(self, x, temb):

# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)

if self.upsample is not None:
Expand All @@ -351,7 +351,7 @@ def forward(self, x, temb):

# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)

hidden_states = self.dropout(hidden_states)
Expand Down

0 comments on commit 7c4b38b

Please sign in to comment.