Skip to content

Commit

Permalink
Test MinlipSurvivalAnalysis with callable kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Dec 31, 2024
1 parent f8b9855 commit 5cd347e
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/test_minlip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sksurv.column import encode_categorical
from sksurv.datasets import load_gbsg2
from sksurv.exceptions import NoComparablePairException
from sksurv.kernels import ClinicalKernelTransform
from sksurv.svm._minlip import create_difference_matrix
from sksurv.svm.minlip import HingeLossSurvivalSVM, MinlipSurvivalAnalysis
from sksurv.testing import FixtureParameterFactory, assert_cindex_almost_equal
Expand Down Expand Up @@ -396,6 +397,22 @@ def test_kernel_precomputed(gbsg2_scaled, solver):
p = m.predict(X_test)
assert_cindex_almost_equal(y_test["cens"], y_test["time"], p, (0.6518928901200369, 8472, 4524, 0, 3))

@staticmethod
@pytest.mark.slow()
def test_fit_clinical_kernel(make_whas500):
whas500 = make_whas500(with_mean=False, with_std=False)

trans = ClinicalKernelTransform()
trans.fit(whas500.x_data_frame)

m = MinlipSurvivalAnalysis(kernel=trans.pairwise_kernel)
m.fit(whas500.x, whas500.y)

assert not m.__sklearn_tags__().input_tags.pairwise

c = m.score(whas500.x, whas500.y)
assert c == pytest.approx(0.7314135916645598)

@staticmethod
@pytest.mark.parametrize("solver", ["osqp", "ecos"])
def test_max_iter(gbsg2_scaled, solver):
Expand Down

0 comments on commit 5cd347e

Please sign in to comment.