Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cardiac tutorial #233

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b8c7650
first commit
MaxBalmus Jul 22, 2024
64f9412
Fixed small plotting error
MaxBalmus Jul 22, 2024
694fedf
Initial draft for simple emulation-based sensitivity analysis pipeline
aranas Jul 24, 2024
ce92961
Small function description corrections
MaxBalmus Jul 25, 2024
137a238
Same
MaxBalmus Jul 25, 2024
caee429
renaming tutorial
aranas Aug 13, 2024
e4dc678
adding context & intro to cardiac tutorial
aranas Aug 13, 2024
0f1382d
fix simulation import
aranas Aug 13, 2024
c2700e5
flow_functions.py: apply pre-commit
MaxBalmus Aug 13, 2024
84cb431
Merge branch 'main' of https://github.com/alan-turing-institute/autoe…
MaxBalmus Jan 21, 2025
a92bb09
Added documentation for the simulation part - references to be added yet
marjanfamili Jan 24, 2025
b462e28
updating docu
marjanfamili Jan 27, 2025
557b6df
changed names of the methods to match the changes to the code
marjanfamili Jan 27, 2025
a36f768
fixing the parameter names in sensitivity_analysis
marjanfamili Jan 27, 2025
3768af1
Added history matching function to the metrics.py and used in the car…
marjanfamili Feb 3, 2025
0420064
Added MLE to the metrics and the same analysis to the notebook. It ca…
marjanfamili Feb 3, 2025
0e6e48c
removing the history matching and MLE from this PR
marjanfamili Feb 13, 2025
fbf1273
mend
marjanfamili Feb 13, 2025
da65bf0
Merge branch 'main' into cardiac-tutorial
marjanfamili Feb 13, 2025
0cf3b45
remove sensitivity analysis fix from this PR
marjanfamili Feb 17, 2025
a6f7795
removing sensitivity analysis variable fix from here, this is in a se…
marjanfamili Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,6 @@ def sensitivity_analysis(
self.logger.info(
f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score, refitted on full data."
)

Si = _sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)
return Si

Expand Down
36 changes: 35 additions & 1 deletion autoemulate/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from sklearn.metrics import r2_score
from sklearn.metrics import root_mean_squared_error


def rmse(y_true, y_pred, multioutput="uniform_average"):
"""Returns the root mean squared error.
Expand Down Expand Up @@ -38,3 +37,38 @@ def r2(y_true, y_pred, multioutput="uniform_average"):
"rmse": rmse,
"r2": r2,
}


def history_matching(obs, expectations, threshold=3.0, discrepancy=0.0, rank=1):
"""
Perform history matching to compute implausibility and identify NROY and RO points.
Parameters:
obs (tuple): Observations as (mean, variance).
expectations (tuple): Predicted (mean, variance).
threshold (float): Implausibility threshold for NROY classification.
discrepancy (float or ndarray): Discrepancy value(s).
rank (int): Rank for implausibility calculation.
Returns:
dict: Contains implausibility (I), NROY indices, and RO indices.
"""
obs_mean, obs_var = np.atleast_1d(obs[0]), np.atleast_1d(obs[1])
pred_mean, pred_var = np.atleast_1d(expectations[0]), np.atleast_1d(expectations[1])

discrepancy = np.atleast_1d(discrepancy)
n_obs = len(obs_mean)
rank = min(max(rank, 0), n_obs - 1)
# Vs represents the total variance associated with the observations, predictions, and potential discrepancies.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change
# Vs represents the total variance associated with the observations, predictions, and potential discrepancies.
# Vs represents the total variance associated with the observations, predictions, and potential discrepancies.

Vs = pred_var + discrepancy[:, np.newaxis] + obs_var[:, np.newaxis]
I = np.abs(obs_mean[:, np.newaxis] - pred_mean) / np.sqrt(Vs)
I_ranked = np.partition(I, rank, axis=0)[rank]

NROY = np.where(I_ranked <= threshold)[0]
RO = np.where(I_ranked > threshold)[0]

return {
"I": I_ranked,
"NROY": list(NROY),
"RO": list(RO)
}
28 changes: 18 additions & 10 deletions autoemulate/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _sensitivity_analysis(
Si = _sobol_analysis(model, problem, X, N, conf_level)

if as_df:
return _sobol_results_to_df(Si)
return _sobol_results_to_df(Si, problem)
else:
return Si

Expand Down Expand Up @@ -148,21 +148,30 @@ def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
return results


def _sobol_results_to_df(results):
def _sobol_results_to_df(results, problem=None):
"""
Convert Sobol results to a (long-format)pandas DataFrame.
Convert Sobol results to a (long-format) pandas DataFrame.

Parameters:
-----------
results : dict
The Sobol indices returned by sobol_analysis.
problem : dict, optional
The problem definition, including 'names'.

