Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalization in Datamodules does not use correct statistics #2175

Closed
nilsleh opened this issue Jul 18, 2024 · 1 comment · Fixed by #2176
Closed

Normalization in Datamodules does not use correct statistics #2175

nilsleh opened this issue Jul 18, 2024 · 1 comment · Fixed by #2176

Comments

@nilsleh
Copy link
Collaborator

nilsleh commented Jul 18, 2024

Description

For some datasets like EuroSAT, the normalization statistics per band are included in the datamodule. However, they are not passed in to the self.aug transform or others such that the image normalizations are wrong (the applied augmentations are still the default mean=0, std=255). I would expect that this also happens for other datamodules.

Steps to reproduce

import kornia.augmentation as K
from torchgeo.transforms import AugmentationSequential
from torchgeo.datamodules import EuroSATDataModule
from copy import deepcopy

ds = EuroSAT(root="/mnt/SSD2/nils/projects/vae/data", bands=["B04", "B03", "B02"])
dm = EuroSATDataModule(root="/mnt/SSD2/nils/projects/vae/data", bands=["B04", "B03", "B02"], batch_size=256)
dm.setup("fit")
train_loader = dm.train_dataloader()
orig_batch = next(iter(train_loader))


# augmentation from datamodule
print(dm.aug)
batch = dm.aug(deepcopy(orig_batch))

print("Default")
print("mean", batch["image"].mean())
print("std", batch["image"].std())

# updated augmentation
normalize_transform = AugmentationSequential(
    K.Normalize(mean=dm.mean, std=dm.std),
    data_keys=["image"]
)

batch = normalize_transform(deepcopy(orig_batch))

print("Updated")
print("mean", batch["image"].mean())
print("std", batch["image"].std())
AugmentationSequential(
  (augs): AugmentationSequential(
    (Normalize_0): Normalize(p=1.0, p_batch=1.0, same_on_batch=True, mean=0, std=255)
  )
)
Default
mean tensor(4.0712)
std tensor(1.7841)
Updated
mean tensor(0.0053)
std tensor(0.9901)

Version

'0.6.0.dev0'

@adamjstewart
Copy link
Collaborator

Oops. Luckily this is easy to fix. Just move self.mean and self.std BEFORE super().__init__(). Want to submit a PR and check for other places where this happens?

@adamjstewart adamjstewart added this to the 0.5.3 milestone Jul 18, 2024
@adamjstewart adamjstewart modified the milestones: 0.5.3, 0.6.0 Aug 6, 2024
@adamjstewart adamjstewart removed this from the 0.6.0 milestone Aug 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants