Skip to content

Commit

Permalink
modified: dcm_api.py
Browse files Browse the repository at this point in the history
	modified:   dcm_utilities.py
  • Loading branch information
chiragnagpal committed Dec 24, 2021
1 parent e03a4a4 commit b6ae902
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
10 changes: 6 additions & 4 deletions dsm/contrib/dcm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ class DeepCoxMixtures:
>>> model.fit(x, t, e)
"""
def __init__(self, k=3, layers=None, gamma=0.95, use_activation=False):

def __init__(self, k=3, layers=None, gamma=0.95, smoothing_factor=1e-4, use_activation=False):

self.k = k
self.layers = layers
self.fitted = False
self.gamma = gamma
self.gamma = gamma
self.smoothing_factor = smoothing_factor
self.use_activation = use_activation

def __call__(self):
Expand Down Expand Up @@ -146,8 +148,8 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,
lr=learning_rate,
bs=batch_size,
return_losses=True,
smoothing_factor=None,
use_posteriors=True,)
smoothing_factor=self.smoothing_factor,
use_posteriors=True)

self.torch_model = (model[0].eval(), model[1])
self.fitted = True
Expand Down
5 changes: 1 addition & 4 deletions dsm/contrib/dcm_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from scipy.interpolate import UnivariateSpline
from sksurv.linear_model.coxph import BreslowEstimator
from sksurv.linear_model.coxph import BreslowEstimator

from sklearn.utils import shuffle

Expand Down Expand Up @@ -247,12 +247,9 @@ def train_dcm(model, train_data, val_data, epochs=50,
xt, tt, et = train_data
xv, tv, ev = val_data

unique_times = np.unique(tt)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = get_optimizer(model, lr)


valc = np.inf
patience_ = 0

Expand Down

0 comments on commit b6ae902

Please sign in to comment.