From abd75932f8c410d0bfb10e0e05f6ff89a13b7462 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Mon, 22 Jan 2024 21:15:43 +0100 Subject: [PATCH] WIP: Maximally selected rank statistics --- sksurv/nonparametric.py | 192 +++++++++++++++++++++++++++++++++++- tests/test_nonparametric.py | 180 +++++++++++++++++++++++++++++++++ 2 files changed, 370 insertions(+), 2 deletions(-) diff --git a/sksurv/nonparametric.py b/sksurv/nonparametric.py index 82ae7364..2a715abf 100644 --- a/sksurv/nonparametric.py +++ b/sksurv/nonparametric.py @@ -16,15 +16,16 @@ from scipy import stats from sklearn.base import BaseEstimator from sklearn.utils._param_validation import Interval, StrOptions -from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted +from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted, check_random_state -from .util import check_y_survival +from .util import check_array_survival, check_y_survival __all__ = [ "CensoringDistributionEstimator", "kaplan_meier_estimator", "nelson_aalen_estimator", "ipc_weights", + "MaxStatCutpointEstimator", "SurvivalFunctionEstimator", ] @@ -586,3 +587,190 @@ def predict_ipcw(self, y): weights[event] = 1.0 / Ghat return weights + + +def _logrank_scores(times, events): + # assign ranks based on observed time + ranks = stats.rankdata(times, method="min") + + # sort by times, breaking ties with event + order = np.lexsort((events, times)) + ordered_ranks = ranks[order] + ordered_events = events[order] + + n = times.shape[0] + n_risk = np.full((n,), times.shape[0] + 1, dtype=int) + n_event = np.zeros(n, dtype=int) + + # compute number at risk and number of events at each unique time point + value = ordered_ranks[0] + i = 0 + for j in range(n): + if ordered_ranks[j] != value: + n_risk[i] -= value + + value = ordered_ranks[j] + i += 1 + + n_event[i] += ordered_events[j] + + if i < n: + n_risk[i] -= value + i += 1 + + n_risk.resize(i, refcheck=False) + n_event.resize(i, refcheck=False) + n_ties = -np.diff(np.concatenate((n_risk, [0]))) + + # compute cumulative sum + csum = np.cumsum(n_event / n_risk) + # restore original order while accounting for ties + idx = np.repeat(np.arange(n_ties.shape[0], dtype=int), n_ties) + csum = csum[idx[ranks - 1]] + + scores = csum - events + + return scores + + +class MaxStatCutpointEstimator(BaseEstimator): + """Estimation of cutpoints with maximally selected rank statistics + + See [1]_, [2]_, and [3]_ for details. + + Parameters + ---------- + min_prob : float, optional, default: 0.1 + Consider only cutpoints greater or equal than the ``minprob * 100%`` + quantile are considered. Must be between 0.0 and 0.5. + + max_prob : float, optional + Consider only cutpoints less or equal than the ``maxprob * 100%`` + quantile are considered. Must be between 0.5 and 1.0, or None. + If None, use ``1 - min_prob``. + + n_resample : float, optional, default: 10000 + Number of Monte Carlo replicates used to estimate the null distribution. + + random_state : int, RandomState instance, default: None + Random numnber seed used to shuffle the data for + estimating the null distribution. + + Attributes + ---------- + cutpoints_ : ndarray, shape = (n_cutpoints,) + Unique list of cutpoints that have been evaluated. + + statistics_ : ndarray, shape = (n_cutpoints,) + Standardized linear rank statistics for each cutpoint. + + best_cutpoint_ : float + Best cutpoint. + + best_test_statistic_ : float + Test statistic of best cutpoint. + + p_value_ : float + P-value of best cutpoint. + + References + ---------- + .. [1] Hothorn T, Lausen B. "On the exact distribution of maximally selected rank statistics." + Computational Statistics & Data Analysis. 2003 Jun;43(2):121-37. + + .. [2] Hothorn T, Zeileis A. "Generalized Maximally Selected Statistics." + Biometrics. 2008;64(4):1263-9. + + .. [3] Lausen B, Schumacher M. "Maximally Selected Rank Statistics." + Biometrics. 1992 Mar;48(1):73. + """ + + _parameter_constraints = { + "min_prob": [Interval(numbers.Real, 0.0, 0.5, closed="neither")], + "max_prob": [Interval(numbers.Real, 0.5, 1.0, closed="neither"), None], + "n_resample": [Interval(numbers.Integral, 100, None, closed="left")], + "random_state": ["random_state"], + } + + def __init__(self, min_prob=0.01, max_prob=None, n_resample=10000, random_state=None): + self.min_prob = min_prob + self.max_prob = max_prob + self.n_resample = n_resample + self.random_state = random_state + + def _get_cutpoint(self, feature_vector): + max_prob = self.max_prob + if max_prob is None: + max_prob = 1.0 - self.min_prob + if max_prob <= self.min_prob: + raise ValueError(f"max_prob ({max_prob}) must be larger than min_prob ({self.min_prob})") + + cp_min, cp_max = np.quantile(feature_vector, [self.min_prob, max_prob], method="inverted_cdf") + + cutpoints = np.unique(feature_vector)[:-1] + cutpoints = cutpoints[(cutpoints >= cp_min) & (cutpoints <= cp_max)] + + return cutpoints + + def _maxstat_test(self, times, events, feature_vector, cutpoints): + n = times.shape[0] + scores = _logrank_scores(times, events) + scores_mean = np.mean(scores) + scores_var = np.var(scores, ddof=1) + + lin_stats = np.empty(cutpoints.shape[0], dtype=float) + mu = np.empty(cutpoints.shape[0], dtype=float) + var = np.empty(cutpoints.shape[0], dtype=float) + for i, cutpoint in enumerate(cutpoints): + s_selected = scores[feature_vector <= cutpoint] + n_split = s_selected.shape[0] + + mu[i] = n_split * scores_mean + var[i] = n_split * (n - n_split) / n * scores_var + lin_stats[i] = np.sum(s_selected) + + std_lin_stats = (lin_stats - mu) / np.sqrt(var) + + # estimate null distribution + permuted_stats = np.empty((self.n_resample, cutpoints.shape[0]), dtype=float) + rnd = check_random_state(self.random_state) + for j in range(self.n_resample): + rnd.shuffle(scores) + for i, cutpoint in enumerate(cutpoints): + s_selected = scores[feature_vector <= cutpoint] + permuted_stats[j, i] = np.sum(s_selected) + + # standardize permuted scores + permuted_stats -= mu[np.newaxis] + permuted_stats /= np.sqrt(var[np.newaxis]) + + return std_lin_stats, permuted_stats + + def fit(self, X, y): + self._validate_params() + + X = self._validate_data(X, ensure_min_samples=2) + event, time = check_array_survival(X, y) + + fv = X[:, 0] + cutpoints = self._get_cutpoint(fv) + std_linear_stats, permuted_stats = self._maxstat_test(time, event, fv, cutpoints) + + abs_stats = np.abs(std_linear_stats) + idx_max = np.argmax(abs_stats) + self.statistics_ = std_linear_stats + self.cutpoints_ = cutpoints + self.best_cutpoint_ = cutpoints[idx_max] + self.best_test_statistic_ = abs_stats[idx_max] + + # row-wise max over cutpoints + std_statistics = np.max(np.abs(permuted_stats), axis=1) + # percentage of permuted statistics exceeding test statistic + self.p_value_ = np.mean(std_statistics >= self.best_test_statistic_) + + return self + + def predict(self, X): + X = self._validate_data(X, reset=False) + + return None diff --git a/tests/test_nonparametric.py b/tests/test_nonparametric.py index 84b4cd64..1b661536 100644 --- a/tests/test_nonparametric.py +++ b/tests/test_nonparametric.py @@ -7,6 +7,7 @@ from sksurv.nonparametric import ( CensoringDistributionEstimator, + MaxStatCutpointEstimator, SurvivalFunctionEstimator, kaplan_meier_estimator, nelson_aalen_estimator, @@ -6232,3 +6233,182 @@ def test_whas500(make_whas500, whas500_true_x): ) assert_array_almost_equal(y, true_y) + + +class MaxStatCases(FixtureParameterFactory): + def data_simple_1(self): + X = np.arange(5) + y = Surv.from_arrays(event=[1, 1, 1, 1, 0], time=[3, 2, 1, 6, 7.2]) + + expected = np.array([-0.293940160569987, -0.849234789286624, -1.73539283028136, -1.74103018183762]) + return X, y, expected + + def data_simple_2(self): + X = np.arange(7) + y = Surv.from_arrays(event=[1, 1, 1, 1, 0, 1, 0], time=[3, 2, 1, 5, 5, 6, 7.2]) + + expected = np.array( + [ + -0.670970617929866, + -1.25139260427407, + -1.97149213626872, + -2.20410978318828, + -1.60965419662672, + -1.72302648973252, + ] + ) + return X, y, expected + + def data_simple_3(self): + X = np.arange(7) + y = Surv.from_arrays(event=[1, 1, 1, 1, 1, 1, 0], time=[3, 2, 1, 5, 5, 6, 7.2]) + + expected = np.array( + [ + -0.649612821471534, + -1.21155928249016, + -1.90873718599155, + -1.89981785334673, + -2.07137554748317, + -1.99929382918909, + ] + ) + return X, y, expected + + def data_simple_4(self): + X = np.arange(7) + y = Surv.from_arrays(event=[1, 1, 0, 1, 1, 1, 0], time=[3, 2, 5, 5, 5, 6, 7.2]) + + expected = np.array( + [ + -0.992127249600952, + -1.72249655501033, + -0.851524299687802, + -1.1466548808296, + -1.57939684120948, + -1.73793325102512, + ] + ) + return X, y, expected + + def data_simple_5(self): + X = np.arange(7) + y = Surv.from_arrays(event=[1, 1, 0, 1, 1, 1, 0], time=[3, 2, 1, 1, 5, 6, 7.2]) + + expected = np.array( + [ + -0.569810094947253, + -1.15376451307546, + -0.911863843758283, + -1.7601092798125, + -2.0081181010575, + -1.99600144370412, + ] + ) + return X, y, expected + + def data_simple_6(self): + X = np.arange(7) + y = Surv.from_arrays(event=[1, 1, 0, 1, 1, 1, 0], time=[3, 2, 1, 1, 5, 6, 7.2]) + + expected = np.array( + [ + -0.569810094947253, + -1.15376451307546, + -0.911863843758283, + -1.7601092798125, + -2.0081181010575, + -1.99600144370412, + ] + ) + return X, y, expected + + def data_simple_7(self): + X = np.arange(8) + y = Surv.from_arrays(event=[1, 1, 0, 1, 1, 1, 0, 0], time=[3, 2, 1, 4, 5, 6, 7, 7]) + + expected = np.array( + [ + -0.988016177516947, + -1.69136683842127, + -1.51280449023955, + -1.92898500262744, + -2.22731368793731, + -2.38873039641651, + -1.56379112234579, + ] + ) + return X, y, expected + + def data_simple_8(self): + X = np.arange(8) + y = Surv.from_arrays(event=[1, 1, 0, 1, 1, 1, 0, 1], time=[3, 2, 1, 4, 5, 6, 7, 7]) + + expected = np.array( + [ + -0.93028311158255, + -1.59253465791262, + -1.4244063006468, + -1.81626800377926, + -2.09716435341382, + -2.24914894763659, + -0.798760326841569, + ] + ) + return X, y, expected + + def data_simple_9(self): + X = np.arange(8) + y = Surv.from_arrays(event=[1, 1, 0, 1, 1, 1, 1, 1], time=[3, 2, 1, 4, 5, 6, 7, 7]) + + expected = np.array( + [ + -0.988016177516947, + -1.69136683842127, + -1.51280449023955, + -1.92898500262744, + -2.22731368793731, + -2.38873039641651, + -1.56379112234579, + ] + ) + return X, y, expected + + def data_simple_10(self): + X = np.arange(9) + y = Surv.from_arrays(event=[1, 1, 0, 1, 1, 0, 0, 1, 1], time=[3, 3, 1, 1, 6, 6, 7, 7, 7]) + + expected = np.array( + [ + -0.930210957630102, + -1.40634677755228, + -1.12604484344696, + -1.93525350694573, + -2.3284970195571, + -1.84083852667852, + -0.614351487035996, + -0.406355313069992, + ] + ) + return X, y, expected + + +class TestMaxStat: + @staticmethod + @pytest.mark.parametrize("X,y,expected", MaxStatCases().get_cases()) + def test_logrank_score(X, y, expected): + est = MaxStatCutpointEstimator(n_resample=100) + est.fit(X[:, np.newaxis], y) + + assert_array_almost_equal(est.statistics_, expected) + + @staticmethod + def test_whas500(make_whas500): + whas500 = make_whas500(with_mean=False, with_std=False) + X = whas500.x_data_frame.loc[:, ["sysbp"]] + + est = MaxStatCutpointEstimator(random_state=24) + est.fit(X, whas500.y) + + assert est.best_cutpoint_ == 137 + assert est.p_value_ == pytest.approx(0.0113, 1e-5)