Skip to content

Commit

Permalink
remove sensitivity analysis fix from this PR
Browse files Browse the repository at this point in the history
  • Loading branch information
marjanfamili committed Feb 17, 2025
1 parent da65bf0 commit 0cf3b45
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions autoemulate/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _sensitivity_analysis(
Si = _sobol_analysis(model, problem, X, N, conf_level)

if as_df:
return _sobol_results_to_df(Si, problem)
return _sobol_results_to_df(Si)
else:
return Si

Expand Down Expand Up @@ -148,30 +148,21 @@ def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
return results


def _sobol_results_to_df(results, problem=None):
def _sobol_results_to_df(results):
"""
Convert Sobol results to a (long-format) pandas DataFrame.
Convert Sobol results to a (long-format)pandas DataFrame.
Parameters:
-----------
results : dict
The Sobol indices returned by sobol_analysis.
problem : dict, optional
The problem definition, including 'names'.
Returns:
--------
pd.DataFrame
A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'.
"""
rows = []
# Use custom names if provided, else default to "x1", "x2", etc.
parameter_names = (
problem["names"]
if problem is not None
else [f"x{i+1}" for i in range(len(next(iter(results.values()))["S1"]))]
)

for output, indices in results.items():
for index_type in ["S1", "ST", "S2"]:
values = indices.get(index_type)
Expand All @@ -183,7 +174,7 @@ def _sobol_results_to_df(results, problem=None):
rows.extend(
{
"output": output,
"parameter": parameter_names[i], # Use appropriate names
"parameter": f"X{i+1}",
"index": index_type,
"value": value,
"confidence": conf,
Expand All @@ -196,7 +187,7 @@ def _sobol_results_to_df(results, problem=None):
rows.extend(
{
"output": output,
"parameter": f"{parameter_names[i]}-{parameter_names[j]}", # Use appropriate names
"parameter": f"X{i+1}-X{j+1}",
"index": index_type,
"value": values[i, j],
"confidence": conf_values[i, j],
Expand All @@ -205,15 +196,16 @@ def _sobol_results_to_df(results, problem=None):
for j in range(i + 1, n)
if not np.isnan(values[i, j])
)

return pd.DataFrame(rows)


# plotting --------------------------------------------------------------------


def _validate_input(results, problem, index):
def _validate_input(results, index):
if not isinstance(results, pd.DataFrame):
results = _sobol_results_to_df(results, problem=problem)
results = _sobol_results_to_df(results)
# we only want to plot one index type at a time
valid_indices = ["S1", "S2", "ST"]
if index not in valid_indices:
Expand Down Expand Up @@ -249,7 +241,7 @@ def _create_bar_plot(ax, output_data, output_name):
ax.set_title(f"Output: {output_name}")


def _plot_sensitivity_analysis(results, problem, index="S1", n_cols=None, figsize=None):
def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
"""
Plot the sensitivity analysis results.
Expand All @@ -271,7 +263,7 @@ def _plot_sensitivity_analysis(results, problem, index="S1", n_cols=None, figsiz
"""
with plt.style.context("fast"):
# prepare data
results = _validate_input(results, problem, index)
results = _validate_input(results, index)
unique_outputs = results["output"].unique()
n_outputs = len(unique_outputs)

Expand Down

0 comments on commit 0cf3b45

Please sign in to comment.