diff --git a/sksurv/tree/tree.py b/sksurv/tree/tree.py index aecc4a65..815e7387 100644 --- a/sksurv/tree/tree.py +++ b/sksurv/tree/tree.py @@ -206,7 +206,7 @@ def __init__( def __sklearn_tags__(self): tags = super().__sklearn_tags__() - tags.input_tags.allow_nan = self.splitter == "best" + tags.input_tags.allow_nan = self.splitter in ("best", "random") return tags def _support_missing_values(self, X): diff --git a/tests/test_forest.py b/tests/test_forest.py index a8cb74ef..ea482600 100644 --- a/tests/test_forest.py +++ b/tests/test_forest.py @@ -40,7 +40,10 @@ def test_fit_predict(make_whas500, forest_cls, expected_c): assert_cindex_almost_equal(whas500.y["fstat"], whas500.y["lenfol"], pred, expected_c) -def test_fit_missing_values(make_whas500): +@pytest.mark.parametrize( + "forest_cls,expected_cindex", [(ExtraSurvivalTrees, 0.7486232588273405), (RandomSurvivalForest, 0.7444120505344995)] +) +def test_fit_missing_values(make_whas500, forest_cls, expected_cindex): whas500 = make_whas500(to_numeric=True) rng = np.random.RandomState(42) @@ -52,35 +55,18 @@ def test_fit_missing_values(make_whas500): X_train, y_train = X[:400], whas500.y[:400] X_test, y_test = X[400:], whas500.y[400:] - forest = RandomSurvivalForest(random_state=42) + forest = forest_cls(random_state=42) forest.fit(X_train, y_train) tags = forest.__sklearn_tags__() assert tags.input_tags.allow_nan cindex = forest.score(X_test, y_test) - assert cindex == pytest.approx(0.7444120505344995) + assert cindex == pytest.approx(expected_cindex) -def test_fit_missing_values_not_supported(make_whas500): - whas500 = make_whas500(to_numeric=True) - - rng = np.random.RandomState(42) - mask = rng.binomial(n=1, p=0.15, size=whas500.x.shape) - mask = mask.astype(bool) - X = whas500.x.copy() - X[mask] = np.nan - - forest = ExtraSurvivalTrees(random_state=42) - with pytest.raises(ValueError, match="Input X contains NaN"): - forest.fit(X, whas500.y) - - tags = forest.__sklearn_tags__() - assert not tags.input_tags.allow_nan - - -@pytest.mark.parametrize("forst_cls,allows_nan", [(ExtraTreesClassifier, False), (RandomForestClassifier, True)]) -def test_sklearn_random_forest_tags(forst_cls, allows_nan): +@pytest.mark.parametrize("forst_cls", [ExtraTreesClassifier, RandomForestClassifier]) +def test_sklearn_random_forest_tags(forst_cls): est = forst_cls() # https://scikit-learn.org/stable/developers/develop.html#estimator-tags @@ -88,7 +74,7 @@ def test_sklearn_random_forest_tags(forst_cls, allows_nan): assert tags.target_tags.multi_output assert tags.requires_fit assert tags.target_tags.required - assert tags.input_tags.allow_nan is allows_nan + assert tags.input_tags.allow_nan @pytest.mark.parametrize("forest_cls", FORESTS) diff --git a/tests/test_tree.py b/tests/test_tree.py index 4178bfd1..fc2a214a 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -14,7 +14,7 @@ from sksurv.compare import compare_survival from sksurv.datasets import load_breast_cancer, load_veterans_lung_cancer from sksurv.nonparametric import kaplan_meier_estimator, nelson_aalen_estimator -from sksurv.tree import ExtraSurvivalTree, SurvivalTree +from sksurv.tree import SurvivalTree from sksurv.util import Surv @@ -837,19 +837,3 @@ def test_missing_values_best_splitter_to_right(): # missing values go to the right y_expected = tree.tree_.value[4] assert_array_almost_equal(y_pred, y_expected) - - -@pytest.mark.parametrize("is_sparse", [False, True]) -def test_missing_value_random_splitter_errors(is_sparse): - X = np.array([[3, 5, 7, 11, np.nan, 13, 17, np.nan, 19]], dtype=np.float32).T - y = Surv.from_arrays( - event=np.array([True, True, True, False, True, False, False, False, True]), - time=np.array([90, 80, 70, 60, 50, 40, 30, 20, 10]), - ) - - if is_sparse: - X = sparse.csr_matrix(X) - - tree = ExtraSurvivalTree() - with pytest.raises(ValueError, match="Input X contains NaN"): - tree.fit(X, y)