Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dick/anomaly score normalization #35

Merged
merged 44 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1e2ac6e
implement anomaly score and anomaly map normalization
djdameln Dec 14, 2021
95c15c1
switch back to callback design
djdameln Dec 15, 2021
cd1612b
fix docstrings and typing
djdameln Dec 15, 2021
149e026
always cast thresholds to float
djdameln Dec 15, 2021
0cdb25c
improve logic of training stats metric
djdameln Dec 15, 2021
122bbb6
switch to torchvision feature extractor
djdameln Dec 15, 2021
5d82316
update configs
djdameln Dec 15, 2021
b4246ab
merge development
djdameln Dec 16, 2021
891c709
switch back to saving training stats in model
djdameln Dec 16, 2021
95de8fc
Tensor -> tensor
djdameln Dec 16, 2021
d84ea11
loading checkpoint no longer needed
djdameln Dec 16, 2021
1cb7a15
add normalization to inferencer
djdameln Dec 16, 2021
21dfefb
subtract image mean from anomaly maps
djdameln Dec 16, 2021
ae70160
small refactor
djdameln Dec 16, 2021
8416336
switch to torchmetrics design for threshold computation
djdameln Dec 16, 2021
a41d7c1
import training stats from init
djdameln Dec 16, 2021
1494e5a
small fix
djdameln Dec 16, 2021
9a196b2
switch to torchmetrics for persisting training stats
djdameln Dec 17, 2021
ac4573e
add test case for normalization callback
djdameln Dec 17, 2021
d05acb9
fix visualizer
djdameln Dec 17, 2021
54cd83e
fix mypy issues
djdameln Dec 17, 2021
4c1a756
fix compression tests
djdameln Dec 17, 2021
cecb23e
remove print statement
djdameln Dec 17, 2021
721ea0b
revert checkpoint loading
djdameln Dec 17, 2021
4a3b2b1
revert changing weight path
djdameln Dec 17, 2021
416a4c8
Merge branch 'development' into feature/dick/anomaly-score-normalization
djdameln Dec 17, 2021
95067fe
rename normalization callback
djdameln Dec 17, 2021
8453260
rename anomaly score dsitrbution class
djdameln Dec 17, 2021
e30d870
change function ordering
djdameln Dec 17, 2021
895bf1e
remove cuda version from torch and torchvision
djdameln Dec 17, 2021
5d46732
add deprecation warning to feature extractor.
djdameln Dec 17, 2021
918f422
training_stats -> training_distribution
djdameln Dec 17, 2021
a55e6d5
update requirements
djdameln Dec 20, 2021
7a07fa4
revert to anomalib feature extractor
djdameln Dec 20, 2021
9853828
workaround for torch 1.8.1 compatibility
djdameln Dec 20, 2021
aee4802
rename normalization callback
djdameln Dec 20, 2021
497d3bc
merge development
djdameln Dec 20, 2021
c002f2a
Revert "add deprecation warning to feature extractor."
djdameln Dec 20, 2021
fc01635
Add batch size support to patchcore
samet-akcay Dec 21, 2021
fc8b3d9
Score Normalization doesnt work for Patchcore
samet-akcay Dec 21, 2021
00dc01f
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay Dec 21, 2021
a56ba72
check to prevent using both normalization and nncf
djdameln Dec 21, 2021
e23f82d
use get_dataset_path
djdameln Dec 21, 2021
1ffb0fb
Merge branch 'feature/dick/anomaly-score-normalization' of github.com…
djdameln Dec 21, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,8 @@ def get_configurable_parameters(
config = update_nncf_config(config)
config = update_device_config(config, openvino)

# thresholding
if "pixel_default" not in config.model.threshold.keys():
config.model.threshold.pixel_default = config.model.threshold.image_default

return config
13 changes: 9 additions & 4 deletions anomalib/core/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .compress import CompressModelCallback
from .model_loader import LoadModelCallback
from .normalization import OutputNormalizationCallback
from .save_to_csv import SaveToCSVCallback
from .timer import TimerCallback
from .visualizer_callback import VisualizerCallback
Expand Down Expand Up @@ -46,8 +47,15 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:

callbacks.extend([checkpoint, TimerCallback()])

if "weight_file" in config.model.keys():
load_model = LoadModelCallback(os.path.join(config.project.path, config.model.weight_file))
callbacks.append(load_model)

if "normalize_scores" in config.model.keys() and config.model.normalize_scores:
callbacks.append(OutputNormalizationCallback())

if not config.project.log_images_to == []:
callbacks.append(VisualizerCallback())
callbacks.append(VisualizerCallback(inputs_are_normalized=config.model.normalize_scores))

if "optimization" in config.keys():
if config.optimization.nncf.apply:
Expand All @@ -70,9 +78,6 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
filename="compressed_model",
)
)
if "weight_file" in config.model.keys():
load_model = LoadModelCallback(os.path.join(config.project.path, config.model.weight_file))
callbacks.append(load_model)

