Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FiLM layers to ConvLSTM #32

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions metnet/layers/ConditionTimeMetNet2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import einops
import torch
from torch import nn as nn


class ConditionWithTimeMetNet2(nn.Module):
"""Compute Scale and bias for conditioning on time"""

def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int):
"""
Compute the scale and bias factors for conditioning convolutional blocks on the forecast time

Args:
forecast_steps: Number of forecast steps
hidden_dim: Hidden dimension size
num_feature_maps: Max number of channels in the blocks, to generate enough scale+bias values
This means extra values will be generated, but keeps implementation simpler
"""
super().__init__()
self.forecast_steps = forecast_steps
self.num_feature_maps = num_feature_maps
self.lead_time_network = nn.ModuleList(
[
nn.Linear(in_features=forecast_steps, out_features=hidden_dim),
nn.Linear(in_features=hidden_dim, out_features=2 * num_feature_maps),
]
)

def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor]:
"""
Get the scale and bias for the conditioning layers

From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer, so needs
a scale and bias for each feature map to be generated

Args:
x: The Tensor that is used
timestep: Index of the timestep to use, between 0 and forecast_steps

Returns:
2 Tensors of shape (Batch, num_feature_maps)
"""
# One hot encode the timestep
timesteps = torch.zeros(x.size()[0], self.forecast_steps, dtype=x.dtype)
timesteps[:, timestep] = 1
# Get scales and biases
for layer in self.lead_time_network:
timesteps = layer(timesteps)
scales_and_biases = timesteps
scales_and_biases = einops.rearrange(
scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2
)
return scales_and_biases[:, :, 0], scales_and_biases[:, :, 1]
62 changes: 11 additions & 51 deletions metnet/models/metnet2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""MetNet-2 model for weather forecasting"""
from typing import List

import einops
import torch
import torch.nn as nn
import torchvision.transforms
from huggingface_hub import PyTorchModelHubMixin

from metnet.layers import DownSampler, MetNetPreprocessor, TimeDistributed
from metnet.layers.ConditionTimeMetNet2 import ConditionWithTimeMetNet2
from metnet.layers.ConvLSTM import ConvLSTM
from metnet.layers.DilatedCondConv import DilatedResidualConv, UpsampleResidualConv

Expand Down Expand Up @@ -192,6 +192,16 @@ def __init__(
]
)

self.conv_lstm_time_conditioners = nn.ModuleList()
for layer in range(self.conv_lstm.num_layers):
self.conv_lstm_time_conditioners.append(
ConditionWithTimeMetNet2(
forecast_steps=forecast_steps,
hidden_dim=self.conv_lstm.cell_list[layer].hidden_dim,
num_feature_maps=self.conv_lstm.cell_list[layer],
)
)

self.time_conditioners = nn.ModuleList()
# Go through each set of blocks and add conditioner
# Context Stack
Expand Down Expand Up @@ -286,53 +296,3 @@ def forward(self, x: torch.Tensor, lead_time: int = 0):

# Softmax for rain forecasting
return res


class ConditionWithTimeMetNet2(nn.Module):
"""Compute Scale and bias for conditioning on time"""

def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int):
"""
Compute the scale and bias factors for conditioning convolutional blocks on the forecast time

Args:
forecast_steps: Number of forecast steps
hidden_dim: Hidden dimension size
num_feature_maps: Max number of channels in the blocks, to generate enough scale+bias values
This means extra values will be generated, but keeps implementation simpler
"""
super().__init__()
self.forecast_steps = forecast_steps
self.num_feature_maps = num_feature_maps
self.lead_time_network = nn.ModuleList(
[
nn.Linear(in_features=forecast_steps, out_features=hidden_dim),
nn.Linear(in_features=hidden_dim, out_features=2 * num_feature_maps),
]
)

def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor]:
"""
Get the scale and bias for the conditioning layers

From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer, so needs
a scale and bias for each feature map to be generated

Args:
x: The Tensor that is used
timestep: Index of the timestep to use, between 0 and forecast_steps

Returns:
2 Tensors of shape (Batch, num_feature_maps)
"""
# One hot encode the timestep
timesteps = torch.zeros(x.size()[0], self.forecast_steps, dtype=x.dtype)
timesteps[:, timestep] = 1
# Get scales and biases
for layer in self.lead_time_network:
timesteps = layer(timesteps)
scales_and_biases = timesteps
scales_and_biases = einops.rearrange(
scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2
)
return scales_and_biases[:, :, 0], scales_and_biases[:, :, 1]