Skip to content

Commit

Permalink
#1864 saving progress for now. outstanding problem: the csr.CSR libra…
Browse files Browse the repository at this point in the history
…ry assumes that mat-vec is done for 1D vec
  • Loading branch information
martinjrobins committed Jan 12, 2022
1 parent 9aa4582 commit 491dc3f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 24 deletions.
101 changes: 77 additions & 24 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy as np
import scipy.sparse
from numba import jit, types
from numba.typed import Dict
from numba.typed import Dict, List
import csr
import inspect

import pybamm

Expand Down Expand Up @@ -208,7 +210,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0}.multiply({1})".format(
print('XXX', type(dummy_eval_right))
symbol_str = "{0}.mult_vec({1})".format(
children_vars[0], children_vars[1]
)
elif scipy.sparse.issparse(dummy_eval_right):
Expand All @@ -217,7 +220,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
children_vars[0], children_vars[1]
)
else:
symbol_str = "{1}.multiply({0})".format(
print('XXX', type(dummy_eval_left))
symbol_str = "{1}.mult_vec({0})".format(
children_vars[0], children_vars[1]
)
else:
Expand All @@ -231,7 +235,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0}.multiply(1/{1})".format(
print('XXX', type(dummy_eval_right))
symbol_str = "{0}.mult_vec(1/{1})".format(
children_vars[0], children_vars[1]
)
else:
Expand All @@ -246,7 +251,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0}.multiply({1})".format(
print('XXX', type(dummy_eval_right))
symbol_str = "{0}.mult_vec({1})".format(
children_vars[0], children_vars[1]
)
elif scipy.sparse.issparse(dummy_eval_right):
Expand All @@ -255,7 +261,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
children_vars[0], children_vars[1]
)
else:
symbol_str = "{1}.multiply({0})".format(
print('XXX', type(dummy_eval_right))
symbol_str = "{1}.mult_vec({0})".format(
children_vars[0], children_vars[1]
)
else:
Expand All @@ -269,17 +276,22 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
elif isinstance(symbol, pybamm.MatrixMultiplication):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if output_jax and (
if (
scipy.sparse.issparse(dummy_eval_left)
and scipy.sparse.issparse(dummy_eval_right)
):
raise NotImplementedError(
"sparse mat-mat multiplication not supported "
"for output_jax == True"
)
if output_jax:
raise NotImplementedError(
"sparse mat-mat multiplication not supported ",
)
else:
symbol_str = (
children_vars[0] + ".multiply(" + children_vars[1] + ")"
)
else:
print('XXX', type(dummy_eval_left), type(dummy_eval_right))
symbol_str = (
children_vars[0] + " " + symbol.name + " " + children_vars[1]
children_vars[0] + ".mult_vec(" + children_vars[1] + ")"
)
else:
symbol_str = children_vars[0] + " " + symbol.name + " " + children_vars[1]
Expand All @@ -300,13 +312,18 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
children_str = child_var
else:
children_str += ", " + child_var
if isinstance(symbol.function, np.ufunc):
if (
isinstance(symbol.function, np.ufunc) or
inspect.getmodule(symbol.function) == np
):
# write any numpy functions directly
symbol_str = "np.{}({})".format(symbol.function.__name__, children_str)
else:
# unknown function, store it as a constant and call this in the
# generated code
constant_symbols[symbol.id] = symbol.function
print('UHNKOWSN FHUNCC', inspect.getmodule(constant_symbols[symbol.id]))
print(inspect.getsource(constant_symbols[symbol.id]))
funct_var = id_to_python_variable(symbol.id, True)
symbol_str = "{}({})".format(funct_var, children_str)

Expand Down Expand Up @@ -452,13 +469,52 @@ class EvaluatorPython:
def __init__(self, symbol):
constants, python_str = pybamm.to_python(symbol, debug=False)

# split constants into ndarrays and csr matrices
constants_ndarray = {
key: value for key, value in constants.items()
if isinstance(value, np.ndarray)
}

# extract constants in generated function
for i, symbol_id in enumerate(constants_ndarray.keys()):
const_name = id_to_python_variable(symbol_id, True)
python_str = "{} = constants_ndarray[{}]\n".format(const_name, i) + python_str

# sparse mat, convert to csr.CSR format for numba
constants_sparse_mat = {
key: csr.CSR.from_scipy(value) for key, value in constants.items()
if scipy.sparse.issparse(value)
}

# extract constants in generated function
for i, symbol_id in enumerate(constants.keys()):
for i, symbol_id in enumerate(constants_sparse_mat.keys()):
const_name = id_to_python_variable(symbol_id, True)
python_str = "{} = constants[{}]\n".format(const_name, i) + python_str
python_str = "{} = constants_sparse_mat[{}]\n".format(const_name, i) + python_str


# functions
constants_sparse_mat = {
key: csr.CSR.from_scipy(value) for key, value in constants.items()
if scipy.sparse.issparse(value)
}

# extract constants in generated function
for i, symbol_id in enumerate(constants_sparse_mat.keys()):
const_name = id_to_python_variable(symbol_id, True)
python_str = "{} = constants_sparse_mat[{}]\n".format(const_name, i) + python_str



# constants passed in as an ordered dict, convert to list
self._constants = list(constants.values())
self._constants_ndarray = list(constants_ndarray.values())
if not self._constants_ndarray:
self._constants_ndarray = List()
self._constants_ndarray.append(np.array([1.0, 2.0]))

self._constants_sparse_mat = list(constants_sparse_mat.values())
if not self._constants_sparse_mat:
self._constants_sparse_mat = List()
self._constants_sparse_mat.append(csr.CSR.empty(2, 2))

# indent code
python_str = " " + python_str
Expand All @@ -467,7 +523,7 @@ def __init__(self, symbol):
# add function def to first line
python_str = (
"@jit(nopython=True)\n"
"def evaluate(constants, t=None, y=None, "
"def evaluate(constants_ndarray, constants_sparse_mat, t=None, y=None, "
"y_dot=None, inputs=None, known_evals=None):\n" + python_str
)

Expand All @@ -485,8 +541,8 @@ def __init__(self, symbol):

# store a copy of examine_jaxpr
python_str = python_str + "\nself._evaluate = evaluate"
print(python_str)

print(python_str)

self._python_str = python_str
self._result_var = result_var
Expand All @@ -504,8 +560,6 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

print(inputs)
print(known_evals)
inputs_numba = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
Expand All @@ -517,11 +571,10 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
)
if inputs:
for k, v in inputs.items():
inputs_numba_dict[k] = v
print(type(self._constants), type(t), type(y), type(y_dot), type(inputs_numba),
type(known_evals_numba))
inputs_numba[k] = v

result = self._evaluate(self._constants, t, y, y_dot, inputs_numba,
result = self._evaluate(self._constants_ndarray, self._constants_sparse_mat,
t, y, y_dot, inputs_numba,
known_evals_numba)

# don't need known_evals, but need to reproduce Symbol.evaluate signature
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def compile_KLU():
"jupyter", # For example notebooks
"pybtex",
"sympy==1.9",
"csr @ git+https://github.com/lenskit/csr.git@v0.4.0",
# Note: Matplotlib is loaded for debug plots, but to ensure pybamm runs
# on systems without an attached display, it should never be imported
# outside of plot() methods.
Expand Down

0 comments on commit 491dc3f

Please sign in to comment.