Skip to content

Commit

Permalink
fix permutation importance for dataset of len 1 (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt authored Aug 9, 2023
1 parent 192bdbe commit fde91c9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
21 changes: 14 additions & 7 deletions bofire/surrogates/feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ def permutation_importance(
k.name: {feature.key: [] for feature in model.inputs} for k in metrics.keys()
}
pred = model.predict(X)
original_metrics = {
k.name: metrics[k](y[output_key].values, pred[output_key + "_pred"].values) # type: ignore
for k in metrics.keys()
}
if len(pred) >= 2:
original_metrics = {
k.name: metrics[k](y[output_key].values, pred[output_key + "_pred"].values) # type: ignore
for k in metrics.keys()
}
else:
original_metrics = {k.name: np.nan for k in metrics.keys()}

for feature in model.inputs:
for _ in range(n_repeats):
Expand All @@ -62,9 +65,13 @@ def permutation_importance(
pred = model.predict(X_i)
# compute scores
for metricenum, metric in metrics.items():
prelim_results[metricenum.name][feature.key].append(
metric(y[output_key].values, pred[output_key + "_pred"].values) # type: ignore
)
if len(pred) >= 2:
prelim_results[metricenum.name][feature.key].append(
metric(y[output_key].values, pred[output_key + "_pred"].values) # type: ignore
)
else:
prelim_results[metricenum.name][feature.key].append(np.nan) # type: ignore

# convert dictionaries to dataframe for easier postprocessing and statistics
# and return
results = {}
Expand Down
4 changes: 2 additions & 2 deletions bofire/surrogates/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def cross_validate(
raise NotImplementedError(
"Cross validation not implemented for multi-output models"
)
# first filter the experiments based on the model setting
experiments = self._preprocess_experiments(experiments)
n = len(experiments)
if folds > n:
warnings.warn(
Expand All @@ -117,8 +119,6 @@ def cross_validate(
# instantiate kfold object
cv = KFold(n_splits=folds, shuffle=True, random_state=random_state)
key = self.outputs.get_keys()[0] # type: ignore
# first filter the experiments based on the model setting
experiments = self._preprocess_experiments(experiments)
train_results = []
test_results = []
# now get the indices for the split
Expand Down
16 changes: 16 additions & 0 deletions tests/bofire/surrogates/test_feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def test_permutation_importance():
assert list(results[m.name].index) == ["mean", "std"]


def test_permutation_importance_nan():
model, experiments = get_model_and_data()
X = experiments[model.inputs.get_keys()][:1]
y = experiments[["y"]][:1]
model.fit(experiments=experiments)
results = permutation_importance(model=model, X=X, y=y, n_repeats=5)
assert isinstance(results, dict)
assert len(results) == len(metrics)
for m in metrics.keys():
assert m.name in results.keys()
assert isinstance(results[m.name], pd.DataFrame)
assert list(results[m.name].columns) == model.inputs.get_keys()
assert list(results[m.name].index) == ["mean", "std"]
assert len(results[m.name].dropna()) == 0


@pytest.mark.parametrize("use_test", [True, False])
def test_permutation_importance_hook(use_test):
model, experiments = get_model_and_data()
Expand Down

0 comments on commit fde91c9

Please sign in to comment.