Skip to content

Commit

Permalink
Handle batches and Tensors in `pipeline_stable_diffusion_inpaint.py:p…
Browse files Browse the repository at this point in the history
…repare_mask_and_masked_image` (open-mmlab#1003)

* Handle batches and Tensors in `prepare_mask_and_masked_image`

* `blackfy`
upgrade `black`

* handle mask as `np.array`

* add docstring

* revert `black` changes with smaller line length

* missing ValueError in docstring

* raise `TypeError` for image as tensor but not mask

* typo in mask shape selection

* check for batch dim

* fix: wrong indentation

* add tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
vict0rsch and patrickvonplaten authored Nov 20, 2022
1 parent eb2425b commit 3bec90f
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,93 @@


def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline.
This means that those inputs will be converted to ``torch.Tensor`` with
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
the ``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to
``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array``
or a ``channels x height x width`` ``torch.Tensor`` or a
``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or
a ``1 x height x width`` ``torch.Tensor`` or a
``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range.
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range.
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")

# Batch single image
if image.ndim == 3:
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)

# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)

# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)

# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)

assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"

# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")

# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")

# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1

# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
if isinstance(image, PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)

masked_image = image * (mask < 0.5)

Expand Down
171 changes: 171 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
UNet2DModel,
VQModel,
)

from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image

from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -506,3 +509,171 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9

class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
def test_pil_inputs(self):
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
im = Image.fromarray(im)
mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
mask = Image.fromarray((mask * 255).astype(np.uint8))

t_mask, t_masked = prepare_mask_and_masked_image(im, mask)

self.assertTrue(isinstance(t_mask, torch.Tensor))
self.assertTrue(isinstance(t_masked, torch.Tensor))

self.assertEqual(t_mask.ndim, 4)
self.assertEqual(t_masked.ndim, 4)

self.assertEqual(t_mask.shape, (1, 1, 32, 32))
self.assertEqual(t_masked.shape, (1, 3, 32, 32))

self.assertTrue(t_mask.dtype == torch.float32)
self.assertTrue(t_masked.dtype == torch.float32)

self.assertTrue(t_mask.min() >= 0.0)
self.assertTrue(t_mask.max() <= 1.0)
self.assertTrue(t_masked.min() >= -1.0)
self.assertTrue(t_masked.min() <= 1.0)

self.assertTrue(t_mask.sum() > 0.0)

def test_np_inputs(self):
im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
im_pil = Image.fromarray(im_np)
mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))

t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)

self.assertTrue((t_mask_np == t_mask_pil).all())
self.assertTrue((t_masked_np == t_masked_pil).all())

def test_torch_3D_2D_inputs(self):
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_3D_3D_inputs(self):
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_4D_2D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_4D_3D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_4D_4D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0][0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_batch_4D_3D(self):
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5

im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy() for mask in mask_tensor]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_batch_4D_4D(self):
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5

im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy()[0] for mask in mask_tensor]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_shape_mismatch(self):
# test height and width
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))

def test_type_mismatch(self):
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))

def test_channels_first(self):
# test channels first for 3D tensors
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))

def test_tensor_range(self):
# test im <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))
# test im >= -1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))
# test mask <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
# test mask >= 0
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)

0 comments on commit 3bec90f

Please sign in to comment.