Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable diffusion pipeline #168

Merged
merged 7 commits into from
Aug 14, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


if is_transformers_available():
from .pipelines import LDMTextToImagePipeline
from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline

else:
from .utils.dummy_transformers_objects import *
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@

if is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import StableDiffusionPipeline
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def __call__(

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
extra_kwargs = {}
if not accepts_eta:
extra_kwrags["eta"] = eta
extra_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0:
Expand All @@ -86,7 +86,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def __call__(

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
extra_kwargs = {}
if not accepts_eta:
extra_kwrags["eta"] = eta
extra_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]

# decode the image latents with the VAE
image = self.vqvae.decode(latents)
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ...utils import is_transformers_available


if is_transformers_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline
115 changes: 115 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import inspect
from typing import List, Optional, Union

import torch

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler


class StableDiffusionPipeline(DiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

if isinstance(prompt, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj -> made sure a string can be passed as input

batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

self.unet.to(torch_device)
self.vae.to(torch_device)
self.text_encoder.to(torch_device)

# get prompt text embeddings
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj removed the weird 77 max length here

text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# get the intial random noise
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta

self.scheduler.set_timesteps(num_inference_steps)

for t in tqdm(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj there was a bug with the naming here -> corrected it


# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)

return {"sample": image}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

class KarrasVePipeline(DiffusionPipeline):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.

[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
"""

unet: UNet2DModel
Expand Down
20 changes: 10 additions & 10 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@

class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.

[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
"""

@register_to_config
Expand All @@ -43,10 +44,9 @@ def __init__(
tensor_format="pt",
):
"""
For more details on the parameters, see the original paper's Appendix E.:
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364.
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model
are described in Table 5 of the paper.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.

Args:
sigma_min (`float`): minimum noise magnitude
Expand Down Expand Up @@ -81,8 +81,8 @@ def set_timesteps(self, num_inference_steps):

def add_noise_to_input(self, sample, sigma, generator=None):
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
"""
if self.s_min <= sigma <= self.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
Expand Down
7 changes: 7 additions & 0 deletions src/diffusers/utils/dummy_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ class LDMTextToImagePipeline(metaclass=DummyObject):

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])


class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
34 changes: 34 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel

from ..src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline


torch.backends.cuda.matmul.allow_tf32 = False

Expand Down Expand Up @@ -839,6 +841,38 @@ def test_ldm_text2img_fast(self):
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
def test_stable_diffusion(self):
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]

image_slice = image[0, -3:, -3:, -1]

# TODO: update the expected_slice
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
def test_stable_diffusion_fast(self):
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]

image_slice = image[0, -3:, -3:, -1]

# TODO: update the expected_slice
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
def test_score_sde_ve_pipeline(self):
model_id = "google/ncsnpp-church-256"
Expand Down