Skip to content

Commit

Permalink
Add missing-values support to ExtraSurvivalTrees
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Dec 21, 2024
1 parent 8a97071 commit 4b7e664
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 41 deletions.
2 changes: 1 addition & 1 deletion sksurv/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __init__(

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = self.splitter == "best"
tags.input_tags.allow_nan = self.splitter in ("best", "random")
return tags

def _support_missing_values(self, X):
Expand Down
32 changes: 9 additions & 23 deletions tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def test_fit_predict(make_whas500, forest_cls, expected_c):
assert_cindex_almost_equal(whas500.y["fstat"], whas500.y["lenfol"], pred, expected_c)


def test_fit_missing_values(make_whas500):
@pytest.mark.parametrize(
"forest_cls,expected_cindex", [(ExtraSurvivalTrees, 0.7486232588273405), (RandomSurvivalForest, 0.7444120505344995)]
)
def test_fit_missing_values(make_whas500, forest_cls, expected_cindex):
whas500 = make_whas500(to_numeric=True)

rng = np.random.RandomState(42)
Expand All @@ -52,43 +55,26 @@ def test_fit_missing_values(make_whas500):
X_train, y_train = X[:400], whas500.y[:400]
X_test, y_test = X[400:], whas500.y[400:]

forest = RandomSurvivalForest(random_state=42)
forest = forest_cls(random_state=42)
forest.fit(X_train, y_train)

tags = forest.__sklearn_tags__()
assert tags.input_tags.allow_nan

cindex = forest.score(X_test, y_test)
assert cindex == pytest.approx(0.7444120505344995)
assert cindex == pytest.approx(expected_cindex)


def test_fit_missing_values_not_supported(make_whas500):
whas500 = make_whas500(to_numeric=True)

rng = np.random.RandomState(42)
mask = rng.binomial(n=1, p=0.15, size=whas500.x.shape)
mask = mask.astype(bool)
X = whas500.x.copy()
X[mask] = np.nan

forest = ExtraSurvivalTrees(random_state=42)
with pytest.raises(ValueError, match="Input X contains NaN"):
forest.fit(X, whas500.y)

tags = forest.__sklearn_tags__()
assert not tags.input_tags.allow_nan


@pytest.mark.parametrize("forst_cls,allows_nan", [(ExtraTreesClassifier, False), (RandomForestClassifier, True)])
def test_sklearn_random_forest_tags(forst_cls, allows_nan):
@pytest.mark.parametrize("forst_cls", [ExtraTreesClassifier, RandomForestClassifier])
def test_sklearn_random_forest_tags(forst_cls):
est = forst_cls()

# https://scikit-learn.org/stable/developers/develop.html#estimator-tags
tags = est.__sklearn_tags__()
assert tags.target_tags.multi_output
assert tags.requires_fit
assert tags.target_tags.required
assert tags.input_tags.allow_nan is allows_nan
assert tags.input_tags.allow_nan


@pytest.mark.parametrize("forest_cls", FORESTS)
Expand Down
18 changes: 1 addition & 17 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sksurv.compare import compare_survival
from sksurv.datasets import load_breast_cancer, load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator, nelson_aalen_estimator
from sksurv.tree import ExtraSurvivalTree, SurvivalTree
from sksurv.tree import SurvivalTree
from sksurv.util import Surv


Expand Down Expand Up @@ -837,19 +837,3 @@ def test_missing_values_best_splitter_to_right():
# missing values go to the right
y_expected = tree.tree_.value[4]
assert_array_almost_equal(y_pred, y_expected)


@pytest.mark.parametrize("is_sparse", [False, True])
def test_missing_value_random_splitter_errors(is_sparse):
X = np.array([[3, 5, 7, 11, np.nan, 13, 17, np.nan, 19]], dtype=np.float32).T
y = Surv.from_arrays(
event=np.array([True, True, True, False, True, False, False, False, True]),
time=np.array([90, 80, 70, 60, 50, 40, 30, 20, 10]),
)

if is_sparse:
X = sparse.csr_matrix(X)

tree = ExtraSurvivalTree()
with pytest.raises(ValueError, match="Input X contains NaN"):
tree.fit(X, y)

0 comments on commit 4b7e664

Please sign in to comment.