diff --git a/auton_survival/models/dcm/__init__.py b/auton_survival/models/dcm/__init__.py index 72b44c9..4bff763 100644 --- a/auton_survival/models/dcm/__init__.py +++ b/auton_survival/models/dcm/__init__.py @@ -88,7 +88,8 @@ class DeepCoxMixtures: """ def __init__(self, k=3, layers=None, gamma=10, - smoothing_factor=1e-4, use_activation=False): + smoothing_factor=1e-4, use_activation=False, + random_seed=0): self.k = k self.layers = layers @@ -96,6 +97,7 @@ def __init__(self, k=3, layers=None, gamma=10, self.gamma = gamma self.smoothing_factor = smoothing_factor self.use_activation = use_activation + self.random_seed = random_seed def __call__(self): if self.fitted: @@ -109,10 +111,10 @@ def __call__(self): def _preprocess_test_data(self, x): return torch.from_numpy(x).float() - def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): + def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed): idx = list(range(x.shape[0])) - np.random.seed(random_state) + np.random.seed(random_seed) np.random.shuffle(idx) x_train, t_train, e_train = x[idx], t[idx], e[idx] @@ -141,6 +143,10 @@ def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): def _gen_torch_model(self, inputdim, optimizer): """Helper function to return a torch model.""" + + np.random.seed(self.random_seed) + torch.manual_seed(self.random_seed) + return DeepCoxMixturesTorch(inputdim, k=self.k, gamma=self.gamma, @@ -150,7 +156,7 @@ def _gen_torch_model(self, inputdim, optimizer): def fit(self, x, t, e, vsize=0.15, val_data=None, iters=1, learning_rate=1e-3, batch_size=100, - optimizer="Adam", random_state=100): + optimizer="Adam"): r"""This method is used to train an instance of the DSM model. @@ -177,14 +183,14 @@ def fit(self, x, t, e, vsize=0.15, val_data=None, optimizer: str The choice of the gradient based optimization method. One of 'Adam', 'RMSProp' or 'SGD'. - random_state: float + random_seed: float random seed that determines how the validation set is chosen. """ processed_data = self._preprocess_training_data(x, t, e, vsize, val_data, - random_state) + self.random_seed) x_train, t_train, e_train, x_val, t_val, e_val = processed_data #Todo: Change this somehow. The base design shouldn't depend on child @@ -201,7 +207,8 @@ def fit(self, x, t, e, vsize=0.15, val_data=None, bs=batch_size, return_losses=True, smoothing_factor=self.smoothing_factor, - use_posteriors=True) + use_posteriors=True, + random_seed=self.random_seed) self.torch_model = (model[0].eval(), model[1]) self.fitted = True diff --git a/auton_survival/models/dcm/dcm_utilities.py b/auton_survival/models/dcm/dcm_utilities.py index 49a798c..57547a8 100644 --- a/auton_survival/models/dcm/dcm_utilities.py +++ b/auton_survival/models/dcm/dcm_utilities.py @@ -242,12 +242,12 @@ def test_step(model, x, t, e, breslow_splines, loss='q', typ='soft'): def train_dcm(model, train_data, val_data, epochs=50, patience=3, vloss='q', bs=256, typ='soft', lr=1e-3, - use_posteriors=True, debug=False, random_state=0, + use_posteriors=True, debug=False, random_seed=0, return_losses=False, update_splines_after=10, smoothing_factor=1e-2): - torch.manual_seed(random_state) - np.random.seed(random_state) + torch.manual_seed(random_seed) + np.random.seed(random_seed) if val_data is None: val_data = train_data