diff --git a/time_robust_forest/functions.py b/time_robust_forest/functions.py index 34c05f0..936c73c 100644 --- a/time_robust_forest/functions.py +++ b/time_robust_forest/functions.py @@ -81,7 +81,9 @@ def score_by_period( aggregation. """ if criterion == "gini": - current_score = gini_impurity_score_by_period(right_dict, left_dict) + current_score = gini_impurity_score_by_period( + right_dict, left_dict, verbose=verbose + ) if criterion == "std": current_score = std_score_by_period(right_dict, left_dict) @@ -98,6 +100,52 @@ def score_by_period( return np.max(current_score) +def impurity_decrease_by_period( + right_dict, + left_dict, + total_sample, + period_criterion="avg", + verbose=False, +): + + impurity_decreases = [] + + for key in right_dict.keys(): + total_count = left_dict[key]["count"] + right_dict[key]["count"] + total_positive = left_dict[key]["sum"] + right_dict[key]["sum"] + p_positive = total_positive / total_count + + previous_impurity = 1 - ((1 - p_positive) ** 2 + (p_positive) ** 2) + + left_proba = left_dict[key]["sum"] / float(left_dict[key]["count"]) + left_gini = 1 - ((1 - left_proba) ** 2 + (left_proba) ** 2) + + right_proba = right_dict[key]["sum"] / float(right_dict[key]["count"]) + right_gini = 1 - ((1 - right_proba) ** 2 + (right_proba) ** 2) + + score = left_gini * ( + left_dict[key]["count"] / total_count + ) + right_gini * (right_dict[key]["count"] / total_count) + + impurity_decrease = ( + total_count / total_sample[key] * (previous_impurity - score) + ) + + if verbose: + print( + "Period: {}, score:{}, left: {}, right: {}, impurity decrease: {}".format( + key, score, left_gini, right_gini, impurity_decrease + ) + ) + + impurity_decreases.append(impurity_decrease) + + if period_criterion == "avg": + return np.mean(impurity_decreases) + else: + return np.min(impurity_decreases) + + def std_score_by_period(right_dict, left_dict, norm=False): """ Calculate the standard deviation score by period given two dictionaries that @@ -133,7 +181,7 @@ def std_score_by_period(right_dict, left_dict, norm=False): return current_score -def gini_impurity_score_by_period(right_dict, left_dict): +def gini_impurity_score_by_period(right_dict, left_dict, verbose=False): """ Calculate the gini impurity score by period given two dictionaries that charactize the left and right leaf after the potential split. diff --git a/time_robust_forest/models.py b/time_robust_forest/models.py index 49415e1..db76bb1 100644 --- a/time_robust_forest/models.py +++ b/time_robust_forest/models.py @@ -10,6 +10,7 @@ check_min_sample_periods, check_min_sample_periods_dict, fill_right_dict, + impurity_decrease_by_period, initialize_period_dict, score_by_period, ) @@ -51,6 +52,7 @@ def __init__( max_features="auto", bootstrapping=True, period_criterion="avg", + min_impurity_decrease=0, n_jobs=-1, multi=True, ): @@ -226,6 +228,7 @@ def __init__( bootstrapping=True, criterion="gini", period_criterion="avg", + min_impurity_decrease=0, n_jobs=-1, multi=True, ): @@ -239,6 +242,7 @@ def __init__( self.bootstrapping = bootstrapping self.criterion = criterion self.period_criterion = period_criterion + self.min_impurity_decrease = min_impurity_decrease def fit(self, X, y, sample_weight=None, verbose=False): """ @@ -263,6 +267,7 @@ def fit(self, X, y, sample_weight=None, verbose=False): self.train_target_proportion = np.mean(y) self.classes_ = np.unique(y) + self.total_sample = X[self.time_column].value_counts().to_dict() self.n_estimators_ = [] if not self.multi: @@ -281,6 +286,8 @@ def fit(self, X, y, sample_weight=None, verbose=False): max_features=self.max_features, criterion=self.criterion, period_criterion=self.period_criterion, + min_impurity_decrease=self.min_impurity_decrease, + total_sample=self.total_sample, random_state=i, ) for i in range(self.n_estimators) @@ -300,6 +307,8 @@ def fit(self, X, y, sample_weight=None, verbose=False): verbose=verbose, max_features=self.max_features, period_criterion=self.period_criterion, + min_impurity_decrease=self.min_impurity_decrease, + total_sample=self.total_sample, criterion=self.criterion, random_state=i, ) @@ -353,7 +362,7 @@ def score(self, X, y): predictions = self.predict_proba_(X) return metrics.roc_auc_score(y, predictions) - def feature_importance(self): + def feature_importance(self, impurity_decrease=False): """ Retrieves the feature importance in terms of number of times a feature was used to split the data. @@ -363,13 +372,15 @@ def feature_importance(self): return ( pd.concat( [ - n_estimator.feature_importance() + n_estimator.feature_importance( + impurity_decrease=impurity_decrease + ) for n_estimator in self.n_estimators_ ] ) - .groupby(level=0) + .groupby("Feature") .sum() - .sort_values(ascending=False) + .sort_values(by="Importance", ascending=False) ) @@ -420,15 +431,19 @@ def __init__( bootstrapping=True, criterion="gini", period_criterion="avg", + min_impurity_decrease=0, + total_sample=None, min_sample_periods=100, sample_weight=None, depth=None, verbose=False, split_verbose=False, + impurity_verbose=False, random_state=42, ): if len(row_indexes) == 0: row_indexes = np.arange(len(y)) + X.reset_index(inplace=True, drop=True) ### Reindex if depth == None: depth = 0 @@ -456,11 +471,14 @@ def __init__( self.min_sample_periods = min_sample_periods self.verbose = verbose self.split_verbose = split_verbose + self.impurity_verbose = impurity_verbose self.max_features = max_features self.split_variable = "LEAF" self.bootstrapping = bootstrapping self.criterion = criterion self.period_criterion = period_criterion + self.min_impurity_decrease = min_impurity_decrease + self.total_sample = total_sample if sample_weight is not None: self.sample_weight = sample_weight @@ -540,8 +558,12 @@ def create_split(self): min_leaf=self.min_leaf, criterion=self.criterion, period_criterion=self.period_criterion, + min_impurity_decrease=self.min_impurity_decrease, + total_sample=self.total_sample, sample_weight=self.sample_weight, verbose=self.verbose, + split_verbose=self.split_verbose, + impurity_verbose=self.impurity_verbose, ) self.right_split = _RandomTimeSplitTree( self.X, @@ -556,8 +578,12 @@ def create_split(self): min_leaf=self.min_leaf, criterion=self.criterion, period_criterion=self.period_criterion, + min_impurity_decrease=self.min_impurity_decrease, + total_sample=self.total_sample, sample_weight=self.sample_weight, verbose=self.verbose, + split_verbose=self.split_verbose, + impurity_verbose=self.impurity_verbose, ) def find_better_split(self, variable, variable_idx): @@ -625,6 +651,7 @@ def find_better_split(self, variable, variable_idx): if self.split_verbose: print(f"Evaluate a split on variable {variable} at value {x_i}") + current_score = score_by_period( right_period_dict, left_period_dict, @@ -634,12 +661,22 @@ def find_better_split(self, variable, variable_idx): ) if current_score < self.score: - self.split_variable, self.score, self.split_example = ( - variable, - current_score, - x_i, + impurity_decrease = impurity_decrease_by_period( + right_period_dict, + left_period_dict, + self.total_sample, + self.period_criterion, + self.impurity_verbose, ) - self.split_variable_idx = variable_idx + + if impurity_decrease >= self.min_impurity_decrease: + self.split_variable, self.score, self.split_example = ( + variable, + current_score, + x_i, + ) + self.split_variable_idx = variable_idx + self.impurity_decrease = impurity_decrease def _is_leaf(self): """ @@ -691,7 +728,19 @@ def _get_split_variable(self): ) return "LEAF" - def feature_importance(self): + def _get_impurity_decrease(self): + """ + Returns the splitting variable name for the current tree instance. + """ + if not self._is_leaf(): + return ( + [self.impurity_decrease] + + self.left_split._get_impurity_decrease() + + self.right_split._get_impurity_decrease() + ) + return ["LEAF"] + + def feature_importance(self, impurity_decrease=False): """ Retrieves the feature importance in terms of number of times a feature was used to split the data. @@ -701,10 +750,18 @@ def feature_importance(self): splits = self._get_split_variable() splits_features = splits.replace("@LEAF", "").split("@") + if impurity_decrease: + impurity_decreases = self._get_impurity_decrease() + impurity_decreases = [i for i in impurity_decreases if i != "LEAF"] + importance = impurity_decreases + else: + importance = [1 for i in splits_features] + return ( - pd.DataFrame(splits_features, columns=["Feature Importance"])[ - "Feature Importance" - ] - .value_counts() - .sort_values(ascending=False) + pd.DataFrame( + zip(splits_features, importance), + columns=["Feature", "Importance"], + ) + .groupby("Feature") + .sum() )