diff --git a/src/uni2ts/common/torch_util.py b/src/uni2ts/common/torch_util.py index 682ae10..2c041b5 100644 --- a/src/uni2ts/common/torch_util.py +++ b/src/uni2ts/common/torch_util.py @@ -44,7 +44,7 @@ def packed_attention_mask( def packed_causal_attention_mask( sample_id: Int[torch.Tensor, "*batch seq_len"], - time_id: Int[torch.Tensor, "*batch seq_len"] + time_id: Int[torch.Tensor, "*batch seq_len"], ) -> Bool[torch.Tensor, "*batch seq_len seq_len"]: attention_mask = packed_attention_mask(sample_id) expanded_id1 = time_id.unsqueeze(-2) diff --git a/src/uni2ts/model/moirai/__init__.py b/src/uni2ts/model/moirai/__init__.py index 097476d..affc159 100644 --- a/src/uni2ts/model/moirai/__init__.py +++ b/src/uni2ts/model/moirai/__init__.py @@ -19,4 +19,10 @@ from .module_moe import MoiraiMoEModule from .pretrain import MoiraiPretrain -__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiMoEModule", "MoiraiPretrain"] +__all__ = [ + "MoiraiFinetune", + "MoiraiForecast", + "MoiraiModule", + "MoiraiMoEModule", + "MoiraiPretrain", +] diff --git a/src/uni2ts/model/moirai/forecast.py b/src/uni2ts/model/moirai/forecast.py index 5394489..de7e864 100644 --- a/src/uni2ts/model/moirai/forecast.py +++ b/src/uni2ts/model/moirai/forecast.py @@ -27,10 +27,10 @@ from gluonts.transform import ( AddObservedValuesIndicator, AsNumpyArray, + CausalMeanValueImputation, ExpandDimArray, TestSplitSampler, Transformation, - CausalMeanValueImputation, ) from gluonts.transform.split import TFTInstanceSplitter from jaxtyping import Bool, Float, Int @@ -345,7 +345,9 @@ def forward( past_feat_dynamic_real, past_observed_feat_dynamic_real, ) - preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,))) + preds = distr.sample( + torch.Size((num_samples or self.hparams.num_samples,)) + ) return self._format_preds( self.hparams.patch_size, preds, past_target.shape[-1] ) @@ -373,10 +375,18 @@ def forward( past_feat_dynamic_real=past_feat_dynamic_real, past_observed_feat_dynamic_real=past_observed_feat_dynamic_real, ) - patch_size = torch.ones_like(time_id, dtype=torch.long) * self.hparams.patch_size + patch_size = ( + torch.ones_like(time_id, dtype=torch.long) * self.hparams.patch_size + ) - pred_index = torch.arange(start=context_step-1, end=context_token, step=context_step) - assign_index = torch.arange(start=context_token, end=context_token+predict_token, step=predict_step) + pred_index = torch.arange( + start=context_step - 1, end=context_token, step=context_step + ) + assign_index = torch.arange( + start=context_token, + end=context_token + predict_token, + step=predict_step, + ) if predict_step == 1: distr = self.module( @@ -388,7 +398,9 @@ def forward( prediction_mask, patch_size, ) - preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,))) + preds = distr.sample( + torch.Size((num_samples or self.hparams.num_samples,)) + ) preds[..., assign_index, :] = preds[..., pred_index, :] return self._format_preds( self.hparams.patch_size, preds, self.hparams.target_dim @@ -405,13 +417,27 @@ def forward( ) preds = distr.sample(torch.Size((self.hparams.num_samples,))) - expand_target = target.unsqueeze(0).repeat(self.hparams.num_samples, 1, 1, 1) - expand_prediction_mask = prediction_mask.unsqueeze(0).repeat(self.hparams.num_samples, 1, 1) - expand_observed_mask = observed_mask.unsqueeze(0).expand(self.hparams.num_samples, -1, -1, -1) - expand_sample_id = sample_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1) - expand_time_id = time_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1) - expand_variate_id = variate_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1) - expand_patch_size = patch_size.unsqueeze(0).expand(self.hparams.num_samples, -1, -1) + expand_target = target.unsqueeze(0).repeat( + self.hparams.num_samples, 1, 1, 1 + ) + expand_prediction_mask = prediction_mask.unsqueeze(0).repeat( + self.hparams.num_samples, 1, 1 + ) + expand_observed_mask = observed_mask.unsqueeze(0).expand( + self.hparams.num_samples, -1, -1, -1 + ) + expand_sample_id = sample_id.unsqueeze(0).expand( + self.hparams.num_samples, -1, -1 + ) + expand_time_id = time_id.unsqueeze(0).expand( + self.hparams.num_samples, -1, -1 + ) + expand_variate_id = variate_id.unsqueeze(0).expand( + self.hparams.num_samples, -1, -1 + ) + expand_patch_size = patch_size.unsqueeze(0).expand( + self.hparams.num_samples, -1, -1 + ) expand_target[..., assign_index, :] = preds[..., pred_index, :] expand_prediction_mask[..., assign_index] = False @@ -581,7 +607,9 @@ def _generate_time_id( "max", patch=patch_size, ) - past_seq_id = torch.clamp(past_seq_id.cummax(dim=-1).values.cumsum(dim=-1) - 1, min=0) + past_seq_id = torch.clamp( + past_seq_id.cummax(dim=-1).values.cumsum(dim=-1) - 1, min=0 + ) batch_shape = " ".join(map(str, past_observed_target.shape[:-2])) future_seq_id = ( repeat( diff --git a/src/uni2ts/model/moirai/module_moe.py b/src/uni2ts/model/moirai/module_moe.py index 9b6c768..7a7414a 100644 --- a/src/uni2ts/model/moirai/module_moe.py +++ b/src/uni2ts/model/moirai/module_moe.py @@ -34,7 +34,7 @@ RotaryProjection, ) from uni2ts.module.transformer import TransformerEncoder -from uni2ts.module.ts_embed import MultiInSizeLinear, FeatLinear +from uni2ts.module.ts_embed import FeatLinear, MultiInSizeLinear def encode_distr_output( diff --git a/src/uni2ts/module/ffn.py b/src/uni2ts/module/ffn.py index 5f52518..dafdef1 100644 --- a/src/uni2ts/module/ffn.py +++ b/src/uni2ts/module/ffn.py @@ -49,7 +49,8 @@ def __init__( self.activation = activation def forward( - self, x: Float[torch.Tensor, "... in_dim"], + self, + x: Float[torch.Tensor, "... in_dim"], centroid: Optional[Float[torch.Tensor, "expert in_dim"]] = None, ) -> Float[torch.Tensor, "... out_dim"]: x = self._in_proj(x) @@ -122,7 +123,8 @@ def __init__( ) def forward( - self, x: Float[torch.Tensor, "... in_dim"], + self, + x: Float[torch.Tensor, "... in_dim"], centroid: Optional[Float[torch.Tensor, "expert in_dim"]] = None, ) -> Float[torch.Tensor, "... dim"]: x_squashed = x.view(-1, x.shape[-1]) @@ -136,11 +138,10 @@ def forward( cdist = torch.cdist(x_temp, centroid) gate_logits = cdist.view(-1, cdist.shape[-1]) - weights, selected_experts = torch.topk( - gate_logits, self.num_experts_per_token - ) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_token) weights = nn.functional.softmax( - weights, dim=1, + weights, + dim=1, dtype=torch.float, ).type_as(x) diff --git a/src/uni2ts/module/transformer.py b/src/uni2ts/module/transformer.py index 67fa4c2..8666ab1 100644 --- a/src/uni2ts/module/transformer.py +++ b/src/uni2ts/module/transformer.py @@ -195,8 +195,8 @@ def __init__( ) self.norm = norm_layer(d_model) - self.register_buffer("centroid", - torch.empty(num_layers, 32, d_model, dtype=torch.float64) + self.register_buffer( + "centroid", torch.empty(num_layers, 32, d_model, dtype=torch.float64) ) @staticmethod @@ -222,5 +222,11 @@ def forward( time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, ) -> Float[torch.Tensor, "*batch time_len dim"]: for idx, layer in enumerate(self.layers): - x = layer(x, attn_mask, var_id=var_id, time_id=time_id, centroid=self.centroid[idx]) + x = layer( + x, + attn_mask, + var_id=var_id, + time_id=time_id, + centroid=self.centroid[idx], + ) return self.norm(x) diff --git a/src/uni2ts/module/ts_embed.py b/src/uni2ts/module/ts_embed.py index 0524705..02a44e2 100644 --- a/src/uni2ts/module/ts_embed.py +++ b/src/uni2ts/module/ts_embed.py @@ -123,15 +123,12 @@ def __init__( self.out_features = out_features self.weight = nn.Parameter( - torch.empty( - (len(in_features_ls), out_features, out_features), dtype=dtype - ) + torch.empty((len(in_features_ls), out_features, out_features), dtype=dtype) ) if bias: self.bias = nn.Parameter( - torch.empty((len(in_features_ls), out_features), dtype=dtype - ) + torch.empty((len(in_features_ls), out_features), dtype=dtype) ) else: self.register_parameter("bias", None)