Skip to content

Commit

Permalink
modified: dsm_torch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Feb 28, 2022
1 parent bfba055 commit 4aec0d1
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions auton_survival/models/dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import torch
from torch import nn
import numpy as np

__pdoc__ = {}

Expand All @@ -46,7 +45,7 @@
__pdoc__[clsn+'.'+membr] = False


def create_representation(inputdim, layers, activation):
def create_representation(inputdim, layers, activation, bias=False):
r"""Helper function to generate the representation function for DSM.
Deep Survival Machines learns a representation (\ Phi(X) \) for the input
Expand Down Expand Up @@ -85,7 +84,7 @@ def create_representation(inputdim, layers, activation):
prevdim = inputdim

for hidden in layers:
modules.append(nn.Linear(prevdim, hidden, bias=False))
modules.append(nn.Linear(prevdim, hidden, bias=bias))
modules.append(act)
prevdim = hidden

Expand Down Expand Up @@ -257,7 +256,7 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1,
hidden=None, dist='Weibull',
temp=1000., discount=1.0,
optimizer='Adam', risks=1):

super(DeepSurvivalMachinesTorch, self).__init__()

self.k = k
Expand Down

0 comments on commit 4aec0d1

Please sign in to comment.