if "save_to_csv" in config.project.keys():
if config.project.save_to_csv:
Expand Down
111 changes: 111 additions & 0 deletions anomalib/core/callbacks/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Anomaly Score Normalization Callback."""
import copy
from typing import Any, Dict, Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.distributions import LogNormal, Normal


class OutputNormalizationCallback(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the name of your PR? :) Like AnomalyScoreNormalizationCallback or ScoreNormalizationCallback. It doesn't matter if it's verbose. What's more important is its readability

"""Callback that standardizes the image-level and pixel-level anomaly scores."""

def __init__(self):
self.image_dist: Optional[LogNormal] = None
self.pixel_dist: Optional[LogNormal] = None

def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Called when the test begins."""
pl_module.image_metrics.F1.threshold = 0.5
pl_module.pixel_metrics.F1.threshold = 0.5

def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, _unused: Optional[Any] = None
) -> None:
"""Called when the train epoch ends.

Use the current model to compute the anomaly score distributions
of the normal training data. This is needed after every epoch, because the statistics must be
stored in the state dict of the checkpoint file.
"""
self._collect_stats(trainer, pl_module)

def on_validation_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
_dataloader_idx: int,
) -> None:
"""Called when the validation batch ends, standardizes the predicted scores and anomaly maps."""
self._standardize(outputs, pl_module)

def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
_dataloader_idx: int,
) -> None:
"""Called when the test batch ends, normalizes the predicted scores and anomaly maps."""
self._standardize(outputs, pl_module)
self._normalize(outputs, pl_module)

def on_predict_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Dict,
_batch: Any,
_batch_idx: int,
_dataloader_idx: int,
) -> None:
"""Called when the predict batch ends, normalizes the predicted scores and anomaly maps."""
self._standardize(outputs, pl_module)
self._normalize(outputs, pl_module)
outputs["pred_labels"] = outputs["pred_scores"] >= 0.5

def _collect_stats(self, trainer, pl_module):
"""Collect the statistics of the normal training data.

Create a trainer and use it to predict the anomaly maps and scores of the normal training data. Then
estimate the distribution of anomaly scores for normal data at the image and pixel level by computing
the mean and standard deviations. A dictionary containing the computed statistics is stored in self.stats.
"""
predictions = Trainer(gpus=trainer.gpus).predict(
model=copy.deepcopy(pl_module), dataloaders=trainer.datamodule.train_dataloader()
)
pl_module.training_stats.reset()
for batch in predictions:
if "pred_scores" in batch.keys():
pl_module.training_stats.update(anomaly_scores=batch["pred_scores"])
if "anomaly_maps" in batch.keys():
pl_module.training_stats.update(anomaly_maps=batch["anomaly_maps"])
pl_module.training_stats.compute()

def _standardize(self, outputs: STEP_OUTPUT, pl_module) -> None:
"""Standardize the predicted scores and anomaly maps to the z-domain."""
stats = pl_module.training_stats.to(outputs["pred_scores"].device)

outputs["pred_scores"] = torch.log(outputs["pred_scores"])
outputs["pred_scores"] = (outputs["pred_scores"] - stats.image_mean) / stats.image_std
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = (torch.log(outputs["anomaly_maps"]) - stats.pixel_mean) / stats.pixel_std
outputs["anomaly_maps"] -= (stats.image_mean - stats.pixel_mean) / stats.pixel_std

def _normalize(self, outputs: STEP_OUTPUT, pl_module: pl.LightningModule) -> None:
"""Normalize the predicted scores and anomaly maps by first standardizing and then computing the CDF."""
device = outputs["pred_scores"].device
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()

norm = Normal(torch.Tensor([0]), torch.Tensor([1]))
outputs["pred_scores"] = norm.cdf(outputs["pred_scores"].cpu() - image_threshold).to(device)
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = norm.cdf(outputs["anomaly_maps"].cpu() - pixel_threshold).to(device)
18 changes: 13 additions & 5 deletions anomalib/core/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class VisualizerCallback(Callback):
config.yaml file.
"""

def __init__(self):
def __init__(self, inputs_are_normalized: bool = True):
"""Visualizer callback."""
self.inputs_are_normalized = inputs_are_normalized

def _add_images(
self,
Expand Down Expand Up @@ -80,15 +81,22 @@ def on_test_batch_end(
"""
assert outputs is not None

if self.inputs_are_normalized:
threshold = 0.5
normalize = False # anomaly maps are already normalized
else:
threshold = pl_module.pixel_threshold.value.item()
normalize = True # raw anomaly maps. Still need to normalize

