From 2f3d6169833717e6ccb0f5c1d0b5c4abd853de32 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Wed, 27 Nov 2024 16:49:09 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20Replace=20`imgaug`=20with=20Nati?= =?UTF-8?q?ve=20PyTorch=20Transforms=20(#2436)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add multi random choice transform * Add DRAEMAugmenter class and Perlin noise generation to new_perlin.py - Introduced DRAEMAugmenter for advanced image augmentations using torchvision v2. - Implemented various augmentation techniques including ColorJitter, RandomAdjustSharpness, and custom transformations. - Added functionality for comparing augmentation methods and visualizing results. - Included utility functions for metrics computation and image processing. - Established logging for better traceability of operations. This commit enhances the image processing capabilities within the Anomalib framework, facilitating more robust anomaly detection workflows. * Add the new perlin noise Signed-off-by: Samet Akcay * Add the new perlin noise Signed-off-by: Samet Akcay * add generate_perlin_noise relative import Signed-off-by: Samet Akcay * add tiffile as a dependency Signed-off-by: Samet Akcay * Remove upper bound from wandb Signed-off-by: Samet Akcay * Added skimage Signed-off-by: Samet Akcay * add scikit-learn as a dependency Signed-off-by: Samet Akcay * limit ollama to < 0.4.0 as it has breaking changes Signed-off-by: Samet Akcay * Fix data generators in test helpers Signed-off-by: Samet Akcay * Update the perlin augmenters Signed-off-by: Samet Akcay * Fix numpy validator tests caused by numpy upgrade Signed-off-by: Samet Akcay * Fix CS-Flow tests Signed-off-by: Samet Akcay * Fix the tests Signed-off-by: Samet Akcay --------- Signed-off-by: Samet Akcay --- pyproject.toml | 7 +- src/anomalib/data/transforms/__init__.py | 3 +- .../data/transforms/multi_random_choice.py | 82 ++++ src/anomalib/data/utils/__init__.py | 6 +- src/anomalib/data/utils/augmenter.py | 171 -------- .../data/utils/generators/__init__.py | 4 +- src/anomalib/data/utils/generators/perlin.py | 401 ++++++++++++------ src/anomalib/data/utils/synthetic.py | 7 +- .../models/image/draem/lightning_model.py | 6 +- .../models/image/dsr/anomaly_generator.py | 29 +- .../models/image/dsr/lightning_model.py | 6 +- tests/helpers/data.py | 7 +- tests/integration/model/test_models.py | 27 +- .../unit/data/validators/numpy/test_image.py | 9 +- .../unit/data/validators/numpy/test_video.py | 24 +- tests/unit/metrics/test_pro.py | 6 +- 16 files changed, 453 insertions(+), 342 deletions(-) create mode 100644 src/anomalib/data/transforms/multi_random_choice.py delete mode 100644 src/anomalib/data/utils/augmenter.py diff --git a/pyproject.toml b/pyproject.toml index e47f7e55d8..805795da40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,11 +42,12 @@ core = [ "av>=10.0.0", "einops>=0.3.2", "freia>=0.2", - "imgaug==0.4.0", "kornia>=0.6.6", "matplotlib>=3.4.3", "opencv-python>=4.5.3.56", "pandas>=1.1.0", + "scikit-image", # NOTE: skimage should be removed as part of dependency cleanup + "tifffile", "timm", "lightning>=2.2", "torch>=2", @@ -57,12 +58,12 @@ core = [ "open-clip-torch>=2.23.0,<2.26.1", ] openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"] -vlm = ["ollama", "openai", "python-dotenv","transformers"] +vlm = ["ollama<0.4.0", "openai", "python-dotenv","transformers"] loggers = [ "comet-ml>=3.31.7", "gradio>=4", "tensorboard", - "wandb>=0.12.17,<=0.15.9", + "wandb", "mlflow >=1.0.0", ] notebooks = ["gitpython", "ipykernel", "ipywidgets", "notebook"] diff --git a/src/anomalib/data/transforms/__init__.py b/src/anomalib/data/transforms/__init__.py index 146fb19e15..89a5c673d2 100644 --- a/src/anomalib/data/transforms/__init__.py +++ b/src/anomalib/data/transforms/__init__.py @@ -4,5 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 from .center_crop import ExportableCenterCrop +from .multi_random_choice import MultiRandomChoice -__all__ = ["ExportableCenterCrop"] +__all__ = ["ExportableCenterCrop", "MultiRandomChoice"] diff --git a/src/anomalib/data/transforms/multi_random_choice.py b/src/anomalib/data/transforms/multi_random_choice.py new file mode 100644 index 0000000000..1d507c17a2 --- /dev/null +++ b/src/anomalib/data/transforms/multi_random_choice.py @@ -0,0 +1,82 @@ +"""Multi random choice transform.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Callable, Sequence + +import torch +from torchvision.transforms import v2 + + +class MultiRandomChoice(v2.Transform): + """Apply multiple transforms randomly picked from a list. + + This transform does not support torchscript. + + Args: + transforms (sequence or torch.nn.Module): List of transformations to choose from. + probabilities (list[float] | None, optional): Probability of each transform being picked. + If None (default), all transforms have equal probability. If provided, probabilities + will be normalized to sum to 1. + num_transforms (int): Maximum number of transforms to apply at once. + Defaults to ``1``. + fixed_num_transforms (bool): If ``True``, always applies exactly ``num_transforms`` transforms. + If ``False``, randomly picks between 1 and ``num_transforms``. + Defaults to ``False``. + + Examples: + >>> import torchvision.transforms.v2 as v2 + >>> transforms = [ + ... v2.RandomHorizontalFlip(p=1.0), + ... v2.ColorJitter(brightness=0.5), + ... v2.RandomRotation(10), + ... ] + >>> # Apply 1-2 random transforms with equal probability + >>> transform = MultiRandomChoice(transforms, num_transforms=2) + + >>> # Always apply exactly 2 transforms with custom probabilities + >>> transform = MultiRandomChoice( + ... transforms, + ... probabilities=[0.5, 0.3, 0.2], + ... num_transforms=2, + ... fixed_num_transforms=True + ... ) + """ + + def __init__( + self, + transforms: Sequence[Callable], + probabilities: list[float] | None = None, + num_transforms: int = 1, + fixed_num_transforms: bool = False, + ) -> None: + if not isinstance(transforms, Sequence): + msg = "Argument transforms should be a sequence of callables" + raise TypeError(msg) + + if probabilities is None: + probabilities = [1.0] * len(transforms) + elif len(probabilities) != len(transforms): + msg = f"Length of p doesn't match the number of transforms: {len(probabilities)} != {len(transforms)}" + raise ValueError(msg) + + super().__init__() + + self.transforms = transforms + total = sum(probabilities) + self.probabilities = [probability / total for probability in probabilities] + + self.num_transforms = num_transforms + self.fixed_num_transforms = fixed_num_transforms + + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Apply randomly selected transforms to the input.""" + # First determine number of transforms to apply + num_transforms = ( + self.num_transforms if self.fixed_num_transforms else int(torch.randint(self.num_transforms, (1,)) + 1) + ) + # Get transforms + idx = torch.multinomial(torch.tensor(self.probabilities), num_transforms).tolist() + transform = v2.Compose([self.transforms[i] for i in idx]) + return transform(*inputs) diff --git a/src/anomalib/data/utils/__init__.py b/src/anomalib/data/utils/__init__.py index e75ba5bf49..570c45af4a 100644 --- a/src/anomalib/data/utils/__init__.py +++ b/src/anomalib/data/utils/__init__.py @@ -3,10 +3,9 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .augmenter import Augmenter from .boxes import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes from .download import DownloadInfo, download_and_extract -from .generators import random_2d_perlin +from .generators import generate_perlin_noise from .image import ( generate_output_image_filename, get_image_filenames, @@ -30,7 +29,7 @@ "generate_output_image_filename", "get_image_filenames", "get_image_height_and_width", - "random_2d_perlin", + "generate_perlin_noise", "read_image", "read_mask", "read_depth_image", @@ -42,7 +41,6 @@ "TestSplitMode", "LabelName", "DirType", - "Augmenter", "masks_to_boxes", "boxes_to_masks", "boxes_to_anomaly_maps", diff --git a/src/anomalib/data/utils/augmenter.py b/src/anomalib/data/utils/augmenter.py deleted file mode 100644 index aa35434773..0000000000 --- a/src/anomalib/data/utils/augmenter.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation.""" - -# Original Code -# Copyright (c) 2021 VitjanZ -# https://github.com/VitjanZ/DRAEM. -# SPDX-License-Identifier: MIT -# -# Modified -# Copyright (C) 2022-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import math -import random -from pathlib import Path - -import cv2 -import imgaug.augmenters as iaa -import numpy as np -import torch -from PIL import Image -from torchvision.datasets.folder import IMG_EXTENSIONS - -from anomalib.data.utils.generators.perlin import random_2d_perlin - - -def nextpow2(value: int) -> int: - """Return the smallest power of 2 greater than or equal to the input value.""" - return 2 ** (math.ceil(math.log(value, 2))) - - -class Augmenter: - """Class that generates noisy augmentations of input images. - - Args: - anomaly_source_path (str | None): Path to a folder of images that will be used as source of the anomalous - noise. If not specified, random noise will be used instead. - p_anomalous (float): Probability that the anomalous perturbation will be applied to a given image. - beta (float): Parameter that determines the opacity of the noise mask. - """ - - def __init__( - self, - anomaly_source_path: str | None = None, - p_anomalous: float = 0.5, - beta: float | tuple[float, float] = (0.2, 1.0), - ) -> None: - self.p_anomalous = p_anomalous - self.beta = beta - - self.anomaly_source_paths: list[Path] = [] - if anomaly_source_path is not None: - for img_ext in IMG_EXTENSIONS: - self.anomaly_source_paths.extend(Path(anomaly_source_path).rglob("*" + img_ext)) - - self.augmenters = [ - iaa.GammaContrast((0.5, 2.0), per_channel=True), - iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)), - iaa.pillike.EnhanceSharpness(), - iaa.AddToHueAndSaturation((-50, 50), per_channel=True), - iaa.Solarize(0.5, threshold=(32, 128)), - iaa.Posterize(), - iaa.Invert(), - iaa.pillike.Autocontrast(), - iaa.pillike.Equalize(), - iaa.Affine(rotate=(-45, 45)), - ] - self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) - - def rand_augmenter(self) -> iaa.Sequential: - """Select 3 random transforms that will be applied to the anomaly source images. - - Returns: - A selection of 3 transforms. - """ - aug_ind = np.random.default_rng().choice(np.arange(len(self.augmenters)), 3, replace=False) - return iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]]) - - def generate_perturbation( - self, - height: int, - width: int, - anomaly_source_path: Path | str | None = None, - ) -> tuple[np.ndarray, np.ndarray]: - """Generate an image containing a random anomalous perturbation using a source image. - - Args: - height (int): height of the generated image. - width: (int): width of the generated image. - anomaly_source_path (Path | str | None): Path to an image file. If not provided, random noise will be used - instead. - - Returns: - Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask. - """ - # Generate random perlin noise - perlin_scale = 6 - min_perlin_scale = 0 - - perlin_scalex = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale) - perlin_scaley = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale) - - perlin_noise = random_2d_perlin((nextpow2(height), nextpow2(width)), (perlin_scalex, perlin_scaley))[ - :height, - :width, - ] - perlin_noise = self.rot(image=perlin_noise) - - # Create mask from perlin noise - mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) - mask = np.expand_dims(mask, axis=2).astype(np.float32) - - # Load anomaly source image - if anomaly_source_path: - anomaly_source_img = np.array(Image.open(anomaly_source_path)) - anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height)) - else: # if no anomaly source is specified, we use the perlin noise as anomalous source - anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2) - anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8) - - # Augment anomaly source image - aug = self.rand_augmenter() - anomaly_img_augmented = aug(image=anomaly_source_img) - - # Create anomalous perturbation that we will apply to the image - perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0 - - return perturbation, mask - - def augment_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Generate anomalous augmentations for a batch of input images. - - Args: - batch (torch.Tensor): Batch of input images - - Returns: - - Augmented image to which anomalous perturbations have been added. - - Ground truth masks corresponding to the anomalous perturbations. - """ - batch_size, channels, height, width = batch.shape - - # Collect perturbations - perturbations_list = [] - masks_list = [] - for _ in range(batch_size): - if torch.rand(1) > self.p_anomalous: # include normal samples - perturbations_list.append(torch.zeros((channels, height, width))) - masks_list.append(torch.zeros((1, height, width))) - else: - anomaly_source_path = ( - random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None - ) - perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path) - perturbations_list.append(torch.Tensor(perturbation).permute((2, 0, 1))) - masks_list.append(torch.Tensor(mask).permute((2, 0, 1))) - - perturbations = torch.stack(perturbations_list).to(batch.device) - masks = torch.stack(masks_list).to(batch.device) - - # Apply perturbations batch wise - if isinstance(self.beta, float): - beta = self.beta - elif isinstance(self.beta, tuple): - beta = torch.rand(batch_size) * (self.beta[1] - self.beta[0]) + self.beta[0] - beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device) # type: ignore[attr-defined] - else: - msg = "Beta must be either float or tuple of floats" - raise TypeError(msg) - - augmented_batch = batch * (1 - masks) + (beta) * perturbations + (1 - beta) * batch * (masks) - - return augmented_batch, masks diff --git a/src/anomalib/data/utils/generators/__init__.py b/src/anomalib/data/utils/generators/__init__.py index a79bad9770..c46f30d08e 100644 --- a/src/anomalib/data/utils/generators/__init__.py +++ b/src/anomalib/data/utils/generators/__init__.py @@ -3,6 +3,6 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .perlin import random_2d_perlin +from .perlin import PerlinAnomalyGenerator, generate_perlin_noise -__all__ = ["random_2d_perlin"] +__all__ = ["PerlinAnomalyGenerator", "generate_perlin_noise"] diff --git a/src/anomalib/data/utils/generators/perlin.py b/src/anomalib/data/utils/generators/perlin.py index fa683d7546..052d565121 100644 --- a/src/anomalib/data/utils/generators/perlin.py +++ b/src/anomalib/data/utils/generators/perlin.py @@ -1,160 +1,317 @@ -"""Helper functions for generating Perlin noise.""" - -# Original Code -# Copyright (c) 2021 VitjanZ -# https://github.com/VitjanZ/DRAEM. -# SPDX-License-Identifier: MIT -# -# Modified +"""Perlin noise-based synthetic anomaly generator.""" + # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# ruff: noqa - -import math +from pathlib import Path -import numpy as np import torch +from torchvision import io +from torchvision.datasets.folder import IMG_EXTENSIONS +from torchvision.transforms import v2 +from anomalib.data.transforms import MultiRandomChoice -def lerp_np(x, y, w): - """Helper function.""" - return (y - x) * w + x - - -def rand_perlin_2d_octaves_np(shape, res, octaves=1, persistence=0.5): - """Generate Perlin noise parameterized by the octaves method. Numpy version.""" - noise = np.zeros(shape) - frequency = 1 - amplitude = 1 - for _ in range(octaves): - noise += amplitude * generate_perlin_noise_2d(shape, (frequency * res[0], frequency * res[1])) - frequency *= 2 - amplitude *= persistence - return noise +def generate_perlin_noise( + height: int, + width: int, + scale: tuple[int, int] | None = None, + device: torch.device | None = None, +) -> torch.Tensor: + """Generate a Perlin noise pattern. -def generate_perlin_noise_2d(shape, res): - """Fractal perlin noise.""" - - def f(t): - return 6 * t**5 - 15 * t**4 + 10 * t**3 - - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 - # Gradients - angles = 2 * np.pi * np.random.default_rng().random(res[0] + 1, res[1] + 1) - gradients = np.dstack((np.cos(angles), np.sin(angles))) - g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) - g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) - g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) - g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) - # Ramps - n00 = np.sum(grid * g00, 2) - n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) - n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) - n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) - # Interpolation - t = f(grid) - n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 - n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 - return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) - - -def random_2d_perlin( - shape: tuple, - res: tuple[int | torch.Tensor, int | torch.Tensor], - fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3, -) -> np.ndarray | torch.Tensor: - """Returns a random 2d perlin noise array. + This function generates a Perlin noise pattern using a grid-based gradient noise approach. + The noise is generated by interpolating between randomly generated gradient vectors at grid vertices. + The interpolation uses a quintic curve for smooth transitions. Args: - shape (tuple): Shape of the 2d map. - res (tuple[int | torch.Tensor, int | torch.Tensor]): Tuple of scales for perlin noise for height and width dimension. - fade (_type_, optional): Function used for fading the resulting 2d map. - Defaults to equation 6*t**5-15*t**4+10*t**3. + height: Desired height of the noise pattern + width: Desired width of the noise pattern + scale: Tuple of (scale_x, scale_y) for noise granularity. If None, random scales will be used. + Larger scales produce coarser noise patterns, while smaller scales produce finer patterns. + device: Device to generate the noise on. If None, uses current default device Returns: - np.ndarray | torch.Tensor: Random 2d-array/tensor generated using perlin noise. - """ - if isinstance(res[0], int | np.integer): - result = _rand_perlin_2d_np(shape, res, fade) - elif isinstance(res[0], torch.Tensor): - result = _rand_perlin_2d(shape, res, fade) - else: - msg = f"got scales of type {type(res[0])}" - raise TypeError(msg) - return result + Tensor of shape [height, width] containing the noise pattern, with values roughly in [-1, 1] range + Examples: + >>> # Generate 256x256 noise with default random scale + >>> noise = generate_perlin_noise(256, 256) + >>> print(noise.shape) + torch.Size([256, 256]) -def _rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): - """Generate a random image containing Perlin noise. Numpy version.""" - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 + >>> # Generate 512x512 noise with fixed scale + >>> noise = generate_perlin_noise(512, 512, scale=(8, 8)) + >>> print(noise.shape) + torch.Size([512, 512]) - angles = 2 * math.pi * np.random.default_rng().random((res[0] + 1, res[1] + 1)) - gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) + >>> # Generate noise on GPU if available + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + >>> noise = generate_perlin_noise(128, 128, device=device) + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def tile_grads(slice1, slice2): - return np.repeat(np.repeat(gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]], d[0], axis=0), d[1], axis=1) + # Handle scale parameter + if scale is None: + min_scale, max_scale = 0, 6 + scalex = 2 ** torch.randint(min_scale, max_scale, (1,), device=device).item() + scaley = 2 ** torch.randint(min_scale, max_scale, (1,), device=device).item() + else: + scalex, scaley = scale - def dot(grad, shift): - return ( - np.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), axis=-1) - * grad[: shape[0], : shape[1]] - ).sum(axis=-1) + # Ensure dimensions are powers of 2 for proper noise generation + def nextpow2(value: int) -> int: + return int(2 ** torch.ceil(torch.log2(torch.tensor(value))).int().item()) - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) - n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) - n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[: shape[0], : shape[1]]) - return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) + pad_h = nextpow2(height) + pad_w = nextpow2(width) + # Generate base grid + delta = (scalex / pad_h, scaley / pad_w) + d = (pad_h // scalex, pad_w // scaley) -def _rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): - """Generate a random image containing Perlin noise. PyTorch version.""" - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) + grid = ( + torch.stack( + torch.meshgrid( + torch.arange(0, scalex, delta[0], device=device), + torch.arange(0, scaley, delta[1], device=device), + indexing="ij", + ), + dim=-1, + ) + % 1 + ) - grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 - angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) + # Generate random gradients + angles = 2 * torch.pi * torch.rand(int(scalex) + 1, int(scaley) + 1, device=device) gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) - def tile_grads(slice1, slice2): + def tile_grads(slice1: list[int | None], slice2: list[int | None]) -> torch.Tensor: return ( gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] - .repeat_interleave(d[0], 0) - .repeat_interleave(d[1], 1) + .repeat_interleave(int(d[0]), 0) + .repeat_interleave(int(d[1]), 1) ) - def dot(grad, shift): + def dot(grad: torch.Tensor, shift: list[float]) -> torch.Tensor: return ( torch.stack( - (grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), + (grid[:pad_h, :pad_w, 0] + shift[0], grid[:pad_h, :pad_w, 1] + shift[1]), dim=-1, ) - * grad[: shape[0], : shape[1]] + * grad[:pad_h, :pad_w] ).sum(dim=-1) + # Calculate noise values at grid points n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[: shape[0], : shape[1]]) - return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) - - -def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): - """Generate Perlin noise parameterized by the octaves method. PyTorch version.""" - noise = torch.zeros(shape) - frequency = 1 - amplitude = 1 - for _ in range(octaves): - noise += amplitude * _rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) - frequency *= 2 - amplitude *= persistence - return noise + + # Interpolate between grid points using quintic curve + def fade(t: torch.Tensor) -> torch.Tensor: + return 6 * t**5 - 15 * t**4 + 10 * t**3 + + t = fade(grid[:pad_h, :pad_w]) + noise = torch.sqrt(torch.tensor(2.0, device=device)) * torch.lerp( + torch.lerp(n00, n10, t[..., 0]), + torch.lerp(n01, n11, t[..., 0]), + t[..., 1], + ) + + # Crop to desired dimensions + return noise[:height, :width] + + +class PerlinAnomalyGenerator(v2.Transform): + """Perlin noise-based synthetic anomaly generator. + + Examples: + >>> # Single image usage with default parameters + >>> transform = PerlinAnomalyGenerator() + >>> image = torch.randn(3, 256, 256) # [C, H, W] + >>> augmented_image, anomaly_mask = transform(image) + >>> print(augmented_image.shape) # [C, H, W] + >>> print(anomaly_mask.shape) # [1, H, W] + + >>> # Batch usage with custom parameters + >>> transform = PerlinAnomalyGenerator( + ... probability=0.8, + ... blend_factor=0.5 + ... ) + >>> batch = torch.randn(4, 3, 256, 256) # [B, C, H, W] + >>> augmented_batch, anomaly_masks = transform(batch) + >>> print(augmented_batch.shape) # [B, C, H, W] + >>> print(anomaly_masks.shape) # [B, 1, H, W] + + >>> # Using anomaly source images + >>> transform = PerlinAnomalyGenerator( + ... anomaly_source_path='path/to/anomaly/images', + ... probability=0.7, + ... blend_factor=(0.3, 0.9), + ... rotation_range=(-45, 45) + ... ) + >>> augmented_image, anomaly_mask = transform(image) + """ + + def __init__( + self, + anomaly_source_path: str | None = None, + probability: float = 0.5, + blend_factor: float | tuple[float, float] = (0.2, 1.0), + rotation_range: tuple[float, float] = (-90, 90), + ) -> None: + super().__init__() + self.probability = probability + self.blend_factor = blend_factor + + # Load anomaly source paths + self.anomaly_source_paths: list[Path] = [] + if anomaly_source_path is not None: + for img_ext in IMG_EXTENSIONS: + self.anomaly_source_paths.extend(Path(anomaly_source_path).rglob("*" + img_ext)) + + # Initialize perlin rotation transform + self.perlin_rotation_transform = v2.RandomAffine( + degrees=rotation_range, + interpolation=v2.InterpolationMode.BILINEAR, + fill=0, + ) + + # Initialize augmenters + self.augmenters = MultiRandomChoice( + transforms=[ + v2.ColorJitter(contrast=(0.5, 2.0)), + v2.RandomPhotometricDistort( + brightness=(0.8, 1.2), + contrast=(1.0, 1.0), # No contrast change + saturation=(1.0, 1.0), # No saturation change + hue=(0.0, 0.0), # No hue change + p=1.0, + ), + v2.RandomAdjustSharpness(sharpness_factor=2.0, p=1.0), + v2.ColorJitter(hue=[-50 / 360, 50 / 360], saturation=[0.5, 1.5]), + v2.RandomSolarize(threshold=torch.empty(1).uniform_(32 / 255, 128 / 255).item(), p=1.0), + v2.RandomPosterize(bits=4, p=1.0), + v2.RandomInvert(p=1.0), + v2.AutoAugment(), + v2.RandomEqualize(p=1.0), + v2.RandomAffine(degrees=(-45, 45), interpolation=v2.InterpolationMode.BILINEAR, fill=0), + ], + probabilities=None, + num_transforms=3, + fixed_num_transforms=True, + ) + + def generate_perturbation( + self, + height: int, + width: int, + device: torch.device | None = None, + anomaly_source_path: Path | str | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Generate perturbed image and mask. + + Args: + height: Height of the output image + width: Width of the output image + device: Device to generate the perturbation on + anomaly_source_path: Optional path to source image for anomaly + + Returns: + tuple[torch.Tensor, torch.Tensor]: Perturbation and mask tensors + """ + # Generate perlin noise + perlin_noise = generate_perlin_noise(height, width, device=device) + + # Create rotated noise pattern + perlin_noise = perlin_noise.unsqueeze(0) # [1, H, W] + perlin_noise = self.perlin_rotation_transform(perlin_noise).squeeze(0) # [H, W] + + # Generate binary mask from perlin noise + mask = torch.where( + perlin_noise > 0.5, + torch.ones_like(perlin_noise, device=device), + torch.zeros_like(perlin_noise, device=device), + ).unsqueeze(-1) # [H, W, 1] + + # Generate anomaly source image + if anomaly_source_path: + anomaly_source_img = ( + io.read_image(str(anomaly_source_path), mode=io.ImageReadMode.RGB).float().to(device) / 255.0 + ) + if anomaly_source_img.shape[-2:] != (height, width): + anomaly_source_img = v2.functional.resize(anomaly_source_img, [height, width], antialias=True) + anomaly_source_img = anomaly_source_img.permute(1, 2, 0) # [H, W, C] + else: + anomaly_source_img = perlin_noise.unsqueeze(-1).repeat(1, 1, 3) # [H, W, C] + anomaly_source_img = (anomaly_source_img * 0.5) + 0.25 # Adjust intensity range + + # Apply augmentations to source image + anomaly_augmented = self.augmenters(anomaly_source_img.permute(2, 0, 1)) # [C, H, W] + anomaly_augmented = anomaly_augmented.permute(1, 2, 0) # [H, W, C] + + # Create final perturbation by applying mask + perturbation = anomaly_augmented * mask + + return perturbation, mask + + def _transform_image( + self, + img: torch.Tensor, + h: int, + w: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Transform a single image.""" + if torch.rand(1, device=device) > self.probability: + return img, torch.zeros((1, h, w), device=device) + + anomaly_source_path = ( + list(self.anomaly_source_paths)[int(torch.randint(len(self.anomaly_source_paths), (1,)).item())] + if self.anomaly_source_paths + else None + ) + + perturbation, mask = self.generate_perturbation(h, w, device, anomaly_source_path) + perturbation = perturbation.permute(2, 0, 1) + mask = mask.permute(2, 0, 1) + + beta = ( + self.blend_factor + if isinstance(self.blend_factor, float) + else torch.rand(1, device=device) * (self.blend_factor[1] - self.blend_factor[0]) + self.blend_factor[0] + if isinstance(self.blend_factor, tuple) + # Add type guard + else torch.tensor(0.5, device=device) # Fallback value + ) + + if not isinstance(beta, float): + beta = beta.view(-1, 1, 1).expand_as(img) + + augmented_img = img * (1 - mask) + beta * perturbation + (1 - beta) * img * mask + return augmented_img, mask + + def forward(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Apply augmentation using the mask for single image or batch.""" + device = img.device + is_batch = len(img.shape) == 4 + + if is_batch: + batch, _, height, width = img.shape + # Initialize batch outputs + batch_augmented = [] + batch_masks = [] + + for i in range(batch): + # Apply transform to each image in batch + augmented, mask = self._transform_image(img[i], height, width, device) + batch_augmented.append(augmented) + batch_masks.append(mask) + + return torch.stack(batch_augmented), torch.stack(batch_masks) + + # Handle single image + return self._transform_image(img, img.shape[1], img.shape[2], device) diff --git a/src/anomalib/data/utils/synthetic.py b/src/anomalib/data/utils/synthetic.py index 16aa20d83d..7d2b340e33 100644 --- a/src/anomalib/data/utils/synthetic.py +++ b/src/anomalib/data/utils/synthetic.py @@ -20,7 +20,8 @@ from anomalib import TaskType from anomalib.data.datasets.base.image import AnomalibDataset -from anomalib.data.utils import Augmenter, Split, read_image +from anomalib.data.utils import Split, read_image +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ def make_synthetic_dataset( anomalous_samples = anomalous_samples.reset_index(drop=True) # initialize augmenter - augmenter = Augmenter("./datasets/dtd", p_anomalous=1.0, beta=(0.01, 0.2)) + augmenter = PerlinAnomalyGenerator(anomaly_source_path="./datasets/dtd", probability=1.0, blend_factor=(0.01, 0.2)) def augment(sample: Series) -> Series: """Apply synthetic anomalous augmentation to a sample from a dataframe. @@ -83,7 +84,7 @@ def augment(sample: Series) -> Series: # read and transform image image = read_image(sample.image_path, as_tensor=True) # apply anomalous perturbation - aug_im, mask = augmenter.augment_batch(image.unsqueeze(0)) + aug_im, mask = augmenter(image) # target file name with leading zeros file_name = f"{str(sample.name).zfill(int(math.log10(n_anomalous)) + 1)}.png" # write image diff --git a/src/anomalib/models/image/draem/lightning_model.py b/src/anomalib/models/image/draem/lightning_model.py index dd02fd168a..66e87a904b 100644 --- a/src/anomalib/models/image/draem/lightning_model.py +++ b/src/anomalib/models/image/draem/lightning_model.py @@ -16,7 +16,7 @@ from anomalib import LearningType from anomalib.data import Batch -from anomalib.data.utils import Augmenter +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor @@ -56,7 +56,7 @@ def __init__( ) -> None: super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) - self.augmenter = Augmenter(anomaly_source_path, beta=beta) + self.augmenter = PerlinAnomalyGenerator(anomaly_source_path=anomaly_source_path, blend_factor=beta) self.model = DraemModel(sspcab=enable_sspcab) self.loss = DraemLoss() self.sspcab = enable_sspcab @@ -110,7 +110,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: input_image = batch.image # Apply corruption to input image - augmented_image, anomaly_mask = self.augmenter.augment_batch(input_image) + augmented_image, anomaly_mask = self.augmenter(input_image) # Generate model prediction reconstruction, prediction = self.model(augmented_image) # Compute loss diff --git a/src/anomalib/models/image/dsr/anomaly_generator.py b/src/anomalib/models/image/dsr/anomaly_generator.py index 396019de39..9bb262500c 100644 --- a/src/anomalib/models/image/dsr/anomaly_generator.py +++ b/src/anomalib/models/image/dsr/anomaly_generator.py @@ -3,12 +3,11 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import imgaug.augmenters as iaa -import numpy as np import torch from torch import Tensor, nn +from torchvision.transforms import v2 -from anomalib.data.utils.generators.perlin import _rand_perlin_2d_np +from anomalib.data.utils.generators.perlin import generate_perlin_noise class DsrAnomalyGenerator(nn.Module): @@ -29,7 +28,8 @@ def __init__( super().__init__() self.p_anomalous = p_anomalous - self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) + # Replace imgaug with torchvision transform + self.rot = v2.RandomAffine(degrees=(-90, 90)) def generate_anomaly(self, height: int, width: int) -> Tensor: """Generate an anomalous mask. @@ -43,15 +43,20 @@ def generate_anomaly(self, height: int, width: int) -> Tensor: """ min_perlin_scale = 0 perlin_scale = 6 - perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) - perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) + perlin_scalex = int(2 ** torch.randint(min_perlin_scale, perlin_scale, (1,)).item()) + perlin_scaley = int(2 ** torch.randint(min_perlin_scale, perlin_scale, (1,)).item()) threshold = 0.5 - perlin_noise_np = _rand_perlin_2d_np((height, width), (perlin_scalex, perlin_scaley)) - perlin_noise_np = self.rot(image=perlin_noise_np) - mask = np.where(perlin_noise_np > threshold, np.ones_like(perlin_noise_np), np.zeros_like(perlin_noise_np)) - mask = np.expand_dims(mask, axis=2).astype(np.float32) - return torch.from_numpy(mask) + # Generate perlin noise using the new function + perlin_noise = generate_perlin_noise(height, width, scale=(perlin_scalex, perlin_scaley)) + + # Apply random rotation + perlin_noise = perlin_noise.unsqueeze(0) # Add channel dimension for transform + perlin_noise = self.rot(perlin_noise).squeeze(0) # Remove channel dimension + + # Create binary mask + mask = (perlin_noise > threshold).float() + return mask.unsqueeze(0) # Add channel dimension [1, H, W] def augment_batch(self, batch: Tensor) -> Tensor: """Generate anomalous augmentations for a batch of input images. @@ -71,6 +76,6 @@ def augment_batch(self, batch: Tensor) -> Tensor: masks_list.append(torch.zeros((1, height, width))) else: mask = self.generate_anomaly(height, width) - masks_list.append(mask.permute((2, 0, 1))) + masks_list.append(mask) return torch.stack(masks_list).to(batch.device) diff --git a/src/anomalib/models/image/dsr/lightning_model.py b/src/anomalib/models/image/dsr/lightning_model.py index 23daa6b95d..8aa3de08e2 100644 --- a/src/anomalib/models/image/dsr/lightning_model.py +++ b/src/anomalib/models/image/dsr/lightning_model.py @@ -17,7 +17,7 @@ from anomalib import LearningType from anomalib.data import Batch from anomalib.data.utils import DownloadInfo, download_and_extract -from anomalib.data.utils.augmenter import Augmenter +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule from anomalib.models.image.dsr.anomaly_generator import DsrAnomalyGenerator @@ -62,7 +62,7 @@ def __init__( self.upsampling_train_ratio = upsampling_train_ratio self.quantized_anomaly_generator = DsrAnomalyGenerator() - self.perlin_generator = Augmenter() + self.perlin_generator = PerlinAnomalyGenerator() self.model = DsrModel(latent_anomaly_strength) self.second_stage_loss = DsrSecondStageLoss() self.third_stage_loss = DsrThirdStageLoss() @@ -158,7 +158,7 @@ def training_step(self, batch: Batch) -> STEP_OUTPUT: # we are training the upsampling module input_image = batch.image # Generate anomalies - input_image, anomaly_maps = self.perlin_generator.augment_batch(input_image) + input_image, anomaly_maps = self.perlin_generator(input_image) # Get model prediction model_outputs = self.model(input_image) # Calculate loss diff --git a/tests/helpers/data.py b/tests/helpers/data.py index 60433df9eb..e1efccc1b1 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -18,7 +18,8 @@ from skimage.io import imsave from anomalib.data import DataFormat -from anomalib.data.utils import Augmenter, LabelName +from anomalib.data.utils import LabelName +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator class DummyImageGenerator: @@ -47,7 +48,7 @@ class DummyImageGenerator: def __init__(self, image_shape: tuple[int, int] = (256, 256), rng: np.random.Generator | None = None) -> None: self.image_shape = image_shape - self.augmenter = Augmenter() + self.augmenter = PerlinAnomalyGenerator() self.rng = rng if rng else np.random.default_rng() def generate_normal_image(self) -> tuple[np.ndarray, np.ndarray]: @@ -72,6 +73,8 @@ def generate_abnormal_image(self, beta: float = 0.2) -> tuple[np.ndarray, np.nda # Generate perturbation. perturbation, mask = self.augmenter.generate_perturbation(height=self.image_shape[0], width=self.image_shape[1]) + perturbation = perturbation.cpu().numpy() + mask = mask.cpu().numpy() # Superimpose perturbation on image ``img``. abnormal_image = (image * (1 - mask) + (beta) * perturbation + (1 - beta) * image * (mask)).astype(np.uint8) diff --git a/tests/integration/model/test_models.py b/tests/integration/model/test_models.py index 464c2cb1da..e78ad19fe0 100644 --- a/tests/integration/model/test_models.py +++ b/tests/integration/model/test_models.py @@ -6,6 +6,9 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import contextlib +import sys +from collections.abc import Generator from pathlib import Path from unittest.mock import MagicMock @@ -28,6 +31,17 @@ def export_types() -> list[ExportType]: return list(ExportType) +@contextlib.contextmanager +def increased_recursion_limit(limit: int = 10000) -> Generator[None, None, None]: + """Temporarily increase the recursion limit.""" + old_limit = sys.getrecursionlimit() + try: + sys.setrecursionlimit(limit) + yield + finally: + sys.setrecursionlimit(old_limit) + + class TestAPI: """Do sanity check on all models.""" @@ -154,11 +168,14 @@ def test_export( dataset_path=dataset_path, project_path=project_path, ) - engine.export( - model=model, - ckpt_path=f"{project_path}/{model.name}/{dataset.name}/dummy/v0/weights/lightning/model.ckpt", - export_type=export_type, - ) + + # Use context manager only for CSFlow + with increased_recursion_limit() if model_name == "csflow" else contextlib.nullcontext(): + engine.export( + model=model, + ckpt_path=f"{project_path}/{model.name}/{dataset.name}/dummy/v0/weights/lightning/model.ckpt", + export_type=export_type, + ) @staticmethod def _get_objects( diff --git a/tests/unit/data/validators/numpy/test_image.py b/tests/unit/data/validators/numpy/test_image.py index 008bc4dff6..e81793aeb7 100644 --- a/tests/unit/data/validators/numpy/test_image.py +++ b/tests/unit/data/validators/numpy/test_image.py @@ -180,12 +180,13 @@ def test_validate_gt_label_none(self) -> None: """Test validation of None ground truth labels.""" assert self.validator.validate_gt_label(None) is None - def test_validate_gt_label_valid_string_input(self) -> None: - """Test validation of ground truth labels with string input.""" - validated_labels = self.validator.validate_gt_label(["0", "1"]) + def test_validate_gt_label_valid_sequence(self) -> None: + """Test validation of ground truth labels with sequence input.""" + # Test with binary labels (0: normal, 1: anomaly) + validated_labels = self.validator.validate_gt_label([0, 1, 1, 0]) assert isinstance(validated_labels, np.ndarray) assert validated_labels.dtype == bool - assert np.array_equal(validated_labels, np.array([False, True])) + assert np.array_equal(validated_labels, np.array([False, True, True, False])) def test_validate_gt_label_invalid_dimensions(self) -> None: """Test validation of ground truth labels with invalid dimensions.""" diff --git a/tests/unit/data/validators/numpy/test_video.py b/tests/unit/data/validators/numpy/test_video.py index abf29d31d9..675a462630 100644 --- a/tests/unit/data/validators/numpy/test_video.py +++ b/tests/unit/data/validators/numpy/test_video.py @@ -76,6 +76,15 @@ def test_validate_target_frame_negative(self) -> None: with pytest.raises(ValueError, match="Target frame index must be non-negative"): self.validator.validate_target_frame(-1) + def test_validate_gt_label_valid(self) -> None: + """Test validation of a valid ground truth label.""" + # Test with binary label (0: normal, 1: anomaly) + label = 1 + validated_label = self.validator.validate_gt_label(label) + assert isinstance(validated_label, np.ndarray) + assert validated_label.dtype == bool + assert validated_label.item() is True + class TestNumpyVideoBatchValidator: """Test NumpyVideoBatchValidator.""" @@ -141,14 +150,21 @@ def test_validate_gt_label_none(self) -> None: """Test validation of None ground truth labels.""" assert self.validator.validate_gt_label(None) is None - def test_validate_gt_label_invalid_type(self) -> None: - """Test validation of ground truth labels with invalid type.""" - validated_labels = self.validator.validate_gt_label(["0", "1"]) - assert validated_labels is not None + def test_validate_gt_label_valid_sequence(self) -> None: + """Test validation of ground truth labels with sequence input.""" + # Test with binary labels (0: normal, 1: anomaly) + labels = [0, 1] + validated_labels = self.validator.validate_gt_label(labels) assert isinstance(validated_labels, np.ndarray) assert validated_labels.dtype == bool assert np.array_equal(validated_labels, np.array([False, True])) + def test_validate_gt_label_invalid_type(self) -> None: + """Test validation of ground truth labels with invalid type.""" + # Test with a non-sequence, non-array type + with pytest.raises(TypeError, match="Ground truth label batch must be a numpy.ndarray"): + self.validator.validate_gt_label(3.14) + def test_validate_gt_label_invalid_dimensions(self) -> None: """Test validation of ground truth labels with invalid dimensions.""" with pytest.raises(ValueError, match="Ground truth label batch must be 1-dimensional, got shape \\(2, 2\\)"): diff --git a/tests/unit/metrics/test_pro.py b/tests/unit/metrics/test_pro.py index 21f26c3349..fe6e149cb1 100644 --- a/tests/unit/metrics/test_pro.py +++ b/tests/unit/metrics/test_pro.py @@ -6,7 +6,7 @@ import torch from torchvision.transforms import RandomAffine -from anomalib.data.utils import random_2d_perlin +from anomalib.data.utils.generators.perlin import generate_perlin_noise from anomalib.metrics.pro import _PRO as PRO from anomalib.metrics.pro import connected_components_cpu, connected_components_gpu @@ -50,7 +50,7 @@ def test_device_consistency() -> None: batch = torch.zeros((32, 256, 256)) for i in range(batch.shape[0]): - batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5 + batch[i, ...] = generate_perlin_noise(256, 256, scale=(4, 4)) > 0.5 # ground truth mask is int type batch = batch.type(torch.int32) @@ -70,7 +70,7 @@ def test_connected_component_labeling() -> None: # generate batch of random binary images using perlin noise batch = torch.zeros((32, 1, 256, 256)) for i in range(batch.shape[0]): - batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5 + batch[i, ...] = generate_perlin_noise(256, 256, scale=(4, 4)) > 0.5 # get connected component results on both cpu and gpu cc_cpu = connected_components_cpu(batch.cpu())