Returns:
--------
pd.DataFrame
A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'.
"""
rows = []
# Use custom names if provided, else default to "x1", "x2", etc.
parameter_names = (
problem["names"]
if problem is not None
else [f"x{i+1}" for i in range(len(next(iter(results.values()))["S1"]))]
)

for output, indices in results.items():
for index_type in ["S1", "ST", "S2"]:
values = indices.get(index_type)
Expand All @@ -174,7 +183,7 @@ def _sobol_results_to_df(results):
rows.extend(
{
"output": output,
"parameter": f"X{i+1}",
"parameter": parameter_names[i], # Use appropriate names
"index": index_type,
"value": value,
"confidence": conf,
Expand All @@ -187,7 +196,7 @@ def _sobol_results_to_df(results):
rows.extend(
{
"output": output,
"parameter": f"X{i+1}-X{j+1}",
"parameter": f"{parameter_names[i]}-{parameter_names[j]}", # Use appropriate names
"index": index_type,
"value": values[i, j],
"confidence": conf_values[i, j],
Expand All @@ -196,16 +205,15 @@ def _sobol_results_to_df(results):
for j in range(i + 1, n)
if not np.isnan(values[i, j])
)

return pd.DataFrame(rows)


# plotting --------------------------------------------------------------------


def _validate_input(results, index):
def _validate_input(results, problem, index):
if not isinstance(results, pd.DataFrame):
results = _sobol_results_to_df(results)
results = _sobol_results_to_df(results, problem=problem)
# we only want to plot one index type at a time
valid_indices = ["S1", "S2", "ST"]
if index not in valid_indices:
Expand Down Expand Up @@ -241,7 +249,7 @@ def _create_bar_plot(ax, output_data, output_name):
ax.set_title(f"Output: {output_name}")


def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
def _plot_sensitivity_analysis(results, problem, index="S1", n_cols=None, figsize=None):
"""
Plot the sensitivity analysis results.

Expand All @@ -263,7 +271,7 @@ def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
"""
with plt.style.context("fast"):
# prepare data
results = _validate_input(results, index)
results = _validate_input(results, problem, index)
unique_outputs = results["output"].unique()
n_outputs = len(unique_outputs)

Expand Down
162 changes: 162 additions & 0 deletions autoemulate/simulations/flow_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp


class FlowProblem:
def __init__(
self,
T=1.0,
td=0.2,
amp=900.0,
dt=0.001,
ncycles=10,
ncomp=10,
C=38.0,
R=0.06,
L=0.0017,
R_o=0.025,
p_o=10.0,
) -> None:
"""
Inputs
-----
T (float): cycle length (default 1.0)
td (float): pulse duration, make sure to make this less than T (default 0.2)
amp (float): inflow amplitude (default 1.0)
dt (float): temporal discretisation resolution (default 0.001)
C (float): tube average compliance (default 38.)
R (float): tube average impedance (default 0.06)
L (float): hydraulic impedance, inertia (default 0.0017)
R_o (float) : outflow resistance
p_o (float) : outflow pressure
"""

assert td < T, f"td should be smaller than T but {td} >= {T}."

self._td = td
self._T = T
self._dt = dt
self._amp = amp

self._ncomp = ncomp
self._ncycles = ncycles

self._C = C
self._R = R
self._L = L

self._R_o = R_o
self._p_o = p_o

self.res = None

@property
def td(self):
return self._td

@property
def T(self):
return self._T

@property
def dt(self):
return self._dt

@property
def amp(self):
return self._amp

@property
def ncomp(self):
return self._ncomp

@property
def ncycles(self):
return self._ncycles

@property
def C(self):
return self._C

@property
def L(self):
return self._L

@property
def R(self):
return self._R

@property
def R_o(self):
return self._R_o

@property
def p_o(self):
return self._p_o

def generate_pulse_function(self):
self.Q_mi_lambda = (
lambda t: np.sin(np.pi / self.td * t) ** 2.0
* np.heaviside(self.td - t, 0.0)
* self.amp
)

def dfdt_fd(self, t: float, y: np.ndarray, Q_in):
Cn = self.C / self.ncomp
Rn = self.R / self.ncomp
Ln = self.L / self.ncomp

out = np.zeros((self.ncomp, 2))
y_temp = y.reshape((-1, 2))

for i in range(self.ncomp):
if i > 0:
out[i, 0] = (y_temp[i - 1, 1] - y_temp[i, 1]) / Cn
else:
out[i, 0] = (Q_in(t % self.T) - y_temp[i, 1]) / Cn
if i < self.ncomp - 1:
out[i, 1] = (-y_temp[i + 1, 0] + y_temp[i, 0] - Rn * y_temp[i, 1]) / Ln
pass
else:
out[i, 1] = (
-self.p_o + y_temp[i, 0] - (Rn + self.R_o) * y_temp[i, 1]
) / Ln
return out.reshape((-1,))

def solve(self):
dfdt_fd_spec = lambda t, y: self.dfdt_fd(t=t, y=y, Q_in=self.Q_mi_lambda)
self.res = sp.integrate.solve_ivp(
dfdt_fd_spec,
[0.0, self.T * self.ncycles],
y0=np.zeros(self.ncomp * 2),
method="BDF",
max_step=self.dt,
)
self.res.y = self.res.y[:, self.res.t >= self.T * (self.ncycles - 1)]
self.res.t = self.res.t[self.res.t >= self.T * (self.ncycles - 1)]

def plot_res(self):
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
for i in range(self.ncomp):
ax[0].plot(
self.res.t,
self.res.y[2 * i, :],
"r",
alpha=0.1 + (1.0 - i / self.ncomp) * 0.9,
)
ax[1].plot(
self.res.t,
self.res.y[2 * i + 1, :],
"r",
alpha=0.1 + (1.0 - i / self.ncomp) * 0.9,
)

ax[0].set_title("Pressure")
ax[1].set_title("Flow rate")
ax[0].set_xlabel("Time (s)")
ax[1].set_xlabel("Time (s)")
ax[0].set_ylabel("mmHg")
ax[1].set_ylabel("$ml\cdot s^{-1}$")

return (fig, ax)
Loading
Loading