From 6e4ddbb2f25f31f883c4c3805615791cb78765f1 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 9 May 2024 08:47:49 +0000 Subject: [PATCH] bug: use casadi MX.interpn_linear function instead of plugin #3783 --- pybamm/expression_tree/interpolant.py | 3 +- .../operations/convert_to_casadi.py | 32 ++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 10881d3084..dd0980fb46 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -126,9 +126,10 @@ def __init__( fill_value_1 = "extrapolate" interpolating_function = interpolate.interp1d( x1, - y.T, + y, bounds_error=False, fill_value=fill_value_1, + axis=0, ) elif interpolator == "cubic": interpolating_function = interpolate.CubicSpline( diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index 196da9dec9..4ec9972336 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -157,15 +157,31 @@ def _convert(self, symbol, t, y, y_dot, inputs): ) if len(converted_children) == 1: - return casadi.interpolant( - "LUT", solver, symbol.x, symbol.y.flatten() - )(*converted_children) + if solver == "linear": + test = casadi.MX.interpn_linear( + symbol.x, symbol.y.flatten(), converted_children + ) + if test.shape[0] == 1 and test.shape[1] > 1: + # for some reason, pybamm.Interpolant always returns a column vector, so match that + test = test.T + return test + else: + return casadi.interpolant( + "LUT", solver, symbol.x, symbol.y.flatten() + )(*converted_children) elif len(converted_children) in [2, 3]: - LUT = casadi.interpolant( - "LUT", solver, symbol.x, symbol.y.ravel(order="F") - ) - res = LUT(casadi.hcat(converted_children).T).T - return res + if solver == "linear": + return casadi.MX.interpn_linear( + symbol.x, + symbol.y.ravel(order="F"), + casadi.hcat(converted_children).T, + ).T + else: + LUT = casadi.interpolant( + "LUT", solver, symbol.x, symbol.y.ravel(order="F") + ) + res = LUT(casadi.hcat(converted_children).T).T + return res else: # pragma: no cover raise ValueError( f"Invalid converted_children count: {len(converted_children)}"