Skip to content

Commit

Permalink
Bug Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
laqua-stack committed Mar 20, 2020
1 parent 01caaf2 commit d933299
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
48 changes: 48 additions & 0 deletions examples/rsf_brier_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 20 00:15:45 2020
Example using brier-score and calibration plot with random survival forrest
@author: Fabian
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

from sksurv.metrics import brier_score, integrated_brier_score, calibration_curve

if __name__ == '__main__':

X, y = load_gbsg2()

grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
grade_num = OrdinalEncoder(categories=[["'I'", "'II'", "'III'"]]).fit_transform(grade_str)

X_no_grade = X.drop("tgrade", axis=1)
Xt = OneHotEncoder().fit_transform(X_no_grade)
Xt = np.column_stack((Xt.values, grade_num))

feature_names = X_no_grade.columns.tolist() + ["tgrade"]

random_state = 20

X_train, X_test, y_train, y_test = train_test_split(
Xt, y, test_size=0.25, random_state=random_state)

rsf = RandomSurvivalForest(n_estimators=1000,
min_samples_split=10,
min_samples_leaf=15,
max_features="sqrt",
n_jobs=-1,
random_state=random_state)
rsf.fit(X_train, y_train)

SurvivalFunction = rsf.predict_survival_function(X_test)
bs=brier_score(y_train,y_test,SurvivalFunction,rsf.event_times_)
4 changes: 2 additions & 2 deletions sksurv/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def brier_score(survival_train,survival_test, estimate, times,
is_control = (T > t)

# get survival function S(t) by interpolating the Survival function
S=numpy.interpolate(Survival,times,t)
S=numpy.interp(Survival,times,t)
S2=numpy.multiply(S,S)
omS2=numpy.multiply(1 - S,1 - S)

Expand Down Expand Up @@ -972,7 +972,7 @@ def calibration_curve(survival_train,survival_test, estimate, times,
test_time_traintest=test_time
test_event_traintest=test_event
# interpolate predicted survival at fu_time.
pred_surv = numpy.interpolate(estimate, times, fu_time)
pred_surv = numpy.interp(estimate, times, fu_time)

# sort by pred_surv in ascending order
order=numpy.argsort(pred_surv)
Expand Down

0 comments on commit d933299

Please sign in to comment.