Skip to content

Commit

Permalink
Adds minimum impurity decrease hyper-param and a feature importance b…
Browse files Browse the repository at this point in the history
…ased on it (#35)

* format

* Exclude some debug code
  • Loading branch information
lgmoneda authored Nov 15, 2021
1 parent fca99df commit 184c750
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 17 deletions.
52 changes: 50 additions & 2 deletions time_robust_forest/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def score_by_period(
aggregation.
"""
if criterion == "gini":
current_score = gini_impurity_score_by_period(right_dict, left_dict)
current_score = gini_impurity_score_by_period(
right_dict, left_dict, verbose=verbose
)

if criterion == "std":
current_score = std_score_by_period(right_dict, left_dict)
Expand All @@ -98,6 +100,52 @@ def score_by_period(
return np.max(current_score)


def impurity_decrease_by_period(
right_dict,
left_dict,
total_sample,
period_criterion="avg",
verbose=False,
):

impurity_decreases = []

for key in right_dict.keys():
total_count = left_dict[key]["count"] + right_dict[key]["count"]
total_positive = left_dict[key]["sum"] + right_dict[key]["sum"]
p_positive = total_positive / total_count

previous_impurity = 1 - ((1 - p_positive) ** 2 + (p_positive) ** 2)

left_proba = left_dict[key]["sum"] / float(left_dict[key]["count"])
left_gini = 1 - ((1 - left_proba) ** 2 + (left_proba) ** 2)

right_proba = right_dict[key]["sum"] / float(right_dict[key]["count"])
right_gini = 1 - ((1 - right_proba) ** 2 + (right_proba) ** 2)

score = left_gini * (
left_dict[key]["count"] / total_count
) + right_gini * (right_dict[key]["count"] / total_count)

impurity_decrease = (
total_count / total_sample[key] * (previous_impurity - score)
)

if verbose:
print(
"Period: {}, score:{}, left: {}, right: {}, impurity decrease: {}".format(
key, score, left_gini, right_gini, impurity_decrease
)
)

impurity_decreases.append(impurity_decrease)

if period_criterion == "avg":
return np.mean(impurity_decreases)
else:
return np.min(impurity_decreases)


def std_score_by_period(right_dict, left_dict, norm=False):
"""
Calculate the standard deviation score by period given two dictionaries that
Expand Down Expand Up @@ -133,7 +181,7 @@ def std_score_by_period(right_dict, left_dict, norm=False):
return current_score


def gini_impurity_score_by_period(right_dict, left_dict):
def gini_impurity_score_by_period(right_dict, left_dict, verbose=False):
"""
Calculate the gini impurity score by period given two dictionaries that
charactize the left and right leaf after the potential split.
Expand Down
87 changes: 72 additions & 15 deletions time_robust_forest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
check_min_sample_periods,
check_min_sample_periods_dict,
fill_right_dict,
impurity_decrease_by_period,
initialize_period_dict,
score_by_period,
)
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
max_features="auto",
bootstrapping=True,
period_criterion="avg",
min_impurity_decrease=0,
n_jobs=-1,
multi=True,
):
Expand Down Expand Up @@ -226,6 +228,7 @@ def __init__(
bootstrapping=True,
criterion="gini",
period_criterion="avg",
min_impurity_decrease=0,
n_jobs=-1,
multi=True,
):
Expand All @@ -239,6 +242,7 @@ def __init__(
self.bootstrapping = bootstrapping
self.criterion = criterion
self.period_criterion = period_criterion
self.min_impurity_decrease = min_impurity_decrease

def fit(self, X, y, sample_weight=None, verbose=False):
"""
Expand All @@ -263,6 +267,7 @@ def fit(self, X, y, sample_weight=None, verbose=False):

self.train_target_proportion = np.mean(y)
self.classes_ = np.unique(y)
self.total_sample = X[self.time_column].value_counts().to_dict()

self.n_estimators_ = []
if not self.multi:
Expand All @@ -281,6 +286,8 @@ def fit(self, X, y, sample_weight=None, verbose=False):
max_features=self.max_features,
criterion=self.criterion,
period_criterion=self.period_criterion,
min_impurity_decrease=self.min_impurity_decrease,
total_sample=self.total_sample,
random_state=i,
)
for i in range(self.n_estimators)
Expand All @@ -300,6 +307,8 @@ def fit(self, X, y, sample_weight=None, verbose=False):
verbose=verbose,
max_features=self.max_features,
period_criterion=self.period_criterion,
min_impurity_decrease=self.min_impurity_decrease,
total_sample=self.total_sample,
criterion=self.criterion,
random_state=i,
)
Expand Down Expand Up @@ -353,7 +362,7 @@ def score(self, X, y):
predictions = self.predict_proba_(X)
return metrics.roc_auc_score(y, predictions)

