Skip to content

Commit

Permalink
allow custom height, width in StableDiffusionPipeline (#179)
Browse files Browse the repository at this point in the history
* allow custom height width

* raise if height width are not mul of 8
  • Loading branch information
patil-suraj authored Aug 15, 2022
1 parent c25d8c9 commit 5f25818
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
Expand All @@ -45,6 +47,9 @@ def __call__(
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

self.unet.to(torch_device)
self.vae.to(torch_device)
self.text_encoder.to(torch_device)
Expand Down Expand Up @@ -72,7 +77,7 @@ def __call__(

# get the intial random noise
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
Expand Down

0 comments on commit 5f25818

Please sign in to comment.