diff --git a/auton_survival/estimators.py b/auton_survival/estimators.py index b7442a7..1a3186d 100644 --- a/auton_survival/estimators.py +++ b/auton_survival/estimators.py @@ -725,10 +725,16 @@ def __init__(self, treated_model, control_model): self.treated_model = treated_model self.control_model = control_model - def predict_counterfactuals(self, features, times): + def predict_counterfactual_survival(self, features, times): - control_outcomes = self.control_model.predict(features, times) - treated_outcomes = self.treated_model.predict(features, times) + control_outcomes = self.control_model.predict_survival(features, times) + treated_outcomes = self.treated_model.predict_survival(features, times) return treated_outcomes, control_outcomes - \ No newline at end of file + + def predict_counterfactual_risk(self, features, times): + + control_outcomes = self.control_model.predict_risk(features, times) + treated_outcomes = self.treated_model.predict_risk(features, times) + + return treated_outcomes, control_outcomes \ No newline at end of file