Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: use casadi MX.interpn_linear function instead of plugin #3783 #4077

Merged
merged 7 commits into from
May 10, 2024
3 changes: 2 additions & 1 deletion pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def __init__(
fill_value_1 = "extrapolate"
interpolating_function = interpolate.interp1d(
x1,
y.T,
y,
bounds_error=False,
fill_value=fill_value_1,
axis=0,
)
elif interpolator == "cubic":
interpolating_function = interpolate.CubicSpline(
Expand Down
32 changes: 24 additions & 8 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,31 @@ def _convert(self, symbol, t, y, y_dot, inputs):
)

if len(converted_children) == 1:
return casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.flatten()
)(*converted_children)
if solver == "linear":
test = casadi.MX.interpn_linear(
symbol.x, symbol.y.flatten(), converted_children
)
if test.shape[0] == 1 and test.shape[1] > 1:
# for some reason, pybamm.Interpolant always returns a column vector, so match that
test = test.T
return test
else:
return casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.flatten()
)(*converted_children)
elif len(converted_children) in [2, 3]:
LUT = casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
)
res = LUT(casadi.hcat(converted_children).T).T
return res
if solver == "linear":
return casadi.MX.interpn_linear(
symbol.x,
symbol.y.ravel(order="F"),
converted_children,
)
else:
LUT = casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
)
res = LUT(casadi.hcat(converted_children).T).T
return res
else: # pragma: no cover
raise ValueError(
f"Invalid converted_children count: {len(converted_children)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_interpolation_2d(self):
# linear
y_test = np.array([0.4, 0.6])
Y = (2 * x).sum(axis=1).reshape(*[len(el) for el in x_])
for interpolator in ["linear"]:
for interpolator in ["linear", "cubic"]:
interp = pybamm.Interpolant(x_, Y, y, interpolator=interpolator)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
Expand Down
Loading