def feature_importance(self):
def feature_importance(self, impurity_decrease=False):
"""
Retrieves the feature importance in terms of number of
times a feature was used to split the data.
Expand All @@ -363,13 +372,15 @@ def feature_importance(self):
return (
pd.concat(
[
n_estimator.feature_importance()
n_estimator.feature_importance(
impurity_decrease=impurity_decrease
)
for n_estimator in self.n_estimators_
]
)
.groupby(level=0)
.groupby("Feature")
.sum()
.sort_values(ascending=False)
.sort_values(by="Importance", ascending=False)
)


Expand Down Expand Up @@ -420,15 +431,19 @@ def __init__(
bootstrapping=True,
criterion="gini",
period_criterion="avg",
min_impurity_decrease=0,
total_sample=None,
min_sample_periods=100,
sample_weight=None,
depth=None,
verbose=False,
split_verbose=False,
impurity_verbose=False,
random_state=42,
):
if len(row_indexes) == 0:
row_indexes = np.arange(len(y))
X.reset_index(inplace=True, drop=True)
### Reindex
if depth == None:
depth = 0
Expand Down Expand Up @@ -456,11 +471,14 @@ def __init__(
self.min_sample_periods = min_sample_periods
self.verbose = verbose
self.split_verbose = split_verbose
self.impurity_verbose = impurity_verbose
self.max_features = max_features
self.split_variable = "LEAF"
self.bootstrapping = bootstrapping
self.criterion = criterion
self.period_criterion = period_criterion
self.min_impurity_decrease = min_impurity_decrease
self.total_sample = total_sample

if sample_weight is not None:
self.sample_weight = sample_weight
Expand Down Expand Up @@ -540,8 +558,12 @@ def create_split(self):
min_leaf=self.min_leaf,
criterion=self.criterion,
period_criterion=self.period_criterion,
min_impurity_decrease=self.min_impurity_decrease,
total_sample=self.total_sample,
sample_weight=self.sample_weight,
verbose=self.verbose,
split_verbose=self.split_verbose,
impurity_verbose=self.impurity_verbose,
)
self.right_split = _RandomTimeSplitTree(
self.X,
Expand All @@ -556,8 +578,12 @@ def create_split(self):
min_leaf=self.min_leaf,
criterion=self.criterion,
period_criterion=self.period_criterion,
min_impurity_decrease=self.min_impurity_decrease,
total_sample=self.total_sample,
sample_weight=self.sample_weight,
verbose=self.verbose,
split_verbose=self.split_verbose,
impurity_verbose=self.impurity_verbose,
)

def find_better_split(self, variable, variable_idx):
Expand Down Expand Up @@ -625,6 +651,7 @@ def find_better_split(self, variable, variable_idx):

if self.split_verbose:
print(f"Evaluate a split on variable {variable} at value {x_i}")

current_score = score_by_period(
right_period_dict,
left_period_dict,
Expand All @@ -634,12 +661,22 @@ def find_better_split(self, variable, variable_idx):
)

if current_score < self.score:
self.split_variable, self.score, self.split_example = (
variable,
current_score,
x_i,
impurity_decrease = impurity_decrease_by_period(
right_period_dict,
left_period_dict,
self.total_sample,
self.period_criterion,
self.impurity_verbose,
)
self.split_variable_idx = variable_idx

if impurity_decrease >= self.min_impurity_decrease:
self.split_variable, self.score, self.split_example = (
variable,
current_score,
x_i,
)
self.split_variable_idx = variable_idx
self.impurity_decrease = impurity_decrease

def _is_leaf(self):
"""
Expand Down Expand Up @@ -691,7 +728,19 @@ def _get_split_variable(self):
)
return "LEAF"

def feature_importance(self):
def _get_impurity_decrease(self):
"""
Returns the splitting variable name for the current tree instance.
"""
if not self._is_leaf():
return (
[self.impurity_decrease]
+ self.left_split._get_impurity_decrease()
+ self.right_split._get_impurity_decrease()
)
return ["LEAF"]

def feature_importance(self, impurity_decrease=False):
"""
Retrieves the feature importance in terms of number of
times a feature was used to split the data.
Expand All @@ -701,10 +750,18 @@ def feature_importance(self):
splits = self._get_split_variable()
splits_features = splits.replace("@LEAF", "").split("@")

if impurity_decrease:
impurity_decreases = self._get_impurity_decrease()
impurity_decreases = [i for i in impurity_decreases if i != "LEAF"]
importance = impurity_decreases
else:
importance = [1 for i in splits_features]

return (
pd.DataFrame(splits_features, columns=["Feature Importance"])[
"Feature Importance"
]
.value_counts()
.sort_values(ascending=False)
pd.DataFrame(
zip(splits_features, importance),
columns=["Feature", "Importance"],
)
.groupby("Feature")
.sum()
)

0 comments on commit 184c750

Please sign in to comment.