Skip to content

Commit

Permalink
Merge pull request #450 from tongnie/main
Browse files Browse the repository at this point in the history
Add Imputeformer
  • Loading branch information
WenjieDu authored Jul 1, 2024
2 parents 45b9573 + 177e4ab commit 18d94e4
Show file tree
Hide file tree
Showing 7 changed files with 782 additions and 0 deletions.
20 changes: 20 additions & 0 deletions pypots/imputation/imputeformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
The package of the partially-observed time-series imputation model Imputeformer.
Refer to the papers
`Tong Nie, Guoyang Qin, Wei Ma, Yuewen Mei, Jian Sun.
"ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation"
KDD 2024.
<https://doi.org/10.48550/arXiv.2312.01728>`_
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import Imputeformer

__all__ = [
"Imputeformer",
]
126 changes: 126 additions & 0 deletions pypots/imputation/imputeformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
The core wrapper assembles the submodules of Imputeformer imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.modules.saits import SaitsLoss
from ...nn.modules.imputeformer import EmbeddedAttentionLayer, ProjectedAttentionLayer, MLP
from einops import repeat


class _Imputeformer(nn.Module):
"""
Spatiotempoarl Imputation Transformer induced by low-rank factorization, KDD'24.
Note:
This is a simplified implementation under the SAITS framework (ORT+MIT).
The timestamp encoding is also removed for ease of implementation.
"""
def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_input_embed: int,
d_learnable_embed: int,
d_proj: int,
d_ffn: int,
num_temporal_heads: int,
dropout: float = 0.,
input_dim: int = 1,
output_dim: int = 1,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.num_nodes = n_features
self.in_steps = n_steps
self.out_steps = n_steps
self.input_dim = input_dim
self.output_dim = output_dim
self.input_embedding_dim = d_input_embed
self.learnable_embedding_dim = d_learnable_embed
model_dim = d_input_embed + d_learnable_embed
self.model_dim = model_dim

self.num_temporal_heads = num_temporal_heads
self.num_layers = n_layers
self.input_proj = nn.Linear(input_dim, self.input_embedding_dim)
self.dim_proj = d_proj

self.learnable_embedding = nn.init.xavier_uniform_(
nn.Parameter(torch.empty(self.in_steps, self.num_nodes, self.learnable_embedding_dim)))

self.readout = MLP(self.model_dim, self.model_dim, output_dim, n_layers=2)

self.attn_layers_t = nn.ModuleList(
[ProjectedAttentionLayer(self.num_nodes, self.dim_proj, self.model_dim, num_temporal_heads,
self.model_dim, dropout)
for _ in range(self.num_layers)])

self.attn_layers_s = nn.ModuleList(
[EmbeddedAttentionLayer(self.model_dim, self.learnable_embedding_dim, d_ffn)
for _ in range(self.num_layers)])

# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)


def forward(self, inputs: dict, training: bool = True) -> dict:
x, missing_mask = inputs["X"], inputs["missing_mask"]

# x: (batch_size, in_steps, num_nodes)
# Note that Imputeformer is designed for Spatial-Temporal data that has the format [B, S, N, C],
# where N is the number of nodes and C is an additional feature dimension,
# We simply add an extra axis here for implementation.
x = x.unsqueeze(-1) # [b s n c]
missing_mask = missing_mask.unsqueeze(-1) # [b s n c]
batch_size = x.shape[0]
# Whiten missing values
x = x * missing_mask
x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)

# Learnable node embedding
node_emb = self.learnable_embedding.expand(batch_size, *self.learnable_embedding.shape)
x = torch.cat([x, node_emb], dim=-1) # (batch_size, in_steps, num_nodes, model_dim)

# Spatial and temporal processing with customized attention layers
x = x.permute(0, 2, 1, 3) # [b n s c]
for att_t, att_s in zip(self.attn_layers_t, self.attn_layers_s):
x = att_t(x)
x = att_s(x, self.learnable_embedding, dim=1)

# Readout
x = x.permute(0, 2, 1, 3) # [b s n c]
reconstruction = self.readout(x)
reconstruction = reconstruction.squeeze(-1) # [b s n]
missing_mask = missing_mask.squeeze(-1) # [b s n]

# Below is the SAITS processing pipeline:
# replace the observed part with values from X
imputed_data = missing_mask * inputs["X"] + (1 - missing_mask) * reconstruction

# ensemble the results as a dictionary for return
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results

22 changes: 22 additions & 0 deletions pypots/imputation/imputeformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Dataset class for the imputation model Imputeformer.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForImputeformer(DatasetForSAITS):
def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit 18d94e4

Please sign in to comment.