Skip to content

Commit

Permalink
#1864 use numba jit and dict for python evaluator, not working due to…
Browse files Browse the repository at this point in the history
… lack of sparse matrix support
  • Loading branch information
martinjrobins committed Dec 25, 2021
1 parent 93d6ddc commit 9aa4582
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np
import scipy.sparse
from numba import jit, types
from numba.typed import Dict

import pybamm

Expand Down Expand Up @@ -464,6 +466,7 @@ def __init__(self, symbol):

# add function def to first line
python_str = (
"@jit(nopython=True)\n"
"def evaluate(constants, t=None, y=None, "
"y_dot=None, inputs=None, known_evals=None):\n" + python_str
)
Expand All @@ -482,6 +485,8 @@ def __init__(self, symbol):

# store a copy of examine_jaxpr
python_str = python_str + "\nself._evaluate = evaluate"
print(python_str)


self._python_str = python_str
self._result_var = result_var
Expand All @@ -499,7 +504,25 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

result = self._evaluate(self._constants, t, y, y_dot, inputs, known_evals)
print(inputs)
print(known_evals)
inputs_numba = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
)

known_evals_numba = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
)
if inputs:
for k, v in inputs.items():
inputs_numba_dict[k] = v
print(type(self._constants), type(t), type(y), type(y_dot), type(inputs_numba),
type(known_evals_numba))

result = self._evaluate(self._constants, t, y, y_dot, inputs_numba,
known_evals_numba)

# don't need known_evals, but need to reproduce Symbol.evaluate signature
if known_evals is not None:
Expand Down

0 comments on commit 9aa4582

Please sign in to comment.