Skip to content

Commit

Permalink
fix format issues
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxu77 committed Nov 1, 2024
1 parent 1c32b6e commit ef5fbde
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/uni2ts/common/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def packed_attention_mask(

def packed_causal_attention_mask(
sample_id: Int[torch.Tensor, "*batch seq_len"],
time_id: Int[torch.Tensor, "*batch seq_len"]
time_id: Int[torch.Tensor, "*batch seq_len"],
) -> Bool[torch.Tensor, "*batch seq_len seq_len"]:
attention_mask = packed_attention_mask(sample_id)
expanded_id1 = time_id.unsqueeze(-2)
Expand Down
8 changes: 7 additions & 1 deletion src/uni2ts/model/moirai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,10 @@
from .module_moe import MoiraiMoEModule
from .pretrain import MoiraiPretrain

__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiMoEModule", "MoiraiPretrain"]
__all__ = [
"MoiraiFinetune",
"MoiraiForecast",
"MoiraiModule",
"MoiraiMoEModule",
"MoiraiPretrain",
]
56 changes: 42 additions & 14 deletions src/uni2ts/model/moirai/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from gluonts.transform import (
AddObservedValuesIndicator,
AsNumpyArray,
CausalMeanValueImputation,
ExpandDimArray,
TestSplitSampler,
Transformation,
CausalMeanValueImputation,
)
from gluonts.transform.split import TFTInstanceSplitter
from jaxtyping import Bool, Float, Int
Expand Down Expand Up @@ -345,7 +345,9 @@ def forward(
past_feat_dynamic_real,
past_observed_feat_dynamic_real,
)
preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,)))
preds = distr.sample(
torch.Size((num_samples or self.hparams.num_samples,))
)
return self._format_preds(
self.hparams.patch_size, preds, past_target.shape[-1]
)
Expand Down Expand Up @@ -373,10 +375,18 @@ def forward(
past_feat_dynamic_real=past_feat_dynamic_real,
past_observed_feat_dynamic_real=past_observed_feat_dynamic_real,
)
patch_size = torch.ones_like(time_id, dtype=torch.long) * self.hparams.patch_size
patch_size = (
torch.ones_like(time_id, dtype=torch.long) * self.hparams.patch_size
)

pred_index = torch.arange(start=context_step-1, end=context_token, step=context_step)
assign_index = torch.arange(start=context_token, end=context_token+predict_token, step=predict_step)
pred_index = torch.arange(
start=context_step - 1, end=context_token, step=context_step
)
assign_index = torch.arange(
start=context_token,
end=context_token + predict_token,
step=predict_step,
)

if predict_step == 1:
distr = self.module(
Expand All @@ -388,7 +398,9 @@ def forward(
prediction_mask,
patch_size,
)
preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,)))
preds = distr.sample(
torch.Size((num_samples or self.hparams.num_samples,))
)
preds[..., assign_index, :] = preds[..., pred_index, :]
return self._format_preds(
self.hparams.patch_size, preds, self.hparams.target_dim
Expand All @@ -405,13 +417,27 @@ def forward(
)
preds = distr.sample(torch.Size((self.hparams.num_samples,)))

expand_target = target.unsqueeze(0).repeat(self.hparams.num_samples, 1, 1, 1)
expand_prediction_mask = prediction_mask.unsqueeze(0).repeat(self.hparams.num_samples, 1, 1)
expand_observed_mask = observed_mask.unsqueeze(0).expand(self.hparams.num_samples, -1, -1, -1)
expand_sample_id = sample_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)
expand_time_id = time_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)
expand_variate_id = variate_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)
expand_patch_size = patch_size.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)
expand_target = target.unsqueeze(0).repeat(
self.hparams.num_samples, 1, 1, 1
)
expand_prediction_mask = prediction_mask.unsqueeze(0).repeat(
self.hparams.num_samples, 1, 1
)
expand_observed_mask = observed_mask.unsqueeze(0).expand(
self.hparams.num_samples, -1, -1, -1
)
expand_sample_id = sample_id.unsqueeze(0).expand(
self.hparams.num_samples, -1, -1
)
expand_time_id = time_id.unsqueeze(0).expand(
self.hparams.num_samples, -1, -1
)
expand_variate_id = variate_id.unsqueeze(0).expand(
self.hparams.num_samples, -1, -1
)
expand_patch_size = patch_size.unsqueeze(0).expand(
self.hparams.num_samples, -1, -1
)

