Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€Add SuperSimpleNet model #2428

Open
wants to merge 33 commits into
base: release/v2.0.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
90625bb
Initial code structure
blaz-r Oct 29, 2024
4d5e480
Add feature extraction
blaz-r Oct 29, 2024
742e0ca
Add adaptor and segdec
blaz-r Nov 1, 2024
fd95793
Format files
blaz-r Nov 1, 2024
26e6f5a
Add anomaly generation
blaz-r Nov 1, 2024
eeddb15
Add configurable params and docstrings
blaz-r Nov 8, 2024
090120d
Add dosctring and fix lint issue
blaz-r Nov 8, 2024
2be15bd
Implement loss and training step
blaz-r Nov 8, 2024
f986c8e
Update torchFX to also support str weights
blaz-r Nov 11, 2024
e08366d
Update names
blaz-r Nov 17, 2024
e45b6d4
Add supervision based param settings and validation step
blaz-r Nov 17, 2024
8768607
Add optimizer configs
blaz-r Nov 17, 2024
0e5a6c7
Fix loss and types
blaz-r Nov 17, 2024
1238512
Add SSN to init
blaz-r Nov 17, 2024
28006f1
Add SSN description
blaz-r Nov 19, 2024
40e7475
Update supersimplenet README.md
blaz-r Nov 20, 2024
ce6c7f5
Update readme with arch and results
blaz-r Nov 20, 2024
a47fc50
Update SuperSimpleNet README.md
blaz-r Nov 21, 2024
3c65da6
Update architecture image
blaz-r Nov 21, 2024
07e02fb
Add SuperSimpleNet to init
blaz-r Nov 21, 2024
fda2c4e
Fix copyright location and format
blaz-r Nov 21, 2024
8530a28
Merge branch 'feature/v2' into feature/supersimplenet
blaz-r Nov 25, 2024
122ec03
Update SSN lightning model for v2
blaz-r Nov 25, 2024
fe57b27
Fix cls head to support onnx export
blaz-r Nov 25, 2024
952653b
Merge branch 'feature/v2' into feature/supersimplenet
samet-akcay Nov 27, 2024
fd499ad
Merge remote-tracking branch 'refs/remotes/og/release/v2.0.0' into fe…
blaz-r Dec 23, 2024
521aee7
Update SSN ano. gen. with new perlin method
blaz-r Dec 23, 2024
6687abc
Add latest v2 structure to SSN
blaz-r Dec 23, 2024
50032c6
Add docstring at the beginning of SSN files.
blaz-r Dec 23, 2024
5bcb533
Add SSN to init files
blaz-r Dec 23, 2024
a569e03
Include SSN in docs
blaz-r Dec 24, 2024
bb142dd
Verify and update ssn res for refactored perlin
blaz-r Dec 25, 2024
64fba6c
Update readme with res for refactored perlin
blaz-r Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions src/anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ReverseDistillation,
Rkde,
Stfpm,
SuperSimpleNet,
Uflow,
WinClip,
)
Expand Down Expand Up @@ -56,6 +57,7 @@ class UnknownModelError(ModuleNotFoundError):
"ReverseDistillation",
"Rkde",
"Stfpm",
"SuperSimpleNet",
"Uflow",
"AiVad",
"WinClip",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ class can be provided and it will try to load the weights from the provided weig
backbone_class = backbone.class_path
backbone_model = backbone_class(**backbone.init_args)

if isinstance(weights, WeightsEnum): # torchvision models
if isinstance(weights, WeightsEnum) or weights in {
"IMAGENET1K_V1",
"IMAGENET1K_V2",
"DEFAULT",
}: # torchvision models
feature_extractor = create_feature_extractor(model=backbone_model, return_nodes=return_nodes)
elif weights is not None:
if not isinstance(weights, str):
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .reverse_distillation import ReverseDistillation
from .rkde import Rkde
from .stfpm import Stfpm
from .supersimplenet import SuperSimpleNet
from .uflow import Uflow
from .winclip import WinClip

Expand All @@ -39,6 +40,7 @@
"ReverseDistillation",
"Rkde",
"Stfpm",
"SuperSimpleNet",
"Uflow",
"WinClip",
]
29 changes: 29 additions & 0 deletions src/anomalib/models/image/supersimplenet/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Copyright (c) 2024 Intel Corporation
SPDX-License-Identifier: Apache-2.0

