From ddb01c59e245c7e91d7fc0ca8e45b908feac680d Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Mon, 25 Oct 2021 12:28:34 -0400 Subject: [PATCH] new file: dsm/contrib/__init__.py new file: dsm/contrib/dcm_api.py new file: dsm/contrib/dcm_torch.py new file: dsm/contrib/dcm_utilities.py modified: dsm/dsm_api.py modified: dsm/dsm_torch.py modified: dsm/utilities.py --- dsm/contrib/__init__.py | 61 +++++++ dsm/contrib/dcm_api.py | 181 ++++++++++++++++++++ dsm/contrib/dcm_torch.py | 57 +++++++ dsm/contrib/dcm_utilities.py | 309 +++++++++++++++++++++++++++++++++++ dsm/dsm_api.py | 32 ++-- dsm/dsm_torch.py | 14 +- dsm/utilities.py | 1 + 7 files changed, 628 insertions(+), 27 deletions(-) create mode 100644 dsm/contrib/__init__.py create mode 100644 dsm/contrib/dcm_api.py create mode 100644 dsm/contrib/dcm_torch.py create mode 100644 dsm/contrib/dcm_utilities.py diff --git a/dsm/contrib/__init__.py b/dsm/contrib/__init__.py new file mode 100644 index 0000000..c2165a2 --- /dev/null +++ b/dsm/contrib/__init__.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# MIT License + +# Copyright (c) 2020 Carnegie Mellon University, Auton Lab + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +r""" +`dsm` includes extended functionality for survival analysis as part +of `dsm.contrib`. + +Contributed Modules +-------------------- +This submodule incorporates contributed survival analysis methods. + + +Deep Cox Mixtures +------------------ + +The Cox Mixture involves the assumption that the survival function +of the individual to be a mixture of K Cox Models. Conditioned on each +subgroup Z=k; the PH assumptions are assumed to hold and the baseline +hazard rates is determined non-parametrically using an spline-interpolated +Breslow's estimator. + +For full details on Deep Cox Mixture, refer to the paper [1]. + +References +---------- +[1] Deep Cox Mixtures +for Survival Regression. Machine Learning in Health Conference (2021) + +``` + @article{nagpal2021dcm, + title={Deep Cox mixtures for survival regression}, + author={Nagpal, Chirag and Yadlowsky, Steve and Rostamzadeh, Negar and Heller, Katherine}, + journal={arXiv preprint arXiv:2101.06536}, + year={2021} + } +``` + +""" + +from dsm.contrib.dcm_api import DeepCoxMixtures \ No newline at end of file diff --git a/dsm/contrib/dcm_api.py b/dsm/contrib/dcm_api.py new file mode 100644 index 0000000..6863bf7 --- /dev/null +++ b/dsm/contrib/dcm_api.py @@ -0,0 +1,181 @@ + +import torch +import numpy as np + +from dsm.contrib.dcm_torch import DeepCoxMixturesTorch +from dsm.contrib.dcm_utilities import train_dcm, predict_survival + + +class DeepCoxMixtures(): + """A Deep Cox Mixture model. + + This is the main interface to a Deep Cox Mixture model. + A model is instantiated with approporiate set of hyperparameters and + fit on numpy arrays consisting of the features, event/censoring times + and the event/censoring indicators. + + For full details on Deep Cox Mixture, refer to the paper [1]. + + References + ---------- + [1] Deep Cox Mixtures + for Survival Regression. Machine Learning in Health Conference (2021) + + Parameters + ---------- + k: int + The number of underlying Cox distributions. + layers: list + A list of integers consisting of the number of neurons in each + hidden layer. + Example + ------- + >>> from dsm.contrib import DeepCoxMixtures + >>> model = DeepCoxMixtures() + >>> model.fit(x, t, e) + + """ + def __init__(self, k=3, layers=None, distribution="Weibull", + temp=1000., discount=1.0): + self.k = k + self.layers = layers + self.dist = distribution + self.temp = temp + self.discount = discount + self.fitted = False + + def __call__(self): + if self.fitted: + print("A fitted instance of the Deep Cox Mixtures model") + else: + print("An unfitted instance of the Deep Cox Mixtures model") + + print("Number of underlying cox distributions (k):", self.k) + print("Hidden Layers:", self.layers) + + def _preprocess_test_data(self, x): + return torch.from_numpy(x).float() + + def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): + + idx = list(range(x.shape[0])) + np.random.seed(random_state) + np.random.shuffle(idx) + x_train, t_train, e_train = x[idx], t[idx], e[idx] + + x_train = torch.from_numpy(x_train).float() + t_train = torch.from_numpy(t_train).float() + e_train = torch.from_numpy(e_train).float() + + if val_data is None: + + vsize = int(vsize*x_train.shape[0]) + x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] + + x_train = x_train[:-vsize] + t_train = t_train[:-vsize] + e_train = e_train[:-vsize] + + else: + + x_val, t_val, e_val = val_data + + x_val = torch.from_numpy(x_val).float() + t_val = torch.from_numpy(t_val).float() + e_val = torch.from_numpy(e_val).float() + + return (x_train, t_train, e_train, x_val, t_val, e_val) + + def _gen_torch_model(self, inputdim, optimizer): + """Helper function to return a torch model.""" + return DeepCoxMixturesTorch(inputdim, + k=self.k, + layers=self.layers, + optimizer=optimizer) + + def fit(self, x, t, e, vsize=0.15, val_data=None, + iters=1, learning_rate=1e-3, batch_size=100, + optimizer="Adam", random_state=100): + + r"""This method is used to train an instance of the DSM model. + + Parameters + ---------- + x: np.ndarray + A numpy array of the input features, \( x \). + t: np.ndarray + A numpy array of the event/censoring times, \( t \). + e: np.ndarray + A numpy array of the event/censoring indicators, \( \delta \). + \( \delta = 1 \) means the event took place. + vsize: float + Amount of data to set aside as the validation set. + val_data: tuple + A tuple of the validation dataset. If passed vsize is ignored. + iters: int + The maximum number of training iterations on the training dataset. + learning_rate: float + The learning rate for the `Adam` optimizer. + batch_size: int + learning is performed on mini-batches of input data. this parameter + specifies the size of each mini-batch. + optimizer: str + The choice of the gradient based optimization method. One of + 'Adam', 'RMSProp' or 'SGD'. + random_state: float + random seed that determines how the validation set is chosen. + + """ + + processed_data = self._preprocess_training_data(x, t, e, + vsize, val_data, + random_state) + x_train, t_train, e_train, x_val, t_val, e_val = processed_data + + #Todo: Change this somehow. The base design shouldn't depend on child + + inputdim = x_train.shape[-1] + + model = self._gen_torch_model(inputdim, optimizer) + + model, _ = train_dcm(model, + (x_train, t_train, e_train), + (x_val, t_val, e_val), + epochs=iters, + lr=learning_rate, + bs=batch_size, + return_losses=True, + smoothing_factor=None, + use_posteriors=True,) + + self.torch_model = (model[0].eval(), model[1]) + self.fitted = True + + return self + + + def predict_survival(self, x, t): + r"""Returns the estimated survival probability at time \( t \), + \( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \). + + Parameters + ---------- + x: np.ndarray + A numpy array of the input features, \( x \). + t: list or float + a list or float of the times at which survival probability is + to be computed + Returns: + np.array: numpy array of the survival probabilites at each time in t. + + """ + x = self._preprocess_test_data(x) + if not isinstance(t, list): + t = [t] + if self.fitted: + scores = predict_survival(self.torch_model, x, t) + return scores + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_survival`.") \ No newline at end of file diff --git a/dsm/contrib/dcm_torch.py b/dsm/contrib/dcm_torch.py new file mode 100644 index 0000000..9dc1d65 --- /dev/null +++ b/dsm/contrib/dcm_torch.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn + +import numpy as np + +from scipy.interpolate import UnivariateSpline +from sksurv.linear_model.coxph import BreslowEstimator + +import time +from tqdm import tqdm + +from dsm.dsm_torch import create_representation + +class DeepCoxMixturesTorch(nn.Module): + """PyTorch model definition of the Deep Cox Mixture Survival Model. + + The Cox Mixture involves the assumption that the survival function + of the individual to be a mixture of K Cox Models. Conditioned on each + subgroup Z=k; the PH assumptions are assumed to hold and the baseline + hazard rates is determined non-parametrically using an spline-interpolated + Breslow's estimator. + """ + + def _init_dcm_layers(self, lastdim): + + self.gate = torch.nn.Linear(lastdim, self.k, bias=False) + self.expert = torch.nn.Linear(lastdim, self.k, bias=False) + + def __init__(self, inputdim, k, layers=None, optimizer='Adam'): + + super(DeepCoxMixturesTorch, self).__init__() + + if not isinstance(k, int): + raise ValueError(f'k must be int, but supplied k is {type(k)}') + + self.k = k + self.optimizer = optimizer + + if layers is None: layers = [] + self.layers = layers + + if len(layers) == 0: lastdim = inputdim + else: lastdim = layers[-1] + + self._init_dcm_layers(lastdim) + self.embedding = create_representation(inputdim, layers, 'ReLU6') + + def forward(self, x): + + x = self.embedding(x) + + log_hazard_ratios = torch.clamp(self.expert(x), min=-7e-1, max=7e-1) + #log_hazard_ratios = self.expert(x) + #log_hazard_ratios = torch.nn.Tanh()(self.expert(x)) + log_gate_prob = torch.nn.LogSoftmax(dim=1)(self.gate(x)) + + return log_gate_prob, log_hazard_ratios \ No newline at end of file diff --git a/dsm/contrib/dcm_utilities.py b/dsm/contrib/dcm_utilities.py new file mode 100644 index 0000000..d332e30 --- /dev/null +++ b/dsm/contrib/dcm_utilities.py @@ -0,0 +1,309 @@ + +import logging +from matplotlib.pyplot import get + +import torch +import numpy as np + +from scipy.interpolate import UnivariateSpline +from sksurv.linear_model.coxph import BreslowEstimator + +from sklearn.utils import shuffle + + +from tqdm import tqdm + + +from dsm.utilities import get_optimizer + +def randargmax(b,**kw): + """ a random tie-breaking argmax""" + return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw) + +def partial_ll_loss(lrisks, tb, eb, eps=1e-2): + + tb = tb + eps*np.random.random(len(tb)) + sindex = np.argsort(-tb) + + tb = tb[sindex] + eb = eb[sindex] + + lrisks = lrisks[sindex] # lrisks = tf.gather(lrisks, sindex) + # lrisksdenom = tf.math.cumulative_logsumexp(lrisks) + lrisksdenom = torch.logcumsumexp(lrisks, dim = 0) + + plls = lrisks - lrisksdenom + pll = plls[eb == 1] + + pll = torch.sum(pll) # pll = tf.reduce_sum(pll) + + return -pll + +def fit_spline(t, surv, s=1e-4): + return UnivariateSpline(t, surv, s=s, ext=3, k=1) + +def smooth_bl_survival(breslow, smoothing_factor): + + blsurvival = breslow.baseline_survival_ + x, y = blsurvival.x, blsurvival.y + return fit_spline(x, y, s=smoothing_factor) + +def get_probability_(lrisks, ts, spl): + risks = np.exp(lrisks) + s0ts = (-risks)*(spl(ts)**(risks-1)) + return s0ts * spl.derivative()(ts) + +def get_survival_(lrisks, ts, spl): + risks = np.exp(lrisks) + return spl(ts)**risks + +def get_probability(lrisks, breslow_splines, t): + psurv = [] + for i in range(lrisks.shape[1]): + p = get_probability_(lrisks[:, i], t, breslow_splines[i]) + psurv.append(p) + psurv = np.array(psurv).T + return psurv + +def get_survival(lrisks, breslow_splines, t): + psurv = [] + for i in range(lrisks.shape[1]): + p = get_survival_(lrisks[:, i], t, breslow_splines[i]) + psurv.append(p) + psurv = np.array(psurv).T + return psurv + +def get_posteriors(probs): + probs_ = probs+1e-8 + return probs-torch.logsumexp(probs, dim=1).reshape(-1,1) + +def get_hard_z(gates_prob): + return torch.argmax(gates_prob, dim=1) + +def sample_hard_z(gates_prob): + return torch.multinomial(gates_prob.exp(), num_samples=1)[:, 0] + +def repair_probs(probs): + probs[torch.isnan(probs)] = -10 + probs[probs<-10] = -10 + return probs + +def get_likelihood(model, breslow_splines, x, t, e, log=False): + + # Function requires numpy/torch + + gates, lrisks = model(x) + lrisks = lrisks.numpy() + e, t = e.numpy(), t.numpy() + + survivals = get_survival(lrisks, breslow_splines, t) + probability = get_probability(lrisks, breslow_splines, t) + + event_probs = np.array([survivals, probability]) + event_probs = event_probs[e.astype('int'), range(len(e)), :] + #event_probs[event_probs<1e-10] = 1e-10 + probs = gates+np.log(event_probs) + # else: + # gates_prob = torch.nn.Softmax(dim = 1)(gates) + # probs = gates_prob*event_probs + return probs + +def q_function(model, x, t, e, posteriors, typ='soft'): + + if typ == 'hard': z = get_hard_z(posteriors) + else: z = sample_hard_z(posteriors) + + gates, lrisks = model(x) + + k = model.k + + loss = 0 + for i in range(k): + lrisks_ = lrisks[z == i][:, i] + loss += partial_ll_loss(lrisks_, t[z == i], e[z == i]) + + #log_smax_loss = -torch.nn.LogSoftmax(dim=1)(gates) # tf.nn.log_softmax(gates) + + gate_loss = posteriors.exp()*gates + gate_loss = -torch.sum(gate_loss) + loss+=gate_loss + + return loss + +def e_step(model, breslow_splines, x, t, e, log=False): + + # TODO: Do this in `Log Space` + if breslow_splines is None: + # If Breslow splines are not available, like in the first + # iteration of learning, we randomly compute posteriors. + posteriors = get_posteriors(torch.rand(len(x), model.k)) + pass + else: + probs = get_likelihood(model, breslow_splines, x, t, e) + posteriors = get_posteriors(repair_probs(probs)) + + return posteriors + +def m_step(model, optimizer, x, t, e, posteriors, typ='soft'): + + optimizer.zero_grad() + loss = q_function(model, x, t, e, posteriors, typ) + loss.backward() + optimizer.step() + + return float(loss) + +def fit_breslow(model, x, t, e, posteriors=None, smoothing_factor=1e-4, typ='soft'): + + # TODO: Make Breslow in Torch !!! + + gates, lrisks = model(x) + + lrisks = lrisks.numpy() + + e = e.numpy() + t = t.numpy() + + if posteriors is None: z_probs = gates + else: z_probs = posteriors + + if typ == 'soft': z = sample_hard_z(z_probs) + else: z = get_hard_z(z_probs) + + breslow_splines = {} + for i in range(model.k): + breslowk = BreslowEstimator().fit(lrisks[:, i][z==i], e[z==i], t[z==i]) + breslow_splines[i] = smooth_bl_survival(breslowk, smoothing_factor=smoothing_factor) + + return breslow_splines + + +def train_step(model, x, t, e, breslow_splines, optimizer, + bs=256, seed=100, typ='soft', use_posteriors=False, + update_splines_after=10, smoothing_factor=1e-4): + + x, t, e = shuffle(x, t, e, random_state=seed) + + n = x.shape[0] + + batches = (n // bs) + 1 + + epoch_loss = 0 + for i in range(batches): + + xb = x[i*bs:(i+1)*bs] + tb = t[i*bs:(i+1)*bs] + eb = e[i*bs:(i+1)*bs] + #ab = a[i*bs:(i+1)*bs] + + # E-Step !!! + # e_step_start = time.time() + with torch.no_grad(): + posteriors = e_step(model, breslow_splines, xb, tb, eb) + + torch.enable_grad() + loss = m_step(model, optimizer, xb, tb, eb, posteriors, typ=typ) + + with torch.no_grad(): + try: + if i%update_splines_after == 0: + if use_posteriors: + posteriors = e_step(model, breslow_splines, x, t, e) + breslow_splines = fit_breslow(model, x, t, e, posteriors=posteriors, typ='soft') + else: + breslow_splines = fit_breslow(model, x, t, e, posteriors=None, typ='soft') + # print(f'Duration of Breslow spline estimation: {time.time() - estimate_breslow_start}') + except Exception as exce: + print("Exception!!!:", exce) + logging.warning("Couldn't fit splines, reusing from previous epoch") + epoch_loss += loss + #print (epoch_loss/n) + return breslow_splines + + +def test_step(model, x, t, e, breslow_splines, loss='q', typ='soft'): + + if loss == 'q': + with torch.no_grad(): + posteriors = e_step(model, breslow_splines, x, t, e) + loss = q_function(model, x, t, e, posteriors, typ=typ) + + return float(loss/x.shape[0]) + + +def train_dcm(model, train_data, val_data, epochs=50, + patience=3, vloss='q', bs=256, typ='soft', lr=1e-3, + use_posteriors=True, debug=False, random_state=0, + return_losses=False, update_splines_after=10, + smoothing_factor=1e-4): + + torch.manual_seed(random_state) + np.random.seed(random_state) + + if val_data is None: + val_data = train_data + + xt, tt, et = train_data + xv, tv, ev = val_data + + unique_times = np.unique(tt) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer = get_optimizer(model, lr) + + + valc = np.inf + patience_ = 0 + + breslow_splines = None + + losses = [] + + for epoch in tqdm(range(epochs)): + + # train_step_start = time.time() + breslow_splines = train_step(model, xt, tt, et, breslow_splines, + optimizer, bs=bs, seed=epoch, typ=typ, + use_posteriors=use_posteriors, + update_splines_after=update_splines_after, + smoothing_factor=smoothing_factor) + # print(f'Duration of train-step: {time.time() - train_step_start}') + # test_step_start = time.time() + valcn = test_step(model, xv, tv, ev, breslow_splines, loss=vloss, typ=typ) + # print(f'Duration of test-step: {time.time() - test_step_start}') + + losses.append(valcn) + + if epoch % 1 == 0: + if debug: print(patience_, epoch, valcn) + + if valcn > valc: patience_ += 1 + else: patience_ = 0 + + if patience_ == patience: + if return_losses: return (model, breslow_splines), losses + else: return (model, breslow_splines) + + valc = valcn + + if return_losses: return (model, breslow_splines), losses + else: return (model, breslow_splines) + + +def predict_survival(model, x, t): + + if isinstance(t, int) or isinstance(t, float): t = [t] + + model, breslow_splines = model + gates, lrisks = model(x) + + lrisks = lrisks.detach().numpy() + gate_probs = torch.exp(gates).detach().numpy() + + predictions = [] + + for t_ in t: + expert_output = get_survival(lrisks, breslow_splines, t_) + predictions.append((gate_probs*expert_output).sum(axis=1)) + + return np.array(predictions).T diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 57a3623..59da5a7 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -108,9 +108,9 @@ def fit(self, x, t, e, vsize=0.15, val_data=None, """ - processed_data = self._prepocess_training_data(x, t, e, - vsize, val_data, - random_state) + processed_data = self._preprocess_training_data(x, t, e, + vsize, val_data, + random_state) x_train, t_train, e_train, x_val, t_val, e_val = processed_data #Todo: Change this somehow. The base design shouldn't depend on child @@ -133,7 +133,7 @@ def fit(self, x, t, e, vsize=0.15, val_data=None, self.torch_model = model.eval() self.fitted = True - return self + return self def compute_nll(self, x, t, e): @@ -158,7 +158,7 @@ def compute_nll(self, x, t, e): raise Exception("The model has not been fitted yet. Please fit the " + "model using the `fit` method on some training data " + "before calling `_eval_nll`.") - processed_data = self._prepocess_training_data(x, t, e, 0, None, 0) + processed_data = self._preprocess_training_data(x, t, e, 0, None, 0) _, _, _, x_val, t_val, e_val = processed_data x_val, t_val, e_val = x_val,\ _reshape_tensor_with_nans(t_val),\ @@ -170,10 +170,10 @@ def compute_nll(self, x, t, e): risk=str(r+1)).detach().numpy()) return loss - def _prepocess_test_data(self, x): + def _preprocess_test_data(self, x): return torch.from_numpy(x) - def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): + def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): idx = list(range(x.shape[0])) np.random.seed(random_state) @@ -201,8 +201,7 @@ def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() - return (x_train, t_train, e_train, - x_val, t_val, e_val) + return (x_train, t_train, e_train, x_val, t_val, e_val) def predict_mean(self, x, risk=1): @@ -218,7 +217,7 @@ def predict_mean(self, x, risk=1): """ if self.fitted: - x = self._prepocess_test_data(x) + x = self._preprocess_test_data(x) scores = losses.predict_mean(self.torch_model, x, risk=str(risk)) return scores else: @@ -264,7 +263,7 @@ def predict_survival(self, x, t, risk=1): np.array: numpy array of the survival probabilites at each time in t. """ - x = self._prepocess_test_data(x) + x = self._preprocess_test_data(x) if not isinstance(t, list): t = [t] if self.fitted: @@ -290,7 +289,7 @@ def predict_pdf(self, x, t, risk=1): np.array: numpy array of the estimated pdf at each time in t. """ - x = self._prepocess_test_data(x) + x = self._preprocess_test_data(x) if not isinstance(t, list): t = [t] if self.fitted: @@ -302,8 +301,6 @@ def predict_pdf(self, x, t, risk=1): "before calling `predict_survival`.") - - class DeepSurvivalMachines(DSMBase): """A Deep Survival Machines model. @@ -396,10 +393,10 @@ def _gen_torch_model(self, inputdim, optimizer, risks): typ=self.typ, risks=risks) - def _prepocess_test_data(self, x): + def _preprocess_test_data(self, x): return torch.from_numpy(_get_padded_features(x)) - def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): + def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): """RNNs require different preprocessing for variable length sequences""" idx = list(range(x.shape[0])) @@ -438,8 +435,7 @@ def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() - return (x_train, t_train, e_train, - x_val, t_val, e_val) + return (x_train, t_train, e_train, x_val, t_val, e_val) class DeepConvolutionalSurvivalMachines(DSMBase): diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index ce19d91..e9717ca 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -184,17 +184,13 @@ def __init__(self, inputdim, k, layers=None, dist='Weibull', self.optimizer = optimizer self.risks = risks - if layers is None: - layers = [] + if layers is None: layers = [] self.layers = layers - if len(layers) == 0: - lastdim = inputdim - else: - lastdim = layers[-1] + if len(layers) == 0: lastdim = inputdim + else: lastdim = layers[-1] self._init_dsm_layers(lastdim) - self.embedding = create_representation(inputdim, layers, 'ReLU6') @@ -213,8 +209,7 @@ def forward(self, x, risk='1'): self.gate[risk](xrep)/self.temp) def get_shape_scale(self, risk='1'): - return(self.shape[risk], - self.scale[risk]) + return(self.shape[risk], self.scale[risk]) class DeepRecurrentSurvivalMachinesTorch(DeepSurvivalMachinesTorch): """A Torch implementation of Deep Recurrent Survival Machines model. @@ -261,6 +256,7 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1, hidden=None, dist='Weibull', temp=1000., discount=1.0, optimizer='Adam', risks=1): + super(DeepSurvivalMachinesTorch, self).__init__() self.k = k diff --git a/dsm/utilities.py b/dsm/utilities.py index 2c5396a..6a8d4cc 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -207,3 +207,4 @@ def train_dsm(model, gc.collect() return model, i + \ No newline at end of file