Skip to content

Commit

Permalink
modified: phenotyping.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Mar 29, 2022
1 parent 338b155 commit 7cb7ab3
Showing 1 changed file with 61 additions and 3 deletions.
64 changes: 61 additions & 3 deletions auton_survival/phenotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"""Utilities to phenotype individuals based on similar survival
characteristics."""

from random import random
from re import I
import numpy as np
import pandas as pd

Expand All @@ -32,6 +34,7 @@
from sklearn import cluster, decomposition, mixture

from auton_survival.utils import _get_method_kwargs
from auton_survival.experiments import CounterfactualSurvivalRegressionCV


class Phenotyper:
Expand Down Expand Up @@ -408,9 +411,64 @@ def fit_phenotype(self, features):

return self.fit(features).phenotype(features)

class CoxMixturePhenotyper:
class SurvivalVirtualTwinsPhenotyper(object):

""""Not Yet Implemented"""

def __init__(self):
raise NotImplementedError()

_VALID_PHENO_METHODS = ['rsf']

def __init__(self,
cf_method='dcph',
phenotyping_method='rsf',
cf_method_hyperparams=None,
phenotyping_method_hyperparams=None,
random_seed=0):

raise NotImplementedError()

assert cf_method in CounterfactualSurvivalRegressionCV._VALID_CF_METHODS, "Invalid Counterfactual Method: "+cf_method
assert phenotyping_method in self._VALID_PHENO_METHODS, "Invalid Phenotyping Method: "+phenotyping_method

if cf_method_hyperparams is None:
cf_method_hyperparams = {}
if phenotyping_method_hyperparams is None:
phenotyping_method_hyperparams = {}

self.cf_method = cf_method
self.phenotyping_method = phenotyping_method
self.random_seed = random_seed

def fit(self, features, outcomes, interventions, horizon):

raise NotImplementedError()

cf_model = CounterfactualSurvivalRegressionCV(**self.cf_method_hyperparams)

self.cf_model = cf_model.fit(features, outcomes, interventions)

times = np.unique(outcomes.times.values)
cf_predictions = self.cf_model.predict_counterfactual_survival(features, interventions, times)

ite_estimates = cf_predictions[1] - cf_predictions[0]

if self.phenotyping_method == 'rsf':

from sklearn.ensemble import RandomForestRegressor

pheno_model = RandomForestRegressor(**self.phenotyping_method_hyperparams)
pheno_model.fit(features.values, ite_estimates)

self.pheno_model = pheno_model

def predict(self, features):

raise NotImplementedError()

phenotype_preds= self.pheno_model.predict(features)
phenotype_preds = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min())
return phenotype_preds




0 comments on commit 7cb7ab3

Please sign in to comment.