Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👕 Complete metric overhaul, improve PP handling & fix Laplace #116

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
)
author = "Adrien Lafage and Olivier Laurent"
release = "0.2.2.post1"
release = "0.2.2.post2"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "torch_uncertainty"
version = "0.2.2.post1"
version = "0.2.2.post2"
authors = [
{ name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" },
{ name = "Adrien Lafage", email = "adrienlafage@outlook.com" },
Expand All @@ -18,7 +18,6 @@ keywords = [
"ensembles",
"neural-networks",
"predictive-uncertainty",
"pytorch",
"reliable-ai",
"trustworthy-machine-learning",
"uncertainty",
Expand All @@ -44,6 +43,7 @@ dependencies = [
"numpy<2",
"opencv-python",
"glest==0.0.1a0",
"rich>=10.2.2",
]

[project.optional-dependencies]
Expand Down
4 changes: 3 additions & 1 deletion torch_uncertainty/post_processing/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self.weight_subset = weight_subset
self.hessian_struct = hessian_struct
self.batch_size = batch_size
self.optimize_prior_precision = optimize_prior_precision

if model is not None:
self.set_model(model)
Expand All @@ -80,7 +81,8 @@ def set_model(self, model: nn.Module) -> None:
def fit(self, dataset: Dataset) -> None:
dl = DataLoader(dataset, batch_size=self.batch_size)
self.la.fit(train_loader=dl)
self.la.optimize_prior_precision(method="marglik")
if self.optimize_prior_precision:
self.la.optimize_prior_precision(method="marglik")

def forward(
self,
Expand Down
16 changes: 8 additions & 8 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,23 +196,23 @@ def _init_metrics(self) -> None:
),
"sc/AURC": AURC(),
"sc/AUGRC": AUGRC(),
"sc/CovAt5Risk": CovAt5Risk(),
"sc/RiskAt80Cov": RiskAt80Cov(),
"sc/Cov@5Risk": CovAt5Risk(),
"sc/Risk@80Cov": RiskAt80Cov(),
},
compute_groups=[
["cls/Acc"],
["cls/Brier"],
["cls/NLL"],
["cal/ECE", "cal/aECE"],
["sc/AURC", "sc/AUGRC", "sc/CovAt5Risk", "sc/RiskAt80Cov"],
["sc/AURC", "sc/AUGRC", "sc/Cov@5Risk", "sc/Risk@80Cov"],
],
)

self.val_cls_metrics = cls_metrics.clone(prefix="val/")
self.test_cls_metrics = cls_metrics.clone(prefix="test/")

if self.post_processing is not None:
self.ts_cls_metrics = cls_metrics.clone(prefix="test/ts_")
self.post_cls_metrics = cls_metrics.clone(prefix="test/post/")

self.test_id_entropy = Entropy()

Expand Down Expand Up @@ -463,7 +463,7 @@ def test_step(
)
self.test_id_entropy(probs)
self.log(
"test/cls/entropy",
"test/cls/Entropy",
self.test_id_entropy,
on_epoch=True,
add_dataloader_idx=False,
Expand All @@ -486,7 +486,7 @@ def test_step(
pp_probs = F.softmax(pp_logits, dim=-1)
else:
pp_probs = pp_logits
self.ts_cls_metrics.update(pp_probs, targets)
self.post_cls_metrics.update(pp_probs, targets)

elif self.eval_ood and dataloader_idx == 1:
self.test_ood_metrics.update(ood_scores, torch.ones_like(targets))
Expand Down Expand Up @@ -529,7 +529,7 @@ def on_test_epoch_end(self) -> None:
)

if self.post_processing is not None:
tmp_metrics = self.ts_cls_metrics.compute()
tmp_metrics = self.post_cls_metrics.compute()
self.log_dict(tmp_metrics, sync_dist=True)
result_dict.update(tmp_metrics)

Expand Down Expand Up @@ -573,7 +573,7 @@ def on_test_epoch_end(self) -> None:
if self.post_processing is not None:
self.logger.experiment.add_figure(
"Reliabity diagram after calibration",
self.ts_cls_metrics["cal/ECE"].plot()[0],
self.post_cls_metrics["cal/ECE"].plot()[0],
)

# plot histograms of logits and likelihoods
Expand Down
24 changes: 12 additions & 12 deletions torch_uncertainty/routines/pixel_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ def __init__(

depth_metrics = MetricCollection(
{
"SILog": SILog(),
"log10": Log10(),
"ARE": MeanGTRelativeAbsoluteError(),
"RSRE": MeanGTRelativeSquaredError(squared=False),
"RMSE": MeanSquaredError(squared=False),
"RMSELog": MeanSquaredLogError(squared=False),
"iMAE": MeanAbsoluteErrorInverse(),
"iRMSE": MeanSquaredErrorInverse(squared=False),
"d1": ThresholdAccuracy(power=1),
"d2": ThresholdAccuracy(power=2),
"d3": ThresholdAccuracy(power=3),
"reg/SILog": SILog(),
"reg/log10": Log10(),
"reg/ARE": MeanGTRelativeAbsoluteError(),
"reg/RSRE": MeanGTRelativeSquaredError(squared=False),
"reg/RMSE": MeanSquaredError(squared=False),
"reg/RMSELog": MeanSquaredLogError(squared=False),
"reg/iMAE": MeanAbsoluteErrorInverse(),
"reg/iRMSE": MeanSquaredErrorInverse(squared=False),
"reg/d1": ThresholdAccuracy(power=1),
"reg/d2": ThresholdAccuracy(power=2),
"reg/d3": ThresholdAccuracy(power=3),
},
compute_groups=False,
)
Expand All @@ -119,7 +119,7 @@ def __init__(

if self.probabilistic:
depth_prob_metrics = MetricCollection(
{"NLL": DistributionNLL(reduction="mean")}
{"reg/NLL": DistributionNLL(reduction="mean")}
)
self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/")
self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/")
Expand Down
8 changes: 4 additions & 4 deletions torch_uncertainty/routines/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def __init__(

reg_metrics = MetricCollection(
{
"MAE": MeanAbsoluteError(),
"MSE": MeanSquaredError(squared=True),
"RMSE": MeanSquaredError(squared=False),
"reg/MAE": MeanAbsoluteError(),
"reg/MSE": MeanSquaredError(squared=True),
"reg/RMSE": MeanSquaredError(squared=False),
},
compute_groups=True,
)
Expand All @@ -96,7 +96,7 @@ def __init__(

if self.probabilistic:
reg_prob_metrics = MetricCollection(
{"NLL": DistributionNLL(reduction="mean")}
{"reg/NLL": DistributionNLL(reduction="mean")}
)
self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/")
self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/")
Expand Down
Loading
Loading