You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using the MinMaxMetric wrapper in a LightningModule, as shown in Torchmetrics docs, it generates a ValueError: ValueError: The '.compute()' return of the metric logged as 'val/macro_jaccard' must be a tensor. Found {'raw': tensor(0.3260, device='cuda:0'), 'max': tensor(0.3260, device='cuda:0'), 'min': tensor(0.3231, device='cuda:0')}
To Reproduce
Here is a sample code to reproduce the error
Code sample
importtorchimporttorch.nn.functionalasFimportpytorch_lightningasplfromtorchmetricsimportJaccardIndex, MetricCollectionfromtorchmetrics.wrappersimportMinMaxMetricclassLitModel(pl.LightningModule):
def__init__(self):
super().__init__()
# Simple model (1 conv layer for demonstration)self.model=torch.nn.Conv2d(3, 2, kernel_size=3, padding=1)
# Define train metrics and wrap them with MinMaxMetricself.train_metrics=MetricCollection({
'weighted_jaccard': MinMaxMetric(JaccardIndex(task='multiclass', num_classes=2, average='weighted')),
'macro_jaccard': MinMaxMetric(JaccardIndex(task='multiclass', num_classes=2, average='macro'))
})
# Clone metrics for validationself.val_metrics=self.train_metrics.clone(prefix='val/')
defforward(self, x):
returnself.model(x)
defstep(self, batch):
x, y=batchlogits=self(x)
loss=F.cross_entropy(logits, y)
returnlogits, y, lossdeftraining_step(self, batch, batch_idx):
logits, y, loss=self.step(batch)
# Log lossself.log("train/loss", loss, on_step=True, on_epoch=True)
# Update and log metricsself.train_metrics(logits.argmax(dim=1), y)
self.log_dict(self.train_metrics, on_step=True, on_epoch=True)
returnlossdefvalidation_step(self, batch, batch_idx):
logits, y, loss=self.step(batch)
# Log validation lossself.log("val/loss", loss, on_step=False, on_epoch=True)
# Update and log validation metricsself.val_metrics(logits.argmax(dim=1), y)
self.log_dict(self.val_metrics, on_step=False, on_epoch=True)
returnlossdefconfigure_optimizers(self):
returntorch.optim.Adam(self.parameters(), lr=1e-3)
# Simulated data for training and validationdefgenerate_fake_data(image_size=(3, 64, 64), num_classes=2):
x=torch.randn(*image_size) # Fake single image inputy=torch.randint(0, num_classes, (image_size[1], image_size[2])) # Fake segmentation labelsreturnx, y# Create DataLoaderstrain_data= [generate_fake_data() for_inrange(100)] # 100 batches of training dataval_data= [generate_fake_data() for_inrange(20)] # 20 batches of validation datatrain_loader=torch.utils.data.DataLoader(train_data, batch_size=16)
val_loader=torch.utils.data.DataLoader(val_data, batch_size=16)
# Initialize the modelmodel=LitModel()
# Initialize a PyTorch Lightning Trainer and fit the modeltrainer=pl.Trainer(max_epochs=3, log_every_n_steps=10, devices=[0])
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
Expected behavior
It should not generate a ValueError and should log 3 metrics: the actual one, the min one, and the max one.
Environment
TorchMetrics version (if build from source, add commit SHA): 1.4.2
Python & PyTorch Version (e.g., 1.0): 3.11.9 & 2.4.0
Any other relevant information such as OS (e.g., Linux): Ubuntu VM
Additional context
If the MinMaxMetric is used outside of a MetricCollection, it generates another error. The error doesn't happen at .compute time but before, when doing .log. Here is the traceback (but maybe I should open a second issue?):
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/patron/mambaforge/envs/segmentator/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 608, in log_dict
for k, v in dictionary.items(**kwargs):
^^^^^^^^^^^^^^^^
File "/home/patron/mambaforge/envs/segmentator/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'MinMaxMetric' object has no attribute 'items'
The text was updated successfully, but these errors were encountered:
Hi @patrontheo, thanks for creating this issue.
I am going to refer you to this comment I made on a similar issue about logging using MetricCollection + ClasswiseWrapper, instead of MetricCollection + MinMaxMetric. Different wrapper but the issue is the same and the answer is therefore also the same.
In short, this kind of nested metric structure is not something we can support directly but that does not mean that it cannot be logged. It just requires you to log the computed value instead of the metric object. I have created PR #2772 which includes an integration test (for future reference) on how to log MetricCollection + MinMaxMetric in Lightning and I will be adding some more text to the documentation to further explain what to do in this situation
🐛 Bug
When using the
MinMaxMetric
wrapper in a LightningModule, as shown in Torchmetrics docs, it generates a ValueError:ValueError: The '.compute()' return of the metric logged as 'val/macro_jaccard' must be a tensor. Found {'raw': tensor(0.3260, device='cuda:0'), 'max': tensor(0.3260, device='cuda:0'), 'min': tensor(0.3231, device='cuda:0')}
To Reproduce
Here is a sample code to reproduce the error
Code sample
Expected behavior
It should not generate a ValueError and should log 3 metrics: the actual one, the min one, and the max one.
Environment
Additional context
If the MinMaxMetric is used outside of a MetricCollection, it generates another error. The error doesn't happen at
.compute
time but before, when doing.log
. Here is the traceback (but maybe I should open a second issue?):The text was updated successfully, but these errors were encountered: