diff --git a/examples/framingham_final_experiment.ipynb b/examples/framingham_final_experiment.ipynb new file mode 100644 index 0000000..d2ccfa6 --- /dev/null +++ b/examples/framingham_final_experiment.ipynb @@ -0,0 +1,856 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from dsm import datasets, DeepRecurrentSurvivalMachines" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def unrollx(data):\n", + " return np.vstack([dat for dat in data])\n", + "def unrollt(data):\n", + " return np.concatenate([dat for dat in data])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "x, t, e = datasets.load_dataset('FRAMINGHAM', sequential=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "x, t, e = np.array(x), np.array(t), np.array(e)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "times = np.quantile(unrollt(t)[unrollt(e)==1], [0.25, .5, 0.75])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([2254.5, 3913. , 5766. ])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "times" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "folds = np.array((list(range(4))*10000)[:len(x)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([8766, 4138, 8766, ..., 8766, 6376, 4565], dtype=int64)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unrollt(t)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from lifelines import CoxPHFitter" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def convert_to_data_frame(x, t, e):\n", + "\n", + " df = pd.DataFrame(data=x, columns=['X' + str(i) for i in range(x.shape[1])])\n", + " df['T'] = pd.DataFrame(data=t.reshape(-1, 1), columns=['T'])\n", + " df['E'] = pd.DataFrame(data=e.reshape(-1, 1), columns=['E'])\n", + "\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.metrics import concordance_index_ipcw, brier_score" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2897, 3)\n", + "(2922, 3)\n", + "(2919, 3)\n", + "(2889, 3)\n" + ] + } + ], + "source": [ + "cis = []\n", + "brs = []\n", + "\n", + "for fold in set(folds): \n", + " \n", + " x_tr, t_tr, e_tr = x[folds!=fold], t[folds!=fold], e[folds!=fold]\n", + " x_te, t_te, e_te = x[folds==fold], t[folds==fold], e[folds==fold]\n", + " \n", + " x_tr = unrollx(x_tr)\n", + " t_tr = unrollt(t_tr)\n", + " e_tr = unrollt(e_tr)\n", + "\n", + " x_te = unrollx(x_te)\n", + " t_te = unrollt(t_te)\n", + " e_te = unrollt(e_te)\n", + " \n", + " df_tr = convert_to_data_frame(x_tr, t_tr, e_tr)\n", + "\n", + " model = CoxPHFitter(penalizer=1e-3).fit(df_tr, duration_col='T', event_col='E')\n", + " \n", + " preds = model.predict_survival_function(x_te, times).T.values\n", + " \n", + " et_tr = np.array([(e_tr[i], t_tr[i]) for i in range(len(e_tr))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " et_te = np.array([(e_te[i], t_te[i]) for i in range(len(e_te))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " \n", + " print (preds.shape)\n", + " \n", + " cis_ = []\n", + " for i in range(len(times)):\n", + " cis_.append(concordance_index_ipcw(et_tr, et_te, 1-preds[:,i], times[i])[0])\n", + " cis.append(cis_)\n", + " \n", + " brs.append(brier_score(et_tr, et_te, preds, times )[1])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.76853118 0.76241752 0.76215845]\n", + "[0.06231811 0.10537517 0.13165702]\n" + ] + } + ], + "source": [ + "print (np.mean(cis, axis=0))\n", + "print (np.mean(brs, axis=0))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "from pysurvival.models.survival_forest import RandomSurvivalForestModel\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cis = []\n", + "brs = []\n", + "\n", + "for fold in set(folds): \n", + " \n", + " x_tr, t_tr, e_tr = x[folds!=fold], t[folds!=fold], e[folds!=fold]\n", + " x_te, t_te, e_te = x[folds==fold], t[folds==fold], e[folds==fold]\n", + " \n", + " x_tr = unrollx(x_tr)\n", + " t_tr = unrollt(t_tr)\n", + " e_tr = unrollt(e_tr)\n", + "\n", + " x_te = unrollx(x_te)\n", + " t_te = unrollt(t_te)\n", + " e_te = unrollt(e_te)\n", + " \n", + " df_tr = convert_to_data_frame(x_tr, t_tr, e_tr)\n", + " \n", + " \n", + " et_tr = np.array([(e_tr[i], t_tr[i]) for i in range(len(e_tr))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " et_te = np.array([(e_te[i], t_te[i]) for i in range(len(e_te))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " \n", + " model = RandomSurvivalForestModel(num_trees=50)\n", + " model = model.fit(x_tr, t_tr, e_tr)\n", + " \n", + " preds = []\n", + " for time in times:\n", + " preds.append(model.predict_survival(x_te, time))\n", + " print (len(preds))\n", + " \n", + " cis_ = []\n", + " for i in range(len(times)):\n", + " cis_.append(concordance_index_ipcw(et_tr, et_te, 1-preds[:,i], times[i])[0])\n", + " cis.append(cis_)\n", + " \n", + " brs.append(brier_score(et_tr, et_te, preds, times )[1])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.predict_survival_function(x_te, times).T.values" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.3126132634036184,\n", + " 0.46647488732207254,\n", + " 0.5624185126209059,\n", + " 0.4059503413491215,\n", + " 0.4415347567963127,\n", + " 0.2670742533214408,\n", + " 0.06359725383083349,\n", + " 0.9349522484067815,\n", + " 0.9486941636172256,\n", + " 0.9522407447171799,\n", + " 0.9103188179868816,\n", + " 0.8824124615716761,\n", + " 0.7801756137064368,\n", + " 0.752123603797489,\n", + " 0.5288001066166764,\n", + " 0.7803207243582894,\n", + " 0.46932603202272666,\n", + " 0.9289505993097185,\n", + " 0.9355269013394406,\n", + " 0.9575783226069599,\n", + " 0.9381858563916292,\n", + " 0.9595487184435972,\n", + " 0.9310981298806621,\n", + " 0.9326698824044939,\n", + " 0.9011583356828196,\n", + " 0.9229446679118422,\n", + " 0.8835080122599095,\n", + " 0.900997315998892,\n", + " 0.8925830090019774,\n", + " 0.8831209397700821,\n", + " 0.8005351813156036,\n", + " 0.7216661520824456,\n", + " 0.7599841537178933,\n", + " 0.00015928860304218303,\n", + " 0.7890039749760369,\n", + " 0.9350205039206175,\n", + " 0.9072527188714308,\n", + " 0.9755350533464282,\n", + " 0.9857568595759039,\n", + " 0.9678514283741088,\n", + " 0.9640084861264626,\n", + " 0.7937728360920354,\n", + " 0.7358192205521833,\n", + " 0.8280615261984754,\n", + " 0.6155861711521265,\n", + " 0.765769361328416,\n", + " 0.06586886769950233,\n", + " 0.16902102092412868,\n", + " 0.1777562305251325,\n", + " 0.03417428544313084,\n", + " 0.5920674847464198,\n", + " 0.8983756086780812,\n", + " 0.9026389359055556,\n", + " 0.8033959963857669,\n", + " 0.8656146999667899,\n", + " 0.8842617423548295,\n", + " 0.7699184228028957,\n", + " 0.7393692326401295,\n", + " 0.7409772214293374,\n", + " 0.7367520868048991,\n", + " 0.8132329524856486,\n", + " 0.9962437013911171,\n", + " 0.722892426003275,\n", + " 0.8033061209341019,\n", + " 0.8384861555346423,\n", + " 0.772113571310073,\n", + " 0.9527586856351254,\n", + " 0.9646799659604668,\n", + " 0.986514763169265,\n", + " 0.9427641115911002,\n", + " 0.9253511419815649,\n", + " 0.708156275955675,\n", + " 0.8397211858505388,\n", + " 0.6529327645885434,\n", + " 0.7485936678973811,\n", + " 0.5424262779158605,\n", + " 0.7039738751110219,\n", + " 0.7115564871713116,\n", + " 0.8379814332471544,\n", + " 0.8348072491998175,\n", + " 0.8757869421845631,\n", + " 0.815678256826629,\n", + " 0.7474536543270377,\n", + " 0.6767413213717813,\n", + " 0.640657339034476,\n", + " 0.6523815703651415,\n", + " 0.6418980182491849,\n", + " 0.6129462714444834,\n", + " 0.39295436549915047,\n", + " 0.5644743579527517,\n", + " 0.33441623871917686,\n", + " 0.10509306719354915,\n", + " 0.6748054707237131,\n", + " 0.8447078794151619,\n", + " 0.9819493215529602,\n", + " 0.9758886636406668,\n", + " 0.956910373380083,\n", + " 0.9548584483556365,\n", + " 0.41541909262155086,\n", + " 0.3924680244561882,\n", + " 0.3043829008534716,\n", + " 0.2781340245897338,\n", + " 1.2613304502273227e-06,\n", + " 0.658320545976604,\n", + " 0.9044344414216838,\n", + " 0.8069121525438718,\n", + " 0.9175468414977661,\n", + " 0.9069898543923935,\n", + " 0.8845031336209687,\n", + " 0.8570323920168299,\n", + " 0.8444506604886417,\n", + " 0.9260400835125885,\n", + " 0.9521279704895809,\n", + " 0.9205819072144652,\n", + " 0.6700855236810579,\n", + " 0.5367980230695721,\n", + " 0.6648257668432667,\n", + " 0.5745808542458603,\n", + " 0.13588205146286958,\n", + " 0.09957800800658022,\n", + " 0.0003708176217721178,\n", + " 0.8020048674291576,\n", + " 0.8880341042224411,\n", + " 0.6999592957903684,\n", + " 0.7474841220967033,\n", + " 0.22926127573875427,\n", + " 0.00033600732679248027,\n", + " 0.8920917556632157,\n", + " 0.9274214892408668,\n", + " 0.9582527976047742,\n", + " 0.9592339060192541,\n", + " 0.921772367623308,\n", + " 0.9635262643689976,\n", + " 0.9428564233432247,\n", + " 0.9125914143621489,\n", + " 0.918287564643115,\n", + " 0.8983465794226003,\n", + " 0.9067773278570157,\n", + " 0.7690766562573458,\n", + " 0.9166026955495812,\n", + " 0.8612002116016616,\n", + " 0.8520244131364719,\n", + " 0.8547637916864208,\n", + " 0.8707030306639615,\n", + " 0.5975739379437166,\n", + " 0.27507122998180517,\n", + " 0.9488788450747389,\n", + " 0.956887677917689,\n", + " 0.9654419668201475,\n", + " 0.9494574706191363,\n", + " 0.9258484405656261,\n", + " 0.9500726011011853,\n", + " 0.903175174049695,\n", + " 0.8911462861039307,\n", + " 0.7972912362100846,\n", + " 0.7203290936846563,\n", + " 0.6548581427569794,\n", + " 0.5501021504126242,\n", + " 0.44935260458311965,\n", + " 0.42635329367001945,\n", + " 0.9236152451661244,\n", + " 0.9224879789082839,\n", + " 0.8850497342485366,\n", + " 0.9083710756329625,\n", + " 0.9237126782908551,\n", + " 0.894554288396283,\n", + " 0.18430150962149053,\n", + " 0.7302017007833491,\n", + " 0.7160779430502329,\n", + " 0.5549094071599893,\n", + " 0.5129254459725594,\n", + " 0.9167686155783427,\n", + " 0.9113956977765255,\n", + " 0.9075957802505701,\n", + " 0.9390146137210659,\n", + " 0.9472443861545131,\n", + " 0.9328337801205822,\n", + " 0.8370456950485082,\n", + " 0.5997948750333487,\n", + " 0.9343258168253256,\n", + " 0.9439457654128739,\n", + " 0.9280206595792848,\n", + " 0.9322397714362198,\n", + " 0.9420865049454272,\n", + " 0.9454225799163594,\n", + " 0.8996670718165377,\n", + " 0.88449093485601,\n", + " 0.8990662406292473,\n", + " 0.8900953570566361,\n", + " 0.881334207392159,\n", + " 0.8651976895713909,\n", + " 0.8861542877639267,\n", + " 0.8849661720305225,\n", + " 0.5954682765046174,\n", + " 0.5522866830533725,\n", + " 0.5448126820180267,\n", + " 0.02020368206170692,\n", + " 0.9526868500079836,\n", + " 0.9424311968328869,\n", + " 0.9473163038631395,\n", + " 0.879723352216857,\n", + " 0.81382887023091,\n", + " 0.5063033982685107,\n", + " 0.46737411250167554,\n", + " 0.20277801966357584,\n", + " 0.05393898062009144,\n", + " 0.06971893798849264,\n", + " 0.9581865115673758,\n", + " 0.9601040119424561,\n", + " 0.9513930191207932,\n", + " 0.9230437162952005,\n", + " 0.9841547435846786,\n", + " 0.9857412613884763,\n", + " 0.983817073652073,\n", + " 0.9836834675404416,\n", + " 0.7869420430902185,\n", + " 0.6433405687090631,\n", + " 0.4315214891481895,\n", + " 0.7260115220495794,\n", + " 0.6988575248832335,\n", + " 0.4534290907045462,\n", + " 0.30239167712821347,\n", + " 0.2860695589421745,\n", + " 0.5164761946551715,\n", + " 0.2522426702592779,\n", + " 0.00197960153557903,\n", + " 0.8618227167048974,\n", + " 0.9091713127289668,\n", + " 0.9095886710480476,\n", + " 0.9424076295974398,\n", + " 0.9251137411946267,\n", + " 0.8929125043943227,\n", + " 0.9227633808748936,\n", + " 0.9096963679485675,\n", + " 0.9101013093322737,\n", + " 0.8961549358286927,\n", + " 0.9118240696763442,\n", + " 0.9056960328154277,\n", + " 0.8530269355745219,\n", + " 0.6505794661814498,\n", + " 0.6572174453184076,\n", + " 0.5307738469776382,\n", + " 0.3032158592811353,\n", + " 0.0915386361343444,\n", + " 0.004460832036425752,\n", + " 0.7954037177041817,\n", + " 0.33520621222191427,\n", + " 0.42385219473283553,\n", + " 0.11415410523529125,\n", + " 0.12389657726885481,\n", + " 0.02389587535342872,\n", + " 0.9206299850509021,\n", + " 0.9300994696621918,\n", + " 0.9375880323643341,\n", + " 0.8795557027145566,\n", + " 0.8551662727717931,\n", + " 0.8805282055792464,\n", + " 0.930698190625293,\n", + " 0.9489751419877448,\n", + " 0.9381666964806937,\n", + " 0.8827748657109729,\n", + " 0.8984124855971098,\n", + " 0.8945580857060548,\n", + " 0.9059021382498761,\n", + " 0.8885983080620843,\n", + " 0.8947629451833445,\n", + " 0.9249509161078188,\n", + " 0.9464905593781671,\n", + " 0.9235464094778002,\n", + " 0.8882768513803266,\n", + " 0.8924985201185903,\n", + " 0.9270893953482611,\n", + " 0.866837535977548,\n", + " 0.90243929006013,\n", + " 0.906280116040601,\n", + " 0.8577710019492567,\n", + " 0.8921978169336371,\n", + " 0.8502797691409552,\n", + " 0.28546195947840236,\n", + " 0.398185288626109,\n", + " 0.5369426602985751,\n", + " 0.1799252069349064,\n", + " 0.7742189138362863,\n", + " 0.7925283745726461,\n", + " 0.8133157270464088,\n", + " 0.8575044927112226,\n", + " 0.7362845422519129,\n", + " 0.7036750601317182,\n", + " 0.6897696671216995,\n", + " 0.8086814631231332,\n", + " 0.428270427377073,\n", + " 0.6318889000676778,\n", + " 0.561726320410852,\n", + " 0.5128431768861994,\n", + " 0.5368918881622513,\n", + " 0.8892806685426429,\n", + " 0.8750926512298951,\n", + " 0.7806237583320027,\n", + " 0.7684665778260594,\n", + " 0.8615261308200153,\n", + " 0.846656442617957,\n", + " 0.860088758907636,\n", + " 0.8356035243644513,\n", + " 0.45912931041450084,\n", + " 0.9493197050938518,\n", + " 0.9714224821197369,\n", + " 0.9633009120360158,\n", + " 0.9592303457237034,\n", + " 0.8894417818107543,\n", + " 0.9103564405570354,\n", + " 0.914455582902078,\n", + " 0.8626489158865417,\n", + " 0.8247425923336972,\n", + " 0.843516106235579,\n", + " 0.8883461500276562,\n", + " 0.7888191019164008,\n", + " 0.7016468515587145,\n", + " 0.7952579699624737,\n", + " 0.7085263481165859,\n", + " 0.9957138759660945,\n", + " 0.6969633713373526,\n", + " 0.6650518687599166,\n", + " 0.5677168208198138,\n", + " 0.5010375501713816,\n", + " 0.6716937400870903,\n", + " 0.5263186842750209,\n", + " 0.3752190397600786,\n", + " 0.9277474224938042,\n", + " 0.9034246716757176,\n", + " 0.8623642126560304,\n", + " 0.8828365285715392,\n", + " 0.8069477291633591,\n", + " 0.5977840168390623,\n", + " 0.6575599972586608,\n", + " 0.7257005802652859,\n", + " 0.6616397497432462,\n", + " 0.7138497308406905,\n", + " 0.8559905862080395,\n", + " 0.5465469708869767,\n", + " 0.7247007978361263,\n", + " 0.4982202081321523,\n", + " 0.1435502364788124,\n", + " 0.8312160646997496,\n", + " 0.7673104339244738,\n", + " 0.9138910753213966,\n", + " 0.9079223080969111,\n", + " 0.9037941851912057,\n", + " 0.9054030389928486,\n", + " 0.8416027599455486,\n", + " 0.8777198982814582,\n", + " 0.896264047562824,\n", + " 0.9183928728613096,\n", + " 0.926234890989123,\n", + " 0.8477449274998132,\n", + " 0.9425756117883883,\n", + " 0.930425818690117,\n", + " 0.8676227063441375,\n", + " 0.9017682694985918,\n", + " 0.8250182416354152,\n", + " 0.8734526867624735,\n", + " 0.7508411926945059,\n", + " 0.7334557969611786,\n", + " 0.6953569988451388,\n", + " 0.6187676536913802,\n", + " 0.5272312526806292,\n", + " 0.25486197204460714,\n", + " 0.41581169188183437,\n", + " 0.08956291479210281,\n", + " 0.9057245705031539,\n", + " 0.8816137721066973,\n", + " 0.7197998157801146,\n", + " 0.8491990949993822,\n", + " 0.8285355188488683,\n", + " 0.8438579846834945,\n", + " 0.8165001169186776,\n", + " 0.7855076627569467,\n", + " 0.7731832923938429,\n", + " 0.7698280409087872,\n", + " 0.04306078773841739,\n", + " 0.9345193208423778,\n", + " 0.9185313864112215,\n", + " 0.9368296696015962,\n", + " 0.8842005995098796,\n", + " 0.8081473074740985,\n", + " 0.7437573722512592,\n", + " 0.4545129103156498,\n", + " 0.014301651634169134,\n", + " 0.7379679110011917,\n", + " 0.2667863361865194,\n", + " 0.05232194086368336,\n", + " 0.9101051873777896,\n", + " 0.9415855262556586,\n", + " 0.8912120472275669,\n", + " 0.8671014329978515,\n", + " 0.9332453379332762,\n", + " 0.898263104763695,\n", + " 0.9129162903881175,\n", + " 0.9119028246731342,\n", + " 0.757290127585658,\n", + " 0.8546953367453761,\n", + " 0.79905606009348,\n", + " 0.5200916429514422,\n", + " 0.41482087460157935,\n", + " 0.8658565068988433,\n", + " 0.8599416175566281,\n", + " 0.7432232481110141,\n", + " 0.4269325639335677,\n", + " 0.167144887201504,\n", + " 0.9358631812602864,\n", + " 0.9087759029109701,\n", + " 0.9246356483716845,\n", + " 0.8636000760257828,\n", + " 0.8502778060823685,\n", + " 0.8239945167311183,\n", + " 0.7913040226262458,\n", + " 0.8250240912526838,\n", + " 0.9151914446528286,\n", + " 0.9373793163952012,\n", + " 0.915873662134878,\n", + " 0.9348315744827719,\n", + " 0.8624048399418976,\n", + " 0.9045261098760244,\n", + " 0.9064311667504562,\n", + " 0.8813688497380183,\n", + " 0.9135094607462082,\n", + " 0.9482322225988866,\n", + " 0.9446357603383352,\n", + " 0.9173510791540908,\n", + " 0.9548461636548782,\n", + " 0.9444051511935787,\n", + " 0.909703525187174,\n", + " 0.900213638501747,\n", + " 0.9407011443944998,\n", + " 0.9569942070418173,\n", + " 0.8575145753354487,\n", + " 0.8770096789068078,\n", + " 0.7629346260779011,\n", + " 0.4987048733082472,\n", + " 0.9150327361260665,\n", + " 0.8887606331057204,\n", + " 0.8849814590168796,\n", + " 0.86831507399619,\n", + " 0.8443596925415235,\n", + " 0.7295135378625308,\n", + " 0.8547477188896778,\n", + " 0.8476220939507235,\n", + " 0.8157788085344,\n", + " 0.8620571095950187,\n", + " 0.8463417721218065,\n", + " 0.9375838631083165,\n", + " 0.9377547504095433,\n", + " 0.9142549931362388,\n", + " 0.8354730835481203,\n", + " 0.7345709733821627,\n", + " 0.763789204701828,\n", + " 0.6857649452929322,\n", + " 0.8197329367378913,\n", + " 0.8774874646026789,\n", + " 0.8034967082087436,\n", + " 0.7834574427067679,\n", + " 0.5343060468964511,\n", + " 0.4735468062357076,\n", + " 0.9774277899166476,\n", + " 0.9862950219071864,\n", + " 0.9867827570802857,\n", + " 0.9837798135660429,\n", + " 0.9837740836439626,\n", + " 0.9271520061949262,\n", + " 0.9465040187757684,\n", + " 0.9143759974442722,\n", + " 0.9198347819068121,\n", + " 0.8980926078628699,\n", + " 0.9158163187453798,\n", + " 0.9008031535477147,\n", + " 0.9047487601401539,\n", + " 0.9028384667697645,\n", + " 0.8846071578485817,\n", + " 0.9196068463296615,\n", + " 0.9332357764846212,\n", + " 0.8946345080109162,\n", + " 0.9148239352242258,\n", + " 0.8981043855909107,\n", + " 0.9272939713883867,\n", + " 0.9042590774195602,\n", + " 0.8958637294419638,\n", + " 0.8015070810488782,\n", + " 0.7596107228462914,\n", + " 0.724358889331443,\n", + " 0.2995966419620019,\n", + " 0.8911623357160338,\n", + " 0.8818683399376666,\n", + " 0.8558246202646637,\n", + " 0.582149983773494,\n", + " 0.07092568601575128,\n", + " 0.9480229205264794,\n", + " 0.9397867446391941,\n", + " 0.9340524066952843,\n", + " 0.9110450858911346,\n", + " 0.8757563279055323,\n", + " 0.8554999196296835,\n", + " 0.9085324590174707,\n", + " 0.9078851143828541,\n", + " 0.9088778045841828,\n", + " 0.6112634538501532,\n", + " 0.29206769385932985]" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[risk.y[3] for risk in out_risk]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/pbc_final_experiment.ipynb b/examples/pbc_final_experiment.ipynb new file mode 100644 index 0000000..d3428e9 --- /dev/null +++ b/examples/pbc_final_experiment.ipynb @@ -0,0 +1,816 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "from dsm import datasets, DeepRecurrentSurvivalMachines" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "x, t, e = datasets.load_dataset('PBC', sequential=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "x, t, e = np.array(x), np.array(t), np.array(e)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "times = np.quantile(unrollt(t), [0.25, .5, 0.75])" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([2.1054649 , 4.57507392, 7.15967583])" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "times" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "folds = np.array((list(range(4))*1000)[:len(x)])" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "def unrollx(data):\n", + " return np.vstack([dat for dat in data])" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "def unrollt(data):\n", + " return np.concatenate([dat for dat in data])" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 1.0951703 , 0.56948856, 14.15233819, ..., 2.92136677,\n", + " 1.86726536, 1.04588764])" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unrollt(t)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "from lifelines import CoxPHFitter" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [], + "source": [ + "def convert_to_data_frame(x, t, e):\n", + "\n", + " df = pd.DataFrame(data=x, columns=['X' + str(i) for i in range(x.shape[1])])\n", + " df['T'] = pd.DataFrame(data=t.reshape(-1, 1), columns=['T'])\n", + " df['E'] = pd.DataFrame(data=e.reshape(-1, 1), columns=['E'])\n", + "\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.metrics import concordance_index_ipcw, brier_score" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(478, 3)\n", + "(505, 3)\n", + "(457, 3)\n", + "(505, 3)\n" + ] + } + ], + "source": [ + "cis = []\n", + "brs = []\n", + "\n", + "for fold in set(folds): \n", + " \n", + " x_tr, t_tr, e_tr = x[folds!=fold], t[folds!=fold], e[folds!=fold]\n", + " x_te, t_te, e_te = x[folds==fold], t[folds==fold], e[folds==fold]\n", + " \n", + " x_tr = unrollx(x_tr)\n", + " t_tr = unrollt(t_tr)\n", + " e_tr = unrollt(e_tr)\n", + "\n", + " x_te = unrollx(x_te)\n", + " t_te = unrollt(t_te)\n", + " e_te = unrollt(e_te)\n", + " \n", + " df_tr = convert_to_data_frame(x_tr, t_tr, e_tr)\n", + "\n", + " model = CoxPHFitter(penalizer=1e-3).fit(df_tr, duration_col='T', event_col='E')\n", + " \n", + " preds = model.predict_survival_function(x_te, times).T.values\n", + " \n", + " et_tr = np.array([(e_tr[i], t_tr[i]) for i in range(len(e_tr))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " et_te = np.array([(e_te[i], t_te[i]) for i in range(len(e_te))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " \n", + " print (preds.shape)\n", + " \n", + " cis_ = []\n", + " for i in range(len(times)):\n", + " cis_.append(concordance_index_ipcw(et_tr, et_te, 1-preds[:,i], times[i])[0])\n", + " cis.append(cis_)\n", + " \n", + " brs.append(brier_score(et_tr, et_te, preds, times )[1])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[0.8622349652479977, 0.8496178660113483, 0.8359660027234909],\n", + " [0.9012252214624602, 0.8471243525992452, 0.7944140040816455],\n", + " [0.8467573397188136, 0.8212682235979585, 0.7283492496826471],\n", + " [0.896278572134772, 0.8628305222741215, 0.7505466410210001]]" + ] + }, + "execution_count": 127, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cis" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.60925486, 0.23441728, 0.04069175],\n", + " [0.55294913, 0.17648165, 0.0217467 ],\n", + " [0.74530162, 0.42290859, 0.14965458],\n", + " ...,\n", + " [0.9551488 , 0.87429316, 0.7434154 ],\n", + " [0.74180704, 0.41712966, 0.14517828],\n", + " [0.3992932 , 0.06804004, 0.00265338]])" + ] + }, + "execution_count": 114, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.predict_survival_function(x_te, times).T.values" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.3126132634036184,\n", + " 0.46647488732207254,\n", + " 0.5624185126209059,\n", + " 0.4059503413491215,\n", + " 0.4415347567963127,\n", + " 0.2670742533214408,\n", + " 0.06359725383083349,\n", + " 0.9349522484067815,\n", + " 0.9486941636172256,\n", + " 0.9522407447171799,\n", + " 0.9103188179868816,\n", + " 0.8824124615716761,\n", + " 0.7801756137064368,\n", + " 0.752123603797489,\n", + " 0.5288001066166764,\n", + " 0.7803207243582894,\n", + " 0.46932603202272666,\n", + " 0.9289505993097185,\n", + " 0.9355269013394406,\n", + " 0.9575783226069599,\n", + " 0.9381858563916292,\n", + " 0.9595487184435972,\n", + " 0.9310981298806621,\n", + " 0.9326698824044939,\n", + " 0.9011583356828196,\n", + " 0.9229446679118422,\n", + " 0.8835080122599095,\n", + " 0.900997315998892,\n", + " 0.8925830090019774,\n", + " 0.8831209397700821,\n", + " 0.8005351813156036,\n", + " 0.7216661520824456,\n", + " 0.7599841537178933,\n", + " 0.00015928860304218303,\n", + " 0.7890039749760369,\n", + " 0.9350205039206175,\n", + " 0.9072527188714308,\n", + " 0.9755350533464282,\n", + " 0.9857568595759039,\n", + " 0.9678514283741088,\n", + " 0.9640084861264626,\n", + " 0.7937728360920354,\n", + " 0.7358192205521833,\n", + " 0.8280615261984754,\n", + " 0.6155861711521265,\n", + " 0.765769361328416,\n", + " 0.06586886769950233,\n", + " 0.16902102092412868,\n", + " 0.1777562305251325,\n", + " 0.03417428544313084,\n", + " 0.5920674847464198,\n", + " 0.8983756086780812,\n", + " 0.9026389359055556,\n", + " 0.8033959963857669,\n", + " 0.8656146999667899,\n", + " 0.8842617423548295,\n", + " 0.7699184228028957,\n", + " 0.7393692326401295,\n", + " 0.7409772214293374,\n", + " 0.7367520868048991,\n", + " 0.8132329524856486,\n", + " 0.9962437013911171,\n", + " 0.722892426003275,\n", + " 0.8033061209341019,\n", + " 0.8384861555346423,\n", + " 0.772113571310073,\n", + " 0.9527586856351254,\n", + " 0.9646799659604668,\n", + " 0.986514763169265,\n", + " 0.9427641115911002,\n", + " 0.9253511419815649,\n", + " 0.708156275955675,\n", + " 0.8397211858505388,\n", + " 0.6529327645885434,\n", + " 0.7485936678973811,\n", + " 0.5424262779158605,\n", + " 0.7039738751110219,\n", + " 0.7115564871713116,\n", + " 0.8379814332471544,\n", + " 0.8348072491998175,\n", + " 0.8757869421845631,\n", + " 0.815678256826629,\n", + " 0.7474536543270377,\n", + " 0.6767413213717813,\n", + " 0.640657339034476,\n", + " 0.6523815703651415,\n", + " 0.6418980182491849,\n", + " 0.6129462714444834,\n", + " 0.39295436549915047,\n", + " 0.5644743579527517,\n", + " 0.33441623871917686,\n", + " 0.10509306719354915,\n", + " 0.6748054707237131,\n", + " 0.8447078794151619,\n", + " 0.9819493215529602,\n", + " 0.9758886636406668,\n", + " 0.956910373380083,\n", + " 0.9548584483556365,\n", + " 0.41541909262155086,\n", + " 0.3924680244561882,\n", + " 0.3043829008534716,\n", + " 0.2781340245897338,\n", + " 1.2613304502273227e-06,\n", + " 0.658320545976604,\n", + " 0.9044344414216838,\n", + " 0.8069121525438718,\n", + " 0.9175468414977661,\n", + " 0.9069898543923935,\n", + " 0.8845031336209687,\n", + " 0.8570323920168299,\n", + " 0.8444506604886417,\n", + " 0.9260400835125885,\n", + " 0.9521279704895809,\n", + " 0.9205819072144652,\n", + " 0.6700855236810579,\n", + " 0.5367980230695721,\n", + " 0.6648257668432667,\n", + " 0.5745808542458603,\n", + " 0.13588205146286958,\n", + " 0.09957800800658022,\n", + " 0.0003708176217721178,\n", + " 0.8020048674291576,\n", + " 0.8880341042224411,\n", + " 0.6999592957903684,\n", + " 0.7474841220967033,\n", + " 0.22926127573875427,\n", + " 0.00033600732679248027,\n", + " 0.8920917556632157,\n", + " 0.9274214892408668,\n", + " 0.9582527976047742,\n", + " 0.9592339060192541,\n", + " 0.921772367623308,\n", + " 0.9635262643689976,\n", + " 0.9428564233432247,\n", + " 0.9125914143621489,\n", + " 0.918287564643115,\n", + " 0.8983465794226003,\n", + " 0.9067773278570157,\n", + " 0.7690766562573458,\n", + " 0.9166026955495812,\n", + " 0.8612002116016616,\n", + " 0.8520244131364719,\n", + " 0.8547637916864208,\n", + " 0.8707030306639615,\n", + " 0.5975739379437166,\n", + " 0.27507122998180517,\n", + " 0.9488788450747389,\n", + " 0.956887677917689,\n", + " 0.9654419668201475,\n", + " 0.9494574706191363,\n", + " 0.9258484405656261,\n", + " 0.9500726011011853,\n", + " 0.903175174049695,\n", + " 0.8911462861039307,\n", + " 0.7972912362100846,\n", + " 0.7203290936846563,\n", + " 0.6548581427569794,\n", + " 0.5501021504126242,\n", + " 0.44935260458311965,\n", + " 0.42635329367001945,\n", + " 0.9236152451661244,\n", + " 0.9224879789082839,\n", + " 0.8850497342485366,\n", + " 0.9083710756329625,\n", + " 0.9237126782908551,\n", + " 0.894554288396283,\n", + " 0.18430150962149053,\n", + " 0.7302017007833491,\n", + " 0.7160779430502329,\n", + " 0.5549094071599893,\n", + " 0.5129254459725594,\n", + " 0.9167686155783427,\n", + " 0.9113956977765255,\n", + " 0.9075957802505701,\n", + " 0.9390146137210659,\n", + " 0.9472443861545131,\n", + " 0.9328337801205822,\n", + " 0.8370456950485082,\n", + " 0.5997948750333487,\n", + " 0.9343258168253256,\n", + " 0.9439457654128739,\n", + " 0.9280206595792848,\n", + " 0.9322397714362198,\n", + " 0.9420865049454272,\n", + " 0.9454225799163594,\n", + " 0.8996670718165377,\n", + " 0.88449093485601,\n", + " 0.8990662406292473,\n", + " 0.8900953570566361,\n", + " 0.881334207392159,\n", + " 0.8651976895713909,\n", + " 0.8861542877639267,\n", + " 0.8849661720305225,\n", + " 0.5954682765046174,\n", + " 0.5522866830533725,\n", + " 0.5448126820180267,\n", + " 0.02020368206170692,\n", + " 0.9526868500079836,\n", + " 0.9424311968328869,\n", + " 0.9473163038631395,\n", + " 0.879723352216857,\n", + " 0.81382887023091,\n", + " 0.5063033982685107,\n", + " 0.46737411250167554,\n", + " 0.20277801966357584,\n", + " 0.05393898062009144,\n", + " 0.06971893798849264,\n", + " 0.9581865115673758,\n", + " 0.9601040119424561,\n", + " 0.9513930191207932,\n", + " 0.9230437162952005,\n", + " 0.9841547435846786,\n", + " 0.9857412613884763,\n", + " 0.983817073652073,\n", + " 0.9836834675404416,\n", + " 0.7869420430902185,\n", + " 0.6433405687090631,\n", + " 0.4315214891481895,\n", + " 0.7260115220495794,\n", + " 0.6988575248832335,\n", + " 0.4534290907045462,\n", + " 0.30239167712821347,\n", + " 0.2860695589421745,\n", + " 0.5164761946551715,\n", + " 0.2522426702592779,\n", + " 0.00197960153557903,\n", + " 0.8618227167048974,\n", + " 0.9091713127289668,\n", + " 0.9095886710480476,\n", + " 0.9424076295974398,\n", + " 0.9251137411946267,\n", + " 0.8929125043943227,\n", + " 0.9227633808748936,\n", + " 0.9096963679485675,\n", + " 0.9101013093322737,\n", + " 0.8961549358286927,\n", + " 0.9118240696763442,\n", + " 0.9056960328154277,\n", + " 0.8530269355745219,\n", + " 0.6505794661814498,\n", + " 0.6572174453184076,\n", + " 0.5307738469776382,\n", + " 0.3032158592811353,\n", + " 0.0915386361343444,\n", + " 0.004460832036425752,\n", + " 0.7954037177041817,\n", + " 0.33520621222191427,\n", + " 0.42385219473283553,\n", + " 0.11415410523529125,\n", + " 0.12389657726885481,\n", + " 0.02389587535342872,\n", + " 0.9206299850509021,\n", + " 0.9300994696621918,\n", + " 0.9375880323643341,\n", + " 0.8795557027145566,\n", + " 0.8551662727717931,\n", + " 0.8805282055792464,\n", + " 0.930698190625293,\n", + " 0.9489751419877448,\n", + " 0.9381666964806937,\n", + " 0.8827748657109729,\n", + " 0.8984124855971098,\n", + " 0.8945580857060548,\n", + " 0.9059021382498761,\n", + " 0.8885983080620843,\n", + " 0.8947629451833445,\n", + " 0.9249509161078188,\n", + " 0.9464905593781671,\n", + " 0.9235464094778002,\n", + " 0.8882768513803266,\n", + " 0.8924985201185903,\n", + " 0.9270893953482611,\n", + " 0.866837535977548,\n", + " 0.90243929006013,\n", + " 0.906280116040601,\n", + " 0.8577710019492567,\n", + " 0.8921978169336371,\n", + " 0.8502797691409552,\n", + " 0.28546195947840236,\n", + " 0.398185288626109,\n", + " 0.5369426602985751,\n", + " 0.1799252069349064,\n", + " 0.7742189138362863,\n", + " 0.7925283745726461,\n", + " 0.8133157270464088,\n", + " 0.8575044927112226,\n", + " 0.7362845422519129,\n", + " 0.7036750601317182,\n", + " 0.6897696671216995,\n", + " 0.8086814631231332,\n", + " 0.428270427377073,\n", + " 0.6318889000676778,\n", + " 0.561726320410852,\n", + " 0.5128431768861994,\n", + " 0.5368918881622513,\n", + " 0.8892806685426429,\n", + " 0.8750926512298951,\n", + " 0.7806237583320027,\n", + " 0.7684665778260594,\n", + " 0.8615261308200153,\n", + " 0.846656442617957,\n", + " 0.860088758907636,\n", + " 0.8356035243644513,\n", + " 0.45912931041450084,\n", + " 0.9493197050938518,\n", + " 0.9714224821197369,\n", + " 0.9633009120360158,\n", + " 0.9592303457237034,\n", + " 0.8894417818107543,\n", + " 0.9103564405570354,\n", + " 0.914455582902078,\n", + " 0.8626489158865417,\n", + " 0.8247425923336972,\n", + " 0.843516106235579,\n", + " 0.8883461500276562,\n", + " 0.7888191019164008,\n", + " 0.7016468515587145,\n", + " 0.7952579699624737,\n", + " 0.7085263481165859,\n", + " 0.9957138759660945,\n", + " 0.6969633713373526,\n", + " 0.6650518687599166,\n", + " 0.5677168208198138,\n", + " 0.5010375501713816,\n", + " 0.6716937400870903,\n", + " 0.5263186842750209,\n", + " 0.3752190397600786,\n", + " 0.9277474224938042,\n", + " 0.9034246716757176,\n", + " 0.8623642126560304,\n", + " 0.8828365285715392,\n", + " 0.8069477291633591,\n", + " 0.5977840168390623,\n", + " 0.6575599972586608,\n", + " 0.7257005802652859,\n", + " 0.6616397497432462,\n", + " 0.7138497308406905,\n", + " 0.8559905862080395,\n", + " 0.5465469708869767,\n", + " 0.7247007978361263,\n", + " 0.4982202081321523,\n", + " 0.1435502364788124,\n", + " 0.8312160646997496,\n", + " 0.7673104339244738,\n", + " 0.9138910753213966,\n", + " 0.9079223080969111,\n", + " 0.9037941851912057,\n", + " 0.9054030389928486,\n", + " 0.8416027599455486,\n", + " 0.8777198982814582,\n", + " 0.896264047562824,\n", + " 0.9183928728613096,\n", + " 0.926234890989123,\n", + " 0.8477449274998132,\n", + " 0.9425756117883883,\n", + " 0.930425818690117,\n", + " 0.8676227063441375,\n", + " 0.9017682694985918,\n", + " 0.8250182416354152,\n", + " 0.8734526867624735,\n", + " 0.7508411926945059,\n", + " 0.7334557969611786,\n", + " 0.6953569988451388,\n", + " 0.6187676536913802,\n", + " 0.5272312526806292,\n", + " 0.25486197204460714,\n", + " 0.41581169188183437,\n", + " 0.08956291479210281,\n", + " 0.9057245705031539,\n", + " 0.8816137721066973,\n", + " 0.7197998157801146,\n", + " 0.8491990949993822,\n", + " 0.8285355188488683,\n", + " 0.8438579846834945,\n", + " 0.8165001169186776,\n", + " 0.7855076627569467,\n", + " 0.7731832923938429,\n", + " 0.7698280409087872,\n", + " 0.04306078773841739,\n", + " 0.9345193208423778,\n", + " 0.9185313864112215,\n", + " 0.9368296696015962,\n", + " 0.8842005995098796,\n", + " 0.8081473074740985,\n", + " 0.7437573722512592,\n", + " 0.4545129103156498,\n", + " 0.014301651634169134,\n", + " 0.7379679110011917,\n", + " 0.2667863361865194,\n", + " 0.05232194086368336,\n", + " 0.9101051873777896,\n", + " 0.9415855262556586,\n", + " 0.8912120472275669,\n", + " 0.8671014329978515,\n", + " 0.9332453379332762,\n", + " 0.898263104763695,\n", + " 0.9129162903881175,\n", + " 0.9119028246731342,\n", + " 0.757290127585658,\n", + " 0.8546953367453761,\n", + " 0.79905606009348,\n", + " 0.5200916429514422,\n", + " 0.41482087460157935,\n", + " 0.8658565068988433,\n", + " 0.8599416175566281,\n", + " 0.7432232481110141,\n", + " 0.4269325639335677,\n", + " 0.167144887201504,\n", + " 0.9358631812602864,\n", + " 0.9087759029109701,\n", + " 0.9246356483716845,\n", + " 0.8636000760257828,\n", + " 0.8502778060823685,\n", + " 0.8239945167311183,\n", + " 0.7913040226262458,\n", + " 0.8250240912526838,\n", + " 0.9151914446528286,\n", + " 0.9373793163952012,\n", + " 0.915873662134878,\n", + " 0.9348315744827719,\n", + " 0.8624048399418976,\n", + " 0.9045261098760244,\n", + " 0.9064311667504562,\n", + " 0.8813688497380183,\n", + " 0.9135094607462082,\n", + " 0.9482322225988866,\n", + " 0.9446357603383352,\n", + " 0.9173510791540908,\n", + " 0.9548461636548782,\n", + " 0.9444051511935787,\n", + " 0.909703525187174,\n", + " 0.900213638501747,\n", + " 0.9407011443944998,\n", + " 0.9569942070418173,\n", + " 0.8575145753354487,\n", + " 0.8770096789068078,\n", + " 0.7629346260779011,\n", + " 0.4987048733082472,\n", + " 0.9150327361260665,\n", + " 0.8887606331057204,\n", + " 0.8849814590168796,\n", + " 0.86831507399619,\n", + " 0.8443596925415235,\n", + " 0.7295135378625308,\n", + " 0.8547477188896778,\n", + " 0.8476220939507235,\n", + " 0.8157788085344,\n", + " 0.8620571095950187,\n", + " 0.8463417721218065,\n", + " 0.9375838631083165,\n", + " 0.9377547504095433,\n", + " 0.9142549931362388,\n", + " 0.8354730835481203,\n", + " 0.7345709733821627,\n", + " 0.763789204701828,\n", + " 0.6857649452929322,\n", + " 0.8197329367378913,\n", + " 0.8774874646026789,\n", + " 0.8034967082087436,\n", + " 0.7834574427067679,\n", + " 0.5343060468964511,\n", + " 0.4735468062357076,\n", + " 0.9774277899166476,\n", + " 0.9862950219071864,\n", + " 0.9867827570802857,\n", + " 0.9837798135660429,\n", + " 0.9837740836439626,\n", + " 0.9271520061949262,\n", + " 0.9465040187757684,\n", + " 0.9143759974442722,\n", + " 0.9198347819068121,\n", + " 0.8980926078628699,\n", + " 0.9158163187453798,\n", + " 0.9008031535477147,\n", + " 0.9047487601401539,\n", + " 0.9028384667697645,\n", + " 0.8846071578485817,\n", + " 0.9196068463296615,\n", + " 0.9332357764846212,\n", + " 0.8946345080109162,\n", + " 0.9148239352242258,\n", + " 0.8981043855909107,\n", + " 0.9272939713883867,\n", + " 0.9042590774195602,\n", + " 0.8958637294419638,\n", + " 0.8015070810488782,\n", + " 0.7596107228462914,\n", + " 0.724358889331443,\n", + " 0.2995966419620019,\n", + " 0.8911623357160338,\n", + " 0.8818683399376666,\n", + " 0.8558246202646637,\n", + " 0.582149983773494,\n", + " 0.07092568601575128,\n", + " 0.9480229205264794,\n", + " 0.9397867446391941,\n", + " 0.9340524066952843,\n", + " 0.9110450858911346,\n", + " 0.8757563279055323,\n", + " 0.8554999196296835,\n", + " 0.9085324590174707,\n", + " 0.9078851143828541,\n", + " 0.9088778045841828,\n", + " 0.6112634538501532,\n", + " 0.29206769385932985]" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[risk.y[3] for risk in out_risk]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}