Skip to content

Commit

Permalink
modified: __init__.py
Browse files Browse the repository at this point in the history
	modified:   datasets.py
	modified:   dsm_api.py
	modified:   dsm_torch.py
	modified:   losses.py
	modified:   utilities.py
  • Loading branch information
chiragnagpal committed Oct 29, 2020
1 parent 9d150e0 commit 28af769
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 46 deletions.
6 changes: 3 additions & 3 deletions dsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with Deep Survival Machines.
# along with Deep Survival Machines.
# If not, see <https://www.gnu.org/licenses/>.

"""
r"""
Python package `dsm` provides an API to train the Deep Survival Machines
and associated models for problems in survival analysis. The underlying model
is implemented in `pytorch`.
Expand Down Expand Up @@ -137,7 +137,7 @@
<img style="float: right;" width ="200px" src="https://www.cmu.edu/brand/downloads/assets/images/wordmarks-600x600-min.jpg">
<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png">
<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png">
<br><br><br><br><br>
<br><br><br><br><br>
Expand Down
5 changes: 3 additions & 2 deletions dsm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with Deep Survival Machines.
# along with Deep Survival Machines.
# If not, see <https://www.gnu.org/licenses/>.


Expand Down Expand Up @@ -144,7 +144,8 @@ def load_dataset(dataset='SUPPORT', **kwargs):
Parameters
----------
dataset: str
The choice of dataset to load. Currently implemented is 'SUPPORT'.
The choice of dataset to load. Currently implemented is 'SUPPORT'
and 'PBC'.
**kwargs: dict
Dataset specific keyword arguments.
Expand Down
14 changes: 6 additions & 8 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from dsm.utilities import train_dsm

import torch

import numpy as np

class DeepSurvivalMachines():
Expand Down Expand Up @@ -97,11 +96,11 @@ def __call__(self):
print("Distribution Choice:", self.dist)


