Skip to content

Commit

Permalink
WIP: Maximally selected rank statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Jan 22, 2024
1 parent 3a1b40a commit abd7593
Show file tree
Hide file tree
Showing 2 changed files with 370 additions and 2 deletions.
192 changes: 190 additions & 2 deletions sksurv/nonparametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
from scipy import stats
from sklearn.base import BaseEstimator
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted
from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted, check_random_state

from .util import check_y_survival
from .util import check_array_survival, check_y_survival

__all__ = [
"CensoringDistributionEstimator",
"kaplan_meier_estimator",
"nelson_aalen_estimator",
"ipc_weights",
"MaxStatCutpointEstimator",
"SurvivalFunctionEstimator",
]

Expand Down Expand Up @@ -586,3 +587,190 @@ def predict_ipcw(self, y):
weights[event] = 1.0 / Ghat

return weights


def _logrank_scores(times, events):
# assign ranks based on observed time
ranks = stats.rankdata(times, method="min")

# sort by times, breaking ties with event
order = np.lexsort((events, times))
ordered_ranks = ranks[order]
ordered_events = events[order]

n = times.shape[0]
n_risk = np.full((n,), times.shape[0] + 1, dtype=int)
n_event = np.zeros(n, dtype=int)

# compute number at risk and number of events at each unique time point
value = ordered_ranks[0]
i = 0
for j in range(n):
if ordered_ranks[j] != value:
n_risk[i] -= value

value = ordered_ranks[j]
i += 1

n_event[i] += ordered_events[j]

if i < n:
n_risk[i] -= value
i += 1

n_risk.resize(i, refcheck=False)
n_event.resize(i, refcheck=False)
n_ties = -np.diff(np.concatenate((n_risk, [0])))

# compute cumulative sum
csum = np.cumsum(n_event / n_risk)
# restore original order while accounting for ties
idx = np.repeat(np.arange(n_ties.shape[0], dtype=int), n_ties)
csum = csum[idx[ranks - 1]]

scores = csum - events

return scores


class MaxStatCutpointEstimator(BaseEstimator):
"""Estimation of cutpoints with maximally selected rank statistics
See [1]_, [2]_, and [3]_ for details.
Parameters
----------
min_prob : float, optional, default: 0.1
Consider only cutpoints greater or equal than the ``minprob * 100%``
quantile are considered. Must be between 0.0 and 0.5.
max_prob : float, optional
Consider only cutpoints less or equal than the ``maxprob * 100%``
quantile are considered. Must be between 0.5 and 1.0, or None.
If None, use ``1 - min_prob``.
n_resample : float, optional, default: 10000
Number of Monte Carlo replicates used to estimate the null distribution.
random_state : int, RandomState instance, default: None
Random numnber seed used to shuffle the data for
estimating the null distribution.
Attributes
----------
cutpoints_ : ndarray, shape = (n_cutpoints,)
Unique list of cutpoints that have been evaluated.
statistics_ : ndarray, shape = (n_cutpoints,)
Standardized linear rank statistics for each cutpoint.
best_cutpoint_ : float
Best cutpoint.
best_test_statistic_ : float
Test statistic of best cutpoint.
p_value_ : float
P-value of best cutpoint.
References
----------
.. [1] Hothorn T, Lausen B. "On the exact distribution of maximally selected rank statistics."
Computational Statistics & Data Analysis. 2003 Jun;43(2):121-37.
.. [2] Hothorn T, Zeileis A. "Generalized Maximally Selected Statistics."
Biometrics. 2008;64(4):1263-9.
.. [3] Lausen B, Schumacher M. "Maximally Selected Rank Statistics."
Biometrics. 1992 Mar;48(1):73.
"""

_parameter_constraints = {
"min_prob": [Interval(numbers.Real, 0.0, 0.5, closed="neither")],
"max_prob": [Interval(numbers.Real, 0.5, 1.0, closed="neither"), None],
"n_resample": [Interval(numbers.Integral, 100, None, closed="left")],
"random_state": ["random_state"],
}

def __init__(self, min_prob=0.01, max_prob=None, n_resample=10000, random_state=None):
self.min_prob = min_prob
self.max_prob = max_prob
self.n_resample = n_resample
self.random_state = random_state

def _get_cutpoint(self, feature_vector):
max_prob = self.max_prob
if max_prob is None:
max_prob = 1.0 - self.min_prob
if max_prob <= self.min_prob:
raise ValueError(f"max_prob ({max_prob}) must be larger than min_prob ({self.min_prob})")

cp_min, cp_max = np.quantile(feature_vector, [self.min_prob, max_prob], method="inverted_cdf")

cutpoints = np.unique(feature_vector)[:-1]
cutpoints = cutpoints[(cutpoints >= cp_min) & (cutpoints <= cp_max)]

return cutpoints

def _maxstat_test(self, times, events, feature_vector, cutpoints):
n = times.shape[0]
scores = _logrank_scores(times, events)
scores_mean = np.mean(scores)
scores_var = np.var(scores, ddof=1)

lin_stats = np.empty(cutpoints.shape[0], dtype=float)
mu = np.empty(cutpoints.shape[0], dtype=float)
var = np.empty(cutpoints.shape[0], dtype=float)
for i, cutpoint in enumerate(cutpoints):
s_selected = scores[feature_vector <= cutpoint]
n_split = s_selected.shape[0]

mu[i] = n_split * scores_mean
var[i] = n_split * (n - n_split) / n * scores_var
lin_stats[i] = np.sum(s_selected)

std_lin_stats = (lin_stats - mu) / np.sqrt(var)

# estimate null distribution
permuted_stats = np.empty((self.n_resample, cutpoints.shape[0]), dtype=float)
rnd = check_random_state(self.random_state)
for j in range(self.n_resample):
rnd.shuffle(scores)
for i, cutpoint in enumerate(cutpoints):
s_selected = scores[feature_vector <= cutpoint]
permuted_stats[j, i] = np.sum(s_selected)

# standardize permuted scores
permuted_stats -= mu[np.newaxis]
permuted_stats /= np.sqrt(var[np.newaxis])

return std_lin_stats, permuted_stats

def fit(self, X, y):
self._validate_params()

X = self._validate_data(X, ensure_min_samples=2)
event, time = check_array_survival(X, y)

fv = X[:, 0]
cutpoints = self._get_cutpoint(fv)
std_linear_stats, permuted_stats = self._maxstat_test(time, event, fv, cutpoints)

abs_stats = np.abs(std_linear_stats)
idx_max = np.argmax(abs_stats)
self.statistics_ = std_linear_stats
self.cutpoints_ = cutpoints
self.best_cutpoint_ = cutpoints[idx_max]
self.best_test_statistic_ = abs_stats[idx_max]

# row-wise max over cutpoints
std_statistics = np.max(np.abs(permuted_stats), axis=1)
# percentage of permuted statistics exceeding test statistic
self.p_value_ = np.mean(std_statistics >= self.best_test_statistic_)

return self

def predict(self, X):
X = self._validate_data(X, reset=False)

return None
Loading

0 comments on commit abd7593

Please sign in to comment.