Skip to content

Commit

Permalink
modified: dsm/__init__.py
Browse files Browse the repository at this point in the history
	modified:   dsm/dsm_torch.py
	modified:   dsm/utilities.py
  • Loading branch information
chiragnagpal committed Feb 3, 2021
1 parent 65ad4e2 commit c03faf9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
1 change: 1 addition & 0 deletions dsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,4 @@
from dsm.dsm_api import DeepSurvivalMachines
from dsm.dsm_api import DeepConvolutionalSurvivalMachines
from dsm.dsm_api import DeepRecurrentSurvivalMachines
from dsm.dsm_api import DeepCNNRNNSurvivalMachines
20 changes: 11 additions & 9 deletions dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ class DeepCNNRNNSurvivalMachinesTorch(DeepRecurrentSurvivalMachinesTorch):
Parameters
----------
inputdim: int
Dimensionality of the input features.
inputdim: tuple
Dimensionality of the input features. (height, width)
k: int
The number of underlying parametric distributions.
layers: int
Expand Down Expand Up @@ -510,16 +510,18 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1,

self._init_dsm_layers(hidden)

self.cnn = create_conv_representation(inputdim, hidden)

if self.typ == 'LSTM':
self.embedding = nn.LSTM(inputdim, hidden, layers,
bias=False, batch_first=True)
self.rnn = nn.LSTM(hidden, hidden, layers,
bias=False, batch_first=True)
if self.typ == 'RNN':
self.embedding = nn.RNN(inputdim, hidden, layers,
bias=False, batch_first=True,
nonlinearity='relu')
self.rnn = nn.RNN(hidden, hidden, layers,
bias=False, batch_first=True,
nonlinearity='relu')
if self.typ == 'GRU':
self.embedding = nn.GRU(inputdim, hidden, layers,
bias=False, batch_first=True)
self.rnn = nn.GRU(hidden, hidden, layers,
bias=False, batch_first=True)

def forward(self, x, risk='1'):
"""The forward function that is called when data is passed through DSM.
Expand Down
9 changes: 8 additions & 1 deletion dsm/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,14 @@ def train_dsm(model,
costs.append(float(valid_loss))
dics.append(deepcopy(model.state_dict()))

if (costs[-1] >= oldcost) is True:
if costs[-1] >= oldcost:
if patience == 2:
minm = np.argmin(costs)
model.load_state_dict(dics[minm])

del dics
gc.collect()

return model, i
else:
patience += 1
Expand All @@ -198,4 +199,10 @@ def train_dsm(model,

oldcost = costs[-1]

minm = np.argmin(costs)
model.load_state_dict(dics[minm])

del dics
gc.collect()

return model, i

0 comments on commit c03faf9

Please sign in to comment.