Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSAI Pipeline #534

Merged
merged 28 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pypots/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
# License: BSD-3-Clause

from .brits import BRITS
from .csai import CSAI
from .grud import GRUD
from .raindrop import Raindrop

__all__ = [
"CSAI",
"BRITS",
"GRUD",
"Raindrop",
Expand Down
20 changes: 20 additions & 0 deletions pypots/classification/csai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
The package including the modules of CSAI.

Refer to the paper
`Linglong Qian, Zina Ibrahim, Hugh Logan Ellis, Ao Zhang, Yuezhou Zhang, Tao Wang, Richard Dobson.
Knowledge Enhanced Conditional Imputation for Healthcare Time-series.
In Arxiv, 2024.
<https://arxiv.org/abs/2312.16713>`_

Notes
-----
This implementation is inspired by the official one the official implementation https://github.com/LinglongQian/CSAI.

"""

from .model import CSAI

__all__ = [
"CSAI",
]
123 changes: 123 additions & 0 deletions pypots/classification/csai/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""

"""

# Created by Linglong Qian, Joseph Arul Raj <linglong.qian@kcl.ac.uk, joseph_arul_raj@kcl.ac.uk>
# License: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...nn.modules.csai import BackboneBCSAI

# class DiceBCELoss(nn.Module):
# def __init__(self, weight=None, size_average=True):
# super(DiceBCELoss, self).__init__()
# self.bcelogits = nn.BCEWithLogitsLoss()

# def forward(self, y_score, y_out, targets, smooth=1):

# #comment out if your model contains a sigmoid or equivalent activation layer
# # inputs = F.sigmoid(inputs)

# #flatten label and prediction tensors
# BCE = self.bcelogits(y_out, targets)

# y_score = y_score.view(-1)
# targets = targets.view(-1)
# intersection = (y_score * targets).sum()
# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth)

# Dice_BCE = BCE + dice_loss

# return BCE, Dice_BCE


class _BCSAI(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
rnn_hidden_size: int,
imputation_weight: float,
consistency_weight: float,
classification_weight: float,
n_classes: int,
step_channels: int,
dropout: float = 0.5,
intervals=None,
):
super().__init__()
self.n_steps = n_steps
self.n_features = n_features
self.rnn_hidden_size = rnn_hidden_size
self.imputation_weight = imputation_weight
self.consistency_weight = consistency_weight
self.classification_weight = classification_weight
self.n_classes = n_classes
self.step_channels = step_channels
self.intervals = intervals

# create models
self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals)
self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.imputer = nn.Linear(self.rnn_hidden_size, n_features)
self.dropout = nn.Dropout(dropout)

def forward(self, inputs: dict, training: bool = True) -> dict:

(
imputed_data,
f_reconstruction,
b_reconstruction,
f_hidden_states,
b_hidden_states,
consistency_loss,
reconstruction_loss,
) = self.model(inputs)

results = {
"imputed_data": imputed_data,
}

f_logits = self.f_classifier(self.dropout(f_hidden_states))
b_logits = self.b_classifier(self.dropout(b_hidden_states))

# f_prediction = torch.sigmoid(f_logits)
# b_prediction = torch.sigmoid(b_logits)

f_prediction = torch.softmax(f_logits, dim=1)
b_prediction = torch.softmax(b_logits, dim=1)
classification_pred = (f_prediction + b_prediction) / 2

results = {
"imputed_data": imputed_data,
"classification_pred": classification_pred,
}

# if in training mode, return results with losses
if training:
# criterion = DiceBCELoss().to(imputed_data.device)
results["consistency_loss"] = consistency_loss
results["reconstruction_loss"] = reconstruction_loss
# print(inputs["labels"].unsqueeze(1))
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["labels"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["labels"])
# f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float())
# b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float())
classification_loss = (f_classification_loss + b_classification_loss)

loss = (
self.consistency_weight * consistency_loss +
self.imputation_weight * reconstruction_loss +
self.classification_weight * classification_loss
)

results["loss"] = loss
results["classification_loss"] = classification_loss
results["f_reconstruction"] = f_reconstruction
results["b_reconstruction"] = b_reconstruction

return results
39 changes: 39 additions & 0 deletions pypots/classification/csai/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""

"""

# Created by Joseph Arul Raj <joseph_arul_raj@kcl.ac.uk>
# License: BSD-3-Clause

from typing import Union
from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation



class DatasetForCSAI(DatasetForCSAI_Imputation):
def __init__(self,
data: Union[dict, str],
file_type: str = "hdf5",
return_y: bool = True,
removal_percent: float = 0.0,
increase_factor: float = 0.1,
compute_intervals: bool = False,
replacement_probabilities = None,
normalise_mean : list = [],
normalise_std: list = [],
training: bool = True
):
super().__init__(
data=data,
return_X_ori=False,
return_y=return_y,
file_type=file_type,
removal_percent=removal_percent,
increase_factor=increase_factor,
compute_intervals=compute_intervals,
replacement_probabilities=replacement_probabilities,
normalise_mean=normalise_mean,
normalise_std=normalise_std,
training=training
)

Loading
Loading