Skip to content

Commit

Permalink
modified: __init__.py
Browse files Browse the repository at this point in the history
	modified:   cmhe_torch.py
	modified:   cmhe_utilities.py
  • Loading branch information
chiragnagpal committed May 27, 2022
1 parent dc2d411 commit 04524b3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
6 changes: 6 additions & 0 deletions auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class DeepCoxMixturesHeterogenousEffects:
layers: list
A list of integers consisting of the number of neurons in each
hidden layer.
gate_l2_penalty: float
Strength of the l2 penalty term for the gate layers.
Higher means stronger regularization.
random_seed: int
Controls the reproducibility of called functions.
Expand All @@ -122,6 +125,7 @@ class DeepCoxMixturesHeterogenousEffects:

def __init__(self, k, g, layers=None, gamma=100,
smoothing_factor=1e-4,
gate_l2_penalty=1e-4,
random_seed=0):

self.k = k
Expand All @@ -130,6 +134,7 @@ def __init__(self, k, g, layers=None, gamma=100,
self.fitted = False
self.gamma = gamma
self.smoothing_factor = smoothing_factor
self.gate_l2_penalty = gate_l2_penalty
self.random_seed = random_seed

def __call__(self):
Expand Down Expand Up @@ -205,6 +210,7 @@ def _gen_torch_model(self, inputdim, optimizer):
layers=self.layers,
gamma=self.gamma,
smoothing_factor=self.smoothing_factor,
gate_l2_penalty=self.gate_l2_penalty,
optimizer=optimizer)

def fit(self, x, t, e, a, vsize=0.15, val_data=None,
Expand Down
5 changes: 4 additions & 1 deletion auton_survival/models/cmhe/cmhe_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def _init_dcmhe_layers(self, lastdim):
self.omega = torch.nn.Parameter(torch.rand(self.g)-0.5)

def __init__(self, k, g, inputdim, layers=None, gamma=100,
smoothing_factor=1e-4, optimizer='Adam'):
smoothing_factor=1e-4, gate_l2_penalty=1e-4,
optimizer='Adam'):

super(DeepCMHETorch, self).__init__()

Expand All @@ -65,6 +66,8 @@ def __init__(self, k, g, inputdim, layers=None, gamma=100,

self._init_dcmhe_layers(lastdim)

self.gate_l2_penalty = gate_l2_penalty

self.embedding = create_representation(inputdim, layers, 'Tanh')


Expand Down
3 changes: 3 additions & 0 deletions auton_survival/models/cmhe/cmhe_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def m_step(model, optimizer, x, t, e, a, log_likelihoods, typ='soft'):

optimizer.zero_grad()
loss = q_function(model, x, t, e, a, log_likelihoods, typ)
gate_regularization_loss = (model.phi_gate.weight**2).sum()
gate_regularization_loss += (model.z_gate.weight**2).sum()
loss += (model.gate_l2_penalty)*gate_regularization_loss
loss.backward()
optimizer.step()

Expand Down

0 comments on commit 04524b3

Please sign in to comment.