Skip to content

Commit

Permalink
model can run in other precisions without autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Sep 13, 2022
1 parent 39994cc commit d30f968
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def forward(self, x, temb):

# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)

if self.upsample is not None:
Expand All @@ -351,7 +351,7 @@ def forward(self, x, temb):

# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)

hidden_states = self.dropout(hidden_states)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def forward(
timesteps = timesteps.expand(sample.shape[0])

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
emb = self.time_embedding(t_emb.to(self.dtype))

# 2. pre-process
sample = self.conv_in(sample)
Expand Down Expand Up @@ -215,7 +215,7 @@ def forward(
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __call__(
latents_shape,
generator=generator,
device=self.device,
dtype=text_embeddings.dtype,
)
else:
if latents.shape != latents_shape:
Expand Down Expand Up @@ -263,7 +264,7 @@ def __call__(

# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype))

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down

0 comments on commit d30f968

Please sign in to comment.