diff --git a/auton_survival/estimators.py b/auton_survival/estimators.py index 8b230b3..2a0f215 100644 --- a/auton_survival/estimators.py +++ b/auton_survival/estimators.py @@ -26,7 +26,6 @@ import numpy as np import pandas as pd -from auton_survival.models.dcm.dcm_api import DeepCoxMixtures def _get_valid_idx(n, size, random_seed): @@ -108,7 +107,7 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams): k = hyperparams.get("k", 3) layers = hyperparams.get("layers", [100]) batch_size = hyperparams.get("batch_size", 128) - lr = hyperparams.get("lr", 1e-3) + learning_rate = hyperparams.get("learning_rate", 1e-3) epochs = hyperparams.get("epochs", 50) smoothing_factor = hyperparams.get("smoothing_factor", 1e-4) gamma = hyperparams.get("gamma", 10) @@ -119,7 +118,7 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams): smoothing_factor=smoothing_factor) model.fit(features.values, outcomes.time.values, outcomes.event.values, - iters=epochs, batch_size=batch_size, lr=lr, + iters=epochs, batch_size=batch_size, learning_rate=learning_rate, random_seed=random_seed) return model