Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MinMaxMetric not working in Pytorch Lightning #2763

Closed
patrontheo opened this issue Sep 25, 2024 · 2 comments · Fixed by #2772
Closed

MinMaxMetric not working in Pytorch Lightning #2763

patrontheo opened this issue Sep 25, 2024 · 2 comments · Fixed by #2772
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x

Comments

@patrontheo
Copy link

🐛 Bug

When using the MinMaxMetric wrapper in a LightningModule, as shown in Torchmetrics docs, it generates a ValueError: ValueError: The '.compute()' return of the metric logged as 'val/macro_jaccard' must be a tensor. Found {'raw': tensor(0.3260, device='cuda:0'), 'max': tensor(0.3260, device='cuda:0'), 'min': tensor(0.3231, device='cuda:0')}

To Reproduce

Here is a sample code to reproduce the error

Code sample
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import JaccardIndex, MetricCollection
from torchmetrics.wrappers import MinMaxMetric

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Simple model (1 conv layer for demonstration)
        self.model = torch.nn.Conv2d(3, 2, kernel_size=3, padding=1)
        
        # Define train metrics and wrap them with MinMaxMetric
        self.train_metrics = MetricCollection({
            'weighted_jaccard': MinMaxMetric(JaccardIndex(task='multiclass', num_classes=2, average='weighted')),
            'macro_jaccard': MinMaxMetric(JaccardIndex(task='multiclass', num_classes=2, average='macro'))
        })

        # Clone metrics for validation
        self.val_metrics = self.train_metrics.clone(prefix='val/')

    def forward(self, x):
        return self.model(x)
    
    def step(self, batch):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return logits, y, loss

    def training_step(self, batch, batch_idx):
        logits, y, loss = self.step(batch)

        # Log loss
        self.log("train/loss", loss, on_step=True, on_epoch=True)

        # Update and log metrics
        self.train_metrics(logits.argmax(dim=1), y)
        self.log_dict(self.train_metrics, on_step=True, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        logits, y, loss = self.step(batch)

        # Log validation loss
        self.log("val/loss", loss, on_step=False, on_epoch=True)

        # Update and log validation metrics
        self.val_metrics(logits.argmax(dim=1), y)
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Simulated data for training and validation
def generate_fake_data(image_size=(3, 64, 64), num_classes=2):
    x = torch.randn(*image_size)  # Fake single image input
    y = torch.randint(0, num_classes, (image_size[1], image_size[2]))  # Fake segmentation labels
    return x, y

# Create DataLoaders
train_data = [generate_fake_data() for _ in range(100)]  # 100 batches of training data
val_data = [generate_fake_data() for _ in range(20)]     # 20 batches of validation data

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=16)

# Initialize the model
model = LitModel()

# Initialize a PyTorch Lightning Trainer and fit the model
trainer = pl.Trainer(max_epochs=3, log_every_n_steps=10, devices=[0])
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Expected behavior

It should not generate a ValueError and should log 3 metrics: the actual one, the min one, and the max one.

Environment

  • TorchMetrics version (if build from source, add commit SHA): 1.4.2
  • Python & PyTorch Version (e.g., 1.0): 3.11.9 & 2.4.0
  • Any other relevant information such as OS (e.g., Linux): Ubuntu VM

Additional context

If the MinMaxMetric is used outside of a MetricCollection, it generates another error. The error doesn't happen at .compute time but before, when doing .log. Here is the traceback (but maybe I should open a second issue?):

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/patron/mambaforge/envs/segmentator/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 608, in log_dict
    for k, v in dictionary.items(**kwargs):
                ^^^^^^^^^^^^^^^^
  File "/home/patron/mambaforge/envs/segmentator/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'MinMaxMetric' object has no attribute 'items'
@patrontheo patrontheo added bug / fix Something isn't working help wanted Extra attention is needed labels Sep 25, 2024
@Borda Borda added the v1.4.x label Sep 25, 2024
@SkafteNicki
Copy link
Member

Hi @patrontheo, thanks for creating this issue.
I am going to refer you to this comment I made on a similar issue about logging using MetricCollection + ClasswiseWrapper, instead of MetricCollection + MinMaxMetric. Different wrapper but the issue is the same and the answer is therefore also the same.

In short, this kind of nested metric structure is not something we can support directly but that does not mean that it cannot be logged. It just requires you to log the computed value instead of the metric object. I have created PR #2772 which includes an integration test (for future reference) on how to log MetricCollection + MinMaxMetric in Lightning and I will be adding some more text to the documentation to further explain what to do in this situation

@patrontheo
Copy link
Author

@SkafteNicki Thanks for the answer :)
I just had a quick question here about the implementation in your tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants