diff --git a/metnet/layers/ConditionTimeMetNet2.py b/metnet/layers/ConditionTimeMetNet2.py new file mode 100644 index 0000000..18c26dd --- /dev/null +++ b/metnet/layers/ConditionTimeMetNet2.py @@ -0,0 +1,53 @@ +import einops +import torch +from torch import nn as nn + + +class ConditionWithTimeMetNet2(nn.Module): + """Compute Scale and bias for conditioning on time""" + + def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int): + """ + Compute the scale and bias factors for conditioning convolutional blocks on the forecast time + + Args: + forecast_steps: Number of forecast steps + hidden_dim: Hidden dimension size + num_feature_maps: Max number of channels in the blocks, to generate enough scale+bias values + This means extra values will be generated, but keeps implementation simpler + """ + super().__init__() + self.forecast_steps = forecast_steps + self.num_feature_maps = num_feature_maps + self.lead_time_network = nn.ModuleList( + [ + nn.Linear(in_features=forecast_steps, out_features=hidden_dim), + nn.Linear(in_features=hidden_dim, out_features=2 * num_feature_maps), + ] + ) + + def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor]: + """ + Get the scale and bias for the conditioning layers + + From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer, so needs + a scale and bias for each feature map to be generated + + Args: + x: The Tensor that is used + timestep: Index of the timestep to use, between 0 and forecast_steps + + Returns: + 2 Tensors of shape (Batch, num_feature_maps) + """ + # One hot encode the timestep + timesteps = torch.zeros(x.size()[0], self.forecast_steps, dtype=x.dtype) + timesteps[:, timestep] = 1 + # Get scales and biases + for layer in self.lead_time_network: + timesteps = layer(timesteps) + scales_and_biases = timesteps + scales_and_biases = einops.rearrange( + scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2 + ) + return scales_and_biases[:, :, 0], scales_and_biases[:, :, 1] diff --git a/metnet/models/metnet2.py b/metnet/models/metnet2.py index 4fff887..2e0e532 100644 --- a/metnet/models/metnet2.py +++ b/metnet/models/metnet2.py @@ -1,13 +1,13 @@ """MetNet-2 model for weather forecasting""" from typing import List -import einops import torch import torch.nn as nn import torchvision.transforms from huggingface_hub import PyTorchModelHubMixin from metnet.layers import DownSampler, MetNetPreprocessor, TimeDistributed +from metnet.layers.ConditionTimeMetNet2 import ConditionWithTimeMetNet2 from metnet.layers.ConvLSTM import ConvLSTM from metnet.layers.DilatedCondConv import DilatedResidualConv, UpsampleResidualConv @@ -192,6 +192,16 @@ def __init__( ] ) + self.conv_lstm_time_conditioners = nn.ModuleList() + for layer in range(self.conv_lstm.num_layers): + self.conv_lstm_time_conditioners.append( + ConditionWithTimeMetNet2( + forecast_steps=forecast_steps, + hidden_dim=self.conv_lstm.cell_list[layer].hidden_dim, + num_feature_maps=self.conv_lstm.cell_list[layer], + ) + ) + self.time_conditioners = nn.ModuleList() # Go through each set of blocks and add conditioner # Context Stack @@ -286,53 +296,3 @@ def forward(self, x: torch.Tensor, lead_time: int = 0): # Softmax for rain forecasting return res - - -class ConditionWithTimeMetNet2(nn.Module): - """Compute Scale and bias for conditioning on time""" - - def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int): - """ - Compute the scale and bias factors for conditioning convolutional blocks on the forecast time - - Args: - forecast_steps: Number of forecast steps - hidden_dim: Hidden dimension size - num_feature_maps: Max number of channels in the blocks, to generate enough scale+bias values - This means extra values will be generated, but keeps implementation simpler - """ - super().__init__() - self.forecast_steps = forecast_steps - self.num_feature_maps = num_feature_maps - self.lead_time_network = nn.ModuleList( - [ - nn.Linear(in_features=forecast_steps, out_features=hidden_dim), - nn.Linear(in_features=hidden_dim, out_features=2 * num_feature_maps), - ] - ) - - def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor]: - """ - Get the scale and bias for the conditioning layers - - From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer, so needs - a scale and bias for each feature map to be generated - - Args: - x: The Tensor that is used - timestep: Index of the timestep to use, between 0 and forecast_steps - - Returns: - 2 Tensors of shape (Batch, num_feature_maps) - """ - # One hot encode the timestep - timesteps = torch.zeros(x.size()[0], self.forecast_steps, dtype=x.dtype) - timesteps[:, timestep] = 1 - # Get scales and biases - for layer in self.lead_time_network: - timesteps = layer(timesteps) - scales_and_biases = timesteps - scales_and_biases = einops.rearrange( - scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2 - ) - return scales_and_biases[:, :, 0], scales_and_biases[:, :, 1]