Some files in this folder are based on the original SuperSimpleNet implementation by BlaΕΎ Rolih

Original license:
-----------------

MIT License

Copyright (c) 2024 BlaΕΎ Rolih

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
57 changes: 57 additions & 0 deletions src/anomalib/models/image/supersimplenet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection

This is an implementation of the [SuperSimpleNet](https://arxiv.org/pdf/2408.03143) paper, based on the [official code](https://github.com/blaz-r/SuperSimpleNet).

Model Type: Segmentation

## Description

**SuperSimpleNet** is a simple yet strong discriminative defect / anomaly detection model evolved from the SimpleNet architecture. It consists of four components:
feature extractor with upscaling, feature adaptor, synthetic feature-level anomaly generation module, and
segmentation-detection module.

A ResNet-like feature extractor first extracts features, which are then upscaled and
average-pooled to capture neighboring context. Features are further refined for anomaly detection task in the adaptor module.
During training, synthetic anomalies are generated at the feature level by adding Gaussian noise to regions defined by the
binary Perlin noise mask. The perturbed features are then fed into the segmentation-detection
module, which produces the anomaly map and the anomaly score. During inference, anomaly generation is skipped, and the model
directly predicts the anomaly map and score. The predicted anomaly map is upscaled to match the input image size
and refined with a Gaussian filter.

This implementation supports both unsupervised and supervised setting, but Anomalib currently supports only unsupervised learning.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the missing points in Anomalib to support supervised setting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now I believe there are no standard supervised datasets. Another problem is the Folder dataset as it assumes that abnormal samples are always in test set:

samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN
samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST

Another thing for full reproduction of SuperSimpleNet results is the fixed flipping augmentation and frequency sampling. This is however not necessary, but needed for best results. It's also not SuperSimpleNet specific, so might be worth considering if other supervised model will be supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we would like to diversify the model pool, and include more learning types than one-class models. Thanks for the feedback.
@abc-125, you might want to be aware of this discussion as you have recently worked on this stuff

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding me, it would be great to have supervised models and datasets in Anomalib. Recently, I looked at how to add a supervised dataset, and it certainly would require changing some base structures, such as paths to folders (I guess we will need abnormal_train_dir and maybe renaming the rest to make it easier to understand, normal_train_dir, etc.):

normal_dir (str | Path | Sequence): Path to the directory containing normal images.
root (str | Path | None): Root folder of the dataset.
Defaults to ``None``.
abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images.
Defaults to ``None``.
normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing
normal images for the test dataset.
Defaults to ``None``.


## Architecture

![SuperSimpleNet architecture](/docs/source/images/supersimplenet/architecture.png "SuperSimpleNet architecture")

## Usage

`anomalib train --model SuperSimpleNet --data MVTec --data.category <category>`

> It is recommended to train the model for 300 epochs with batch size of 32 to achieve stable training with random anomaly generation. Training with lower parameter values will still work, but might not yield the optimal results.
>
> For supervised learning, refer to the [official code](https://github.com/blaz-r/SuperSimpleNet).

## MVTec AD results

The following results were obtained using this Anomalib implementation trained for 300 epochs with seed 42, default params, and batch size 32.
| | **Image AUROC** | **Pixel AUPRO** |
| ----------- | :-------------: | :-------------: |
| Bottle | 1.000 | 0.914 |
| Cable | 0.981 | 0.895 |
| Capsule | 0.990 | 0.926 |
| Carpet | 0.987 | 0.936 |
| Grid | 0.998 | 0.935 |
| Hazelnut | 0.999 | 0.946 |
| Leather | 1.000 | 0.972 |
| Metal_nut | 0.996 | 0.923 |
| Pill | 0.960 | 0.942 |
| Screw | 0.903 | 0.952 |
| Tile | 0.989 | 0.817 |
| Toothbrush | 0.917 | 0.861 |
| Transistor | 1.000 | 0.909 |
| Wood | 0.996 | 0.868 |
| Zipper | 0.996 | 0.944 |
| **Average** | 0.981 | 0.916 |

For other results on VisA, SensumSODF, and KSDD2, refer to the [paper](https://arxiv.org/pdf/2408.03143).
8 changes: 8 additions & 0 deletions src/anomalib/models/image/supersimplenet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""SuperSimpleNet model."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import SuperSimpleNet

__all__ = ["SuperSimpleNet"]
163 changes: 163 additions & 0 deletions src/anomalib/models/image/supersimplenet/anomaly_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Anomaly generator for the SuperSimplenet model implementation."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F # noqa: N812
from torch import nn

from anomalib.data.utils.generators.perlin import _rand_perlin_2d


class SSNAnomalyGenerator(nn.Module):
"""Anomaly generator of the SuperSimpleNet model."""

def __init__(
self,
noise_mean: float,
noise_std: float,
threshold: float,
perlin_range: tuple[int, int] = (0, 6),
) -> None:
super().__init__()

self.noise_mean = noise_mean
self.noise_std = noise_std

self.threshold = threshold

self.min_perlin_scale = perlin_range[0]
self.max_perlin_scale = perlin_range[1]

@staticmethod
def next_power_2(num: int) -> int:
"""Get the next power of 2 for given number.

Args:
num (int): value of interest

Returns:
next power of 2 value for given number
"""
return 1 << (num - 1).bit_length()

def generate_perlin(self, batches: int, height: int, width: int) -> torch.Tensor:
"""Generate 2d perlin noise masks with dims [b, 1, h, w].

Args:
batches (int): number of batches (different masks)
height (int): height of features
width (int): width of features

Returns:
tensor with b perlin binarized masks
"""
perlin = []
for _ in range(batches):
# get scale of perlin in x and y direction
perlin_scalex = 2 ** (
torch.randint(
self.min_perlin_scale,
self.max_perlin_scale,
(1,),
).item()
)
perlin_scaley = 2 ** (
torch.randint(
self.min_perlin_scale,
self.max_perlin_scale,
(1,),
).item()
)

perlin_height = self.next_power_2(height)
perlin_width = self.next_power_2(width)

perlin_noise = _rand_perlin_2d(
(perlin_height, perlin_width),
(perlin_scalex, perlin_scaley),
)
# original is power of 2 scale, so fit to our size
perlin_noise = F.interpolate(
perlin_noise.reshape(1, 1, perlin_height, perlin_width),
size=(height, width),
mode="bilinear",
)
# binarize
perlin_thr = torch.where(perlin_noise > self.threshold, 1, 0)

# 50% of anomaly
if torch.rand(1).item() > 0.5:
perlin_thr = torch.zeros_like(perlin_thr)

perlin.append(perlin_thr)
return torch.cat(perlin)

def forward(
self,
features: torch.Tensor,
mask: torch.Tensor,
labels: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate anomaly on features using thresholded perlin noise and Gaussian noise.

Also update GT masks and labels with new anomaly information.

Args:
features: input features.
mask: GT masks.
labels: GT labels.

Returns:
perturbed features, updated GT masks and labels.
"""
b, _, h, w = features.shape

# duplicate
features = torch.cat((features, features))
mask = torch.cat((mask, mask))
labels = torch.cat((labels, labels))

noise = torch.normal(
mean=self.noise_mean,
std=self.noise_std,
size=features.shape,
device=features.device,
requires_grad=False,
)

# mask indicating which regions will have noise applied
# [B * 2, 1, H, W] initial all masked as anomalous
noise_mask = torch.ones(
b * 2,
1,
h,
w,
device=features.device,
requires_grad=False,
)

# no overlap: don't apply to already anomalous regions (mask=1 -> bad)
noise_mask = noise_mask * (1 - mask)

# shape of noise is [B * 2, 1, H, W]
perlin_mask = self.generate_perlin(b * 2, h, w).to(features.device)
# only apply where perlin mask is 1
noise_mask = noise_mask * perlin_mask

# update gt mask
mask = mask + noise_mask
# binarize
mask = torch.where(mask > 0, torch.ones_like(mask), torch.zeros_like(mask))

# make new labels. 1 if any part of mask is 1, 0 otherwise
new_anomalous = noise_mask.reshape(b * 2, -1).any(dim=1).type(torch.float32)
labels = labels + new_anomalous
# binarize
labels = torch.where(labels > 0, torch.ones_like(labels), torch.zeros_like(labels))

# apply masked noise
perturbed = features + noise * noise_mask

return perturbed, mask, labels
Loading
Loading