expand_target[..., assign_index, :] = preds[..., pred_index, :]
expand_prediction_mask[..., assign_index] = False
Expand Down Expand Up @@ -581,7 +607,9 @@ def _generate_time_id(
"max",
patch=patch_size,
)
past_seq_id = torch.clamp(past_seq_id.cummax(dim=-1).values.cumsum(dim=-1) - 1, min=0)
past_seq_id = torch.clamp(
past_seq_id.cummax(dim=-1).values.cumsum(dim=-1) - 1, min=0
)
batch_shape = " ".join(map(str, past_observed_target.shape[:-2]))
future_seq_id = (
repeat(
Expand Down
2 changes: 1 addition & 1 deletion src/uni2ts/model/moirai/module_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
RotaryProjection,
)
from uni2ts.module.transformer import TransformerEncoder
from uni2ts.module.ts_embed import MultiInSizeLinear, FeatLinear
from uni2ts.module.ts_embed import FeatLinear, MultiInSizeLinear


def encode_distr_output(
Expand Down
13 changes: 7 additions & 6 deletions src/uni2ts/module/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(
self.activation = activation

def forward(
self, x: Float[torch.Tensor, "... in_dim"],
self,
x: Float[torch.Tensor, "... in_dim"],
centroid: Optional[Float[torch.Tensor, "expert in_dim"]] = None,
) -> Float[torch.Tensor, "... out_dim"]:
x = self._in_proj(x)
Expand Down Expand Up @@ -122,7 +123,8 @@ def __init__(
)

def forward(
self, x: Float[torch.Tensor, "... in_dim"],
self,
x: Float[torch.Tensor, "... in_dim"],
centroid: Optional[Float[torch.Tensor, "expert in_dim"]] = None,
) -> Float[torch.Tensor, "... dim"]:
x_squashed = x.view(-1, x.shape[-1])
Expand All @@ -136,11 +138,10 @@ def forward(
cdist = torch.cdist(x_temp, centroid)
gate_logits = cdist.view(-1, cdist.shape[-1])

weights, selected_experts = torch.topk(
gate_logits, self.num_experts_per_token
)
weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_token)
weights = nn.functional.softmax(
weights, dim=1,
weights,
dim=1,
dtype=torch.float,
).type_as(x)

Expand Down
12 changes: 9 additions & 3 deletions src/uni2ts/module/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def __init__(
)
self.norm = norm_layer(d_model)

self.register_buffer("centroid",
torch.empty(num_layers, 32, d_model, dtype=torch.float64)
self.register_buffer(
"centroid", torch.empty(num_layers, 32, d_model, dtype=torch.float64)
)

@staticmethod
Expand All @@ -222,5 +222,11 @@ def forward(
time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None,
) -> Float[torch.Tensor, "*batch time_len dim"]:
for idx, layer in enumerate(self.layers):
x = layer(x, attn_mask, var_id=var_id, time_id=time_id, centroid=self.centroid[idx])
x = layer(
x,
attn_mask,
var_id=var_id,
time_id=time_id,
centroid=self.centroid[idx],
)
return self.norm(x)
7 changes: 2 additions & 5 deletions src/uni2ts/module/ts_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,12 @@ def __init__(
self.out_features = out_features

self.weight = nn.Parameter(
torch.empty(
(len(in_features_ls), out_features, out_features), dtype=dtype
)
torch.empty((len(in_features_ls), out_features, out_features), dtype=dtype)
)

if bias:
self.bias = nn.Parameter(
torch.empty((len(in_features_ls), out_features), dtype=dtype
)
torch.empty((len(in_features_ls), out_features), dtype=dtype)
)
else:
self.register_parameter("bias", None)
Expand Down

0 comments on commit ef5fbde

Please sign in to comment.