Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change regression task to timm support #854

Merged
merged 13 commits into from
Dec 7, 2022
5 changes: 4 additions & 1 deletion tests/conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
experiment:
task: cowc_counting
module:
model: resnet18
regression_model: resnet18
weights: "random"
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: True
Expand Down
5 changes: 4 additions & 1 deletion tests/conf/cyclone.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
experiment:
task: "cyclone"
module:
model: "resnet18"
regression_model: "resnet18"
weights: "random"
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: False
Expand Down
38 changes: 34 additions & 4 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,37 @@ def test_no_logger(self) -> None:
)
trainer.fit(model=model, datamodule=datamodule)

def test_invalid_model(self) -> None:
match = "module 'torchvision.models' has no attribute 'invalid_model'"
with pytest.raises(AttributeError, match=match):
RegressionTask(model="invalid_model", pretrained=False)
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
"regression_model": "resnet18",
"weights": "random",
"num_outputs": 1,
"in_channels": 3,
}

def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
) -> None:
model_kwargs["weights"] = checkpoint
model_kwargs["regression_model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)

def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["regression_model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)

def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)
84 changes: 62 additions & 22 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

"""Regression tasks."""

import os
from typing import Any, Dict, cast

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from packaging.version import parse
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection

from ..datasets.utils import unbind_samples
from . import utils

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Expand All @@ -26,36 +26,76 @@


class RegressionTask(pl.LightningModule):
"""LightningModule for training models on regression datasets."""
"""LightningModule for training models on regression datasets.

Supports any available `Timm model
<https://rwightman.github.io/pytorch-image-models/>`_
as an architecture choice. To see a list of available pretrained
models, you can do:

.. code-block:: python

import timm
print(timm.list_models(pretrained=True))
"""

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
model = self.hyperparams["model"]
pretrained = self.hyperparams["pretrained"]

if parse(torchvision.__version__) >= parse("0.13"):
if pretrained:
kwargs = {
"weights": getattr(
torchvision.models, f"ResNet{model[6:]}_Weights"
).DEFAULT
}
in_channels = self.hyperparams["in_channels"]
regression_model = self.hyperparams["regression_model"]
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
kwargs = {"weights": None}
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
kwargs = {"pretrained": pretrained}
custom_pretrained = True

# Create the model
valid_models = timm.list_models(pretrained=True)
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
if regression_model in valid_models:
self.model = timm.create_model(
regression_model,
num_classes=self.hyperparams["num_outputs"],
in_chans=in_channels,
pretrained=imagenet_pretrained,
)
else:
raise ValueError(
f"Model type '{regression_model}' is not a valid timm model."
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)

if custom_pretrained:
name, state_dict = utils.extract_encoder(self.hyperparams["weights"])

self.model = getattr(torchvision.models, model)(**kwargs)
in_features = self.model.fc.in_features
self.model.fc = nn.Linear(in_features, out_features=1)
if self.hyperparams["regression_model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['regression_model']}"
)
self.model = utils.load_state_dict(self.model, state_dict)

def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.

Keyword Args:
model: Name of the model to use
learning_rate: Initial learning rate to use in the optimizer
learning_rate_schedule_patience: Patience parameter for the LR scheduler
regression_model: Name of the model to use
weights: Either "random" or "imagenet"
num_outputs: Number of prediction outputs
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler

:: versionchanged:: 0.4
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
Change regression model support from torchvision.models to timm
"""
super().__init__()

Expand Down