for (filename, image, true_mask, anomaly_map) in zip(
outputs["image_path"], outputs["image"], outputs["mask"], outputs["anomaly_maps"]
):
image = Denormalize()(image.cpu())
true_mask = true_mask.cpu().numpy()
anomaly_map = anomaly_map.cpu().numpy()

heat_map = superimpose_anomaly_map(anomaly_map, image)
pred_mask = compute_mask(anomaly_map, pl_module.threshold.item())
heat_map = superimpose_anomaly_map(anomaly_map, image, normalize=normalize)
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")

visualizer = Visualizer(num_rows=1, num_cols=5, figure_size=(12, 3))
Expand All @@ -100,14 +108,14 @@ def on_test_batch_end(
self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()

def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Sync logs.

Currently only ``AnomalibWandbLogger`` is called from this method. This is because logging as a single batch
ensures that all images appear as part of the same step.

Args:
trainer (pl.Trainer): Pytorch Lightning trainer
_trainer (pl.Trainer): Pytorch Lightning trainer (unused)
pl_module (pl.LightningModule): Anomaly module
"""
if pl_module.logger is not None and isinstance(pl_module.logger, AnomalibWandbLogger):
Expand Down
4 changes: 3 additions & 1 deletion anomalib/core/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Custom anomaly evaluation metrics."""
from .adaptive_threshold import AdaptiveThreshold
from .anomaly_score_distribution import AnomalyScoreDistribution
from .auroc import AUROC
from .optimal_f1 import OptimalF1

__all__ = ["AUROC", "OptimalF1"]
__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution"]
41 changes: 41 additions & 0 deletions anomalib/core/metrics/adaptive_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Implementation of Optimal F1 score based on TorchMetrics."""
import torch
from torchmetrics import Metric, PrecisionRecallCurve


class AdaptiveThreshold(Metric):
"""Optimal F1 Metric.

Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the
predicted anomaly scores.
"""

def __init__(self, default_value: float, **kwargs):
super().__init__(**kwargs)

self.precision_recall_curve = PrecisionRecallCurve(num_classes=1, compute_on_step=False)
self.add_state("value", default=torch.tensor(default_value), persistent=True) # pylint: disable=not-callable
self.value = torch.tensor(default_value) # pylint: disable=not-callable

# pylint: disable=arguments-differ
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore
"""Update the precision-recall curve metric."""
self.precision_recall_curve.update(preds, target)

def compute(self) -> torch.Tensor:
"""Compute the threshold that yields the optimal F1 score.

Compute the F1 scores while varying the threshold. Store the optimal
threshold as attribute and return the maximum value of the F1 score.

Returns:
Value of the F1 score at the optimal threshold.
"""
precision: torch.Tensor
recall: torch.Tensor
thresholds: torch.Tensor

precision, recall, thresholds = self.precision_recall_curve.compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
self.value = thresholds[torch.argmax(f1_score)]
return self.value
52 changes: 52 additions & 0 deletions anomalib/core/metrics/anomaly_score_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Module that computes the parameters of the normal data distribution of the training set."""
from typing import Optional, Tuple

import torch
from torch import Tensor
from torchmetrics import Metric


class AnomalyScoreDistribution(Metric):
"""Mean and standard deviation of the anomaly scores of normal training data."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.anomaly_maps = []
self.anomaly_scores = []

self.add_state("image_mean", torch.empty(0), persistent=True)
self.add_state("image_std", torch.empty(0), persistent=True)
self.add_state("pixel_mean", torch.empty(0), persistent=True)
self.add_state("pixel_std", torch.empty(0), persistent=True)

self.image_mean = torch.empty(0)
self.image_std = torch.empty(0)
self.pixel_mean = torch.empty(0)
self.pixel_std = torch.empty(0)

# pylint: disable=arguments-differ
def update( # type: ignore
self, anomaly_scores: Optional[Tensor] = None, anomaly_maps: Optional[Tensor] = None
) -> None:
"""Update the precision-recall curve metric."""
if anomaly_maps is not None:
self.anomaly_maps.append(anomaly_maps)
if anomaly_scores is not None:
self.anomaly_scores.append(anomaly_scores)

def compute(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Compute stats."""
anomaly_scores = torch.hstack(self.anomaly_scores)
anomaly_scores = torch.log(anomaly_scores)

self.image_mean = anomaly_scores.mean()
self.image_std = anomaly_scores.std()

if self.anomaly_maps:
anomaly_maps = torch.vstack(self.anomaly_maps)
anomaly_maps = torch.log(anomaly_maps).cpu()

self.pixel_mean = anomaly_maps.mean(dim=0).squeeze()
self.pixel_std = anomaly_maps.std(dim=0).squeeze()

return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std
Loading