def fit(self, x, t, e, vsize=0.15,
def fit(self, x, t, e, vsize=0.15,
iters=1, learning_rate=1e-3, batch_size=100,
elbo=True, optimizer="Adam", random_state=100):

"""This method is used to train an instance of the DSM model.
r"""This method is used to train an instance of the DSM model.
Parameters
----------
Expand All @@ -123,10 +122,10 @@ def fit(self, x, t, e, vsize=0.15,
learning is performed on mini-batches of input data. this parameter
specifies the size of each mini-batch.
elbo: bool
Whether to use the Evidence Lower Bound for Optimization.
Whether to use the Evidence Lower Bound for optimization.
Default is True.
optimizer: str
The choice of the gradient based optimization method. One of
The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
random_state: float
random seed that determines how the validation set is chosen.
Expand Down Expand Up @@ -155,7 +154,6 @@ def fit(self, x, t, e, vsize=0.15,
model = DeepSurvivalMachinesTorch(inputdim,
k=self.k,
layers=self.layers,
init=False,
dist=self.dist,
temp=self.temp,
discount=self.discount,
Expand All @@ -177,7 +175,7 @@ def fit(self, x, t, e, vsize=0.15,


def predict_risk(self, x, t):
"""Returns the estimated risk of an event occuring before time \( t \)
r"""Returns the estimated risk of an event occuring before time \( t \)
\( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \).
Parameters
Expand All @@ -201,7 +199,7 @@ def predict_risk(self, x, t):


def predict_survival(self, x, t):
"""Returns the estimated survival probability at time \( t \),
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Parameters
Expand Down
120 changes: 112 additions & 8 deletions dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with Deep Survival Machines.
# along with Deep Survival Machines.
# If not, see <https://www.gnu.org/licenses/>.


Expand All @@ -32,7 +32,7 @@


def create_representation(inputdim, layers, activation):
"""Helper function to generate the representation function for DSM.
r"""Helper function to generate the representation function for DSM.
Deep Survival Machines learns a representation (\ Phi(X) \) for the input
data. This representation is parameterized using a Non Linear Multilayer
Expand Down Expand Up @@ -116,7 +116,7 @@ class DeepSurvivalMachinesTorch(nn.Module):
Default is 1.
"""

def __init__(self, inputdim, k, layers=None, init=False, dist='Weibull',
def __init__(self, inputdim, k, layers=None, dist='Weibull',
temp=1000., discount=1.0, optimizer='Adam'):
super(DeepSurvivalMachinesTorch, self).__init__()

Expand All @@ -134,11 +134,13 @@ def __init__(self, inputdim, k, layers=None, init=False, dist='Weibull',
self.act = nn.SELU()
self.scale = nn.Parameter(-torch.ones(k))
self.shape = nn.Parameter(-torch.ones(k))

elif self.dist == 'LogNormal':
self.act = nn.Tanh()
self.scale = nn.Parameter(torch.ones(k))
self.shape = nn.Parameter(torch.ones(k))
else:
raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
' yet.')

self.embedding = create_representation(inputdim, layers, 'ReLU6')

Expand All @@ -152,10 +154,6 @@ def __init__(self, inputdim, k, layers=None, init=False, dist='Weibull',
self.scaleg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))
self.shapeg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))

if init is not False:
self.shape.data.fill_(init[0])
self.scale.data.fill_(init[1])

def forward(self, x):
"""The forward function that is called when data is passed through DSM.
Expand All @@ -171,3 +169,109 @@ def forward(self, x):
def get_shape_scale(self):
return(self.shape,
self.scale)

class DeepRecurrentSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
"""A Torch implementation of Deep Recurrent Survival Machines model.
This is an implementation of Deep Recurrent Survival Machines model
in torch. It inherits from `DeepSurvivalMachinesTorch` and replaces the
input representation learning MLP with an LSTM or RNN, the parameters of the
underlying distributions and the forward function which is called whenever
data is passed to the module. Each of the parameters are nn.Parameters and
torch automatically keeps track and computes gradients for them.
.. warning::
Not designed to be used directly.
Please use the API inferface `dsm.dsm_api.DeepRecurrentSurvivalMachines`!!
Parameters
----------
inputdim: int
Dimensionality of the input features.
k: int
The number of underlying parametric distributions.
layers: int
The number of hidden layers in the LSTM or RNN cell.
hidden: int
The number of neurons in each hidden layer.
init: tuple
A tuple for initialization of the parameters for the underlying
distributions. (shape, scale).
dist: str
Choice of the underlying survival distributions.
One of 'Weibull', 'LogNormal'.
Default is 'Weibull'.
temp: float
The logits for the gate are rescaled with this value.
Default is 1000.
discount: float
a float in [0,1] that determines how to discount the tail bias
from the uncensored instances.
Default is 1.
"""

def __init__(self, inputdim, k, typ='LSTM', layers=1,
hidden=None, dist='Weibull',
temp=1000., discount=1.0, optimizer='Adam'):
super(DeepSurvivalMachinesTorch, self).__init__()

self.k = k
self.dist = dist
self.temp = float(temp)
self.discount = float(discount)
self.optimizer = optimizer
self.hidden = hidden
self.layers = layers
self.typ = typ

if self.dist == 'Weibull':
self.act = nn.SELU()
self.scale = nn.Parameter(-torch.ones(k))
self.shape = nn.Parameter(-torch.ones(k))
elif self.dist == 'LogNormal':
self.act = nn.Tanh()
self.scale = nn.Parameter(torch.ones(k))
self.shape = nn.Parameter(torch.ones(k))
else:
raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
' yet.')

self.gate = nn.Sequential(nn.Linear(hidden, k, bias=False))
self.scaleg = nn.Sequential(nn.Linear(hidden, k, bias=True))
self.shapeg = nn.Sequential(nn.Linear(hidden, k, bias=True))

if self.typ == 'LSTM':
self.embedding = nn.LSTM(inputdim, hidden, layers,
bias=False, batch_first=True)
if self.typ == 'RNN':
self.embedding = nn.RNN(inputdim, hidden, layers,
bias=False, batch_first=True)

#self.embedding = nn.ReLU6(self.embedding)


def forward(self, x):
"""The forward function that is called when data is passed through DSM.
Note: As compared to DSM, the input data for DRSM is a tensor. The forward
function involves unpacking the tensor in-order to directly use the
DSM loss functions.
Args:
x:
a torch.tensor of the input features.
"""
x = x.detach().clone()
inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1)
x[torch.isnan(x)] = 0
xrep, _ = self.embedding(x)
xrep = xrep.contiguous().view(-1, self.hidden)
xrep = xrep[inputmask]
#xrep = nn.ReLU6()(xrep)
return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
self.gate(xrep)/self.temp)

def get_shape_scale(self):
return(self.shape,
self.scale)
7 changes: 4 additions & 3 deletions dsm/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,16 @@ def _conditional_lognormal_loss(model, x, t, e, elbo=True):
cens = np.where(e.cpu().data.numpy() == 0)[0]
ll = lossf[uncens].sum() + alpha*losss[cens].sum()

return -ll/x.shape[0]
return -ll.mean()


def _conditional_weibull_loss(model, x, t, e, elbo=True):

alpha = model.discount
shape, scale, logits = model.forward(x)

#print (shape, scale, logits)

k_ = shape
b_ = scale

Expand Down Expand Up @@ -192,7 +194,7 @@ def _conditional_weibull_loss(model, x, t, e, elbo=True):
cens = np.where(e.cpu().data.numpy() == 0)[0]
ll = lossf[uncens].sum() + alpha*losss[cens].sum()

return -ll/x.shape[0]
return -ll.mean()


def conditional_loss(model, x, t, e, elbo=True):
Expand Down Expand Up @@ -280,6 +282,5 @@ def predict_cdf(model, x, t_horizon):
torch.no_grad()
if model.dist == 'Weibull':
return _weibull_cdf(model, x, t_horizon)

if model.dist == 'LogNormal':
return _lognormal_cdf(model, x, t_horizon)
Loading

0 comments on commit 28af769

Please sign in to comment.