-
Notifications
You must be signed in to change notification settings - Fork 696
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🔨 Replace
imgaug
with Native PyTorch Transforms (#2436)
* Add multi random choice transform * Add DRAEMAugmenter class and Perlin noise generation to new_perlin.py - Introduced DRAEMAugmenter for advanced image augmentations using torchvision v2. - Implemented various augmentation techniques including ColorJitter, RandomAdjustSharpness, and custom transformations. - Added functionality for comparing augmentation methods and visualizing results. - Included utility functions for metrics computation and image processing. - Established logging for better traceability of operations. This commit enhances the image processing capabilities within the Anomalib framework, facilitating more robust anomaly detection workflows. * Add the new perlin noise Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Add the new perlin noise Signed-off-by: Samet Akcay <samet.akcay@intel.com> * add generate_perlin_noise relative import Signed-off-by: Samet Akcay <samet.akcay@intel.com> * add tiffile as a dependency Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Remove upper bound from wandb Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Added skimage Signed-off-by: Samet Akcay <samet.akcay@intel.com> * add scikit-learn as a dependency Signed-off-by: Samet Akcay <samet.akcay@intel.com> * limit ollama to < 0.4.0 as it has breaking changes Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Fix data generators in test helpers Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Update the perlin augmenters Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Fix numpy validator tests caused by numpy upgrade Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Fix CS-Flow tests Signed-off-by: Samet Akcay <samet.akcay@intel.com> * Fix the tests Signed-off-by: Samet Akcay <samet.akcay@intel.com> --------- Signed-off-by: Samet Akcay <samet.akcay@intel.com>
- Loading branch information
1 parent
c16f51e
commit 2f3d616
Showing
16 changed files
with
453 additions
and
342 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Multi random choice transform.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from collections.abc import Callable, Sequence | ||
|
||
import torch | ||
from torchvision.transforms import v2 | ||
|
||
|
||
class MultiRandomChoice(v2.Transform): | ||
"""Apply multiple transforms randomly picked from a list. | ||
This transform does not support torchscript. | ||
Args: | ||
transforms (sequence or torch.nn.Module): List of transformations to choose from. | ||
probabilities (list[float] | None, optional): Probability of each transform being picked. | ||
If None (default), all transforms have equal probability. If provided, probabilities | ||
will be normalized to sum to 1. | ||
num_transforms (int): Maximum number of transforms to apply at once. | ||
Defaults to ``1``. | ||
fixed_num_transforms (bool): If ``True``, always applies exactly ``num_transforms`` transforms. | ||
If ``False``, randomly picks between 1 and ``num_transforms``. | ||
Defaults to ``False``. | ||
Examples: | ||
>>> import torchvision.transforms.v2 as v2 | ||
>>> transforms = [ | ||
... v2.RandomHorizontalFlip(p=1.0), | ||
... v2.ColorJitter(brightness=0.5), | ||
... v2.RandomRotation(10), | ||
... ] | ||
>>> # Apply 1-2 random transforms with equal probability | ||
>>> transform = MultiRandomChoice(transforms, num_transforms=2) | ||
>>> # Always apply exactly 2 transforms with custom probabilities | ||
>>> transform = MultiRandomChoice( | ||
... transforms, | ||
... probabilities=[0.5, 0.3, 0.2], | ||
... num_transforms=2, | ||
... fixed_num_transforms=True | ||
... ) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
transforms: Sequence[Callable], | ||
probabilities: list[float] | None = None, | ||
num_transforms: int = 1, | ||
fixed_num_transforms: bool = False, | ||
) -> None: | ||
if not isinstance(transforms, Sequence): | ||
msg = "Argument transforms should be a sequence of callables" | ||
raise TypeError(msg) | ||
|
||
if probabilities is None: | ||
probabilities = [1.0] * len(transforms) | ||
elif len(probabilities) != len(transforms): | ||
msg = f"Length of p doesn't match the number of transforms: {len(probabilities)} != {len(transforms)}" | ||
raise ValueError(msg) | ||
|
||
super().__init__() | ||
|
||
self.transforms = transforms | ||
total = sum(probabilities) | ||
self.probabilities = [probability / total for probability in probabilities] | ||
|
||
self.num_transforms = num_transforms | ||
self.fixed_num_transforms = fixed_num_transforms | ||
|
||
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: | ||
"""Apply randomly selected transforms to the input.""" | ||
# First determine number of transforms to apply | ||
num_transforms = ( | ||
self.num_transforms if self.fixed_num_transforms else int(torch.randint(self.num_transforms, (1,)) + 1) | ||
) | ||
# Get transforms | ||
idx = torch.multinomial(torch.tensor(self.probabilities), num_transforms).tolist() | ||
transform = v2.Compose([self.transforms[i] for i in idx]) | ||
return transform(*inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.