From 491dc3f637b1087b3ad33de8e5b785b5d0b37d25 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Wed, 12 Jan 2022 15:48:03 +0000 Subject: [PATCH] #1864 saving progress for now. outstanding problem: the csr.CSR library assumes that mat-vec is done for 1D vec --- .../operations/evaluate_python.py | 101 +++++++++++++----- setup.py | 1 + 2 files changed, 78 insertions(+), 24 deletions(-) diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index ae569f3ab4..6d2f490c72 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -7,7 +7,9 @@ import numpy as np import scipy.sparse from numba import jit, types -from numba.typed import Dict +from numba.typed import Dict, List +import csr +import inspect import pybamm @@ -208,7 +210,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): children_vars[0], children_vars[1] ) else: - symbol_str = "{0}.multiply({1})".format( + print('XXX', type(dummy_eval_right)) + symbol_str = "{0}.mult_vec({1})".format( children_vars[0], children_vars[1] ) elif scipy.sparse.issparse(dummy_eval_right): @@ -217,7 +220,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): children_vars[0], children_vars[1] ) else: - symbol_str = "{1}.multiply({0})".format( + print('XXX', type(dummy_eval_left)) + symbol_str = "{1}.mult_vec({0})".format( children_vars[0], children_vars[1] ) else: @@ -231,7 +235,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): children_vars[0], children_vars[1] ) else: - symbol_str = "{0}.multiply(1/{1})".format( + print('XXX', type(dummy_eval_right)) + symbol_str = "{0}.mult_vec(1/{1})".format( children_vars[0], children_vars[1] ) else: @@ -246,7 +251,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): children_vars[0], children_vars[1] ) else: - symbol_str = "{0}.multiply({1})".format( + print('XXX', type(dummy_eval_right)) + symbol_str = "{0}.mult_vec({1})".format( children_vars[0], children_vars[1] ) elif scipy.sparse.issparse(dummy_eval_right): @@ -255,7 +261,8 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): children_vars[0], children_vars[1] ) else: - symbol_str = "{1}.multiply({0})".format( + print('XXX', type(dummy_eval_right)) + symbol_str = "{1}.mult_vec({0})".format( children_vars[0], children_vars[1] ) else: @@ -269,17 +276,22 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): elif isinstance(symbol, pybamm.MatrixMultiplication): dummy_eval_left = symbol.children[0].evaluate_for_shape() dummy_eval_right = symbol.children[1].evaluate_for_shape() - if output_jax and ( + if ( scipy.sparse.issparse(dummy_eval_left) and scipy.sparse.issparse(dummy_eval_right) ): - raise NotImplementedError( - "sparse mat-mat multiplication not supported " - "for output_jax == True" - ) + if output_jax: + raise NotImplementedError( + "sparse mat-mat multiplication not supported ", + ) + else: + symbol_str = ( + children_vars[0] + ".multiply(" + children_vars[1] + ")" + ) else: + print('XXX', type(dummy_eval_left), type(dummy_eval_right)) symbol_str = ( - children_vars[0] + " " + symbol.name + " " + children_vars[1] + children_vars[0] + ".mult_vec(" + children_vars[1] + ")" ) else: symbol_str = children_vars[0] + " " + symbol.name + " " + children_vars[1] @@ -300,13 +312,18 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): children_str = child_var else: children_str += ", " + child_var - if isinstance(symbol.function, np.ufunc): + if ( + isinstance(symbol.function, np.ufunc) or + inspect.getmodule(symbol.function) == np + ): # write any numpy functions directly symbol_str = "np.{}({})".format(symbol.function.__name__, children_str) else: # unknown function, store it as a constant and call this in the # generated code constant_symbols[symbol.id] = symbol.function + print('UHNKOWSN FHUNCC', inspect.getmodule(constant_symbols[symbol.id])) + print(inspect.getsource(constant_symbols[symbol.id])) funct_var = id_to_python_variable(symbol.id, True) symbol_str = "{}({})".format(funct_var, children_str) @@ -452,13 +469,52 @@ class EvaluatorPython: def __init__(self, symbol): constants, python_str = pybamm.to_python(symbol, debug=False) + # split constants into ndarrays and csr matrices + constants_ndarray = { + key: value for key, value in constants.items() + if isinstance(value, np.ndarray) + } + + # extract constants in generated function + for i, symbol_id in enumerate(constants_ndarray.keys()): + const_name = id_to_python_variable(symbol_id, True) + python_str = "{} = constants_ndarray[{}]\n".format(const_name, i) + python_str + + # sparse mat, convert to csr.CSR format for numba + constants_sparse_mat = { + key: csr.CSR.from_scipy(value) for key, value in constants.items() + if scipy.sparse.issparse(value) + } + # extract constants in generated function - for i, symbol_id in enumerate(constants.keys()): + for i, symbol_id in enumerate(constants_sparse_mat.keys()): const_name = id_to_python_variable(symbol_id, True) - python_str = "{} = constants[{}]\n".format(const_name, i) + python_str + python_str = "{} = constants_sparse_mat[{}]\n".format(const_name, i) + python_str + + + # functions + constants_sparse_mat = { + key: csr.CSR.from_scipy(value) for key, value in constants.items() + if scipy.sparse.issparse(value) + } + + # extract constants in generated function + for i, symbol_id in enumerate(constants_sparse_mat.keys()): + const_name = id_to_python_variable(symbol_id, True) + python_str = "{} = constants_sparse_mat[{}]\n".format(const_name, i) + python_str + + # constants passed in as an ordered dict, convert to list - self._constants = list(constants.values()) + self._constants_ndarray = list(constants_ndarray.values()) + if not self._constants_ndarray: + self._constants_ndarray = List() + self._constants_ndarray.append(np.array([1.0, 2.0])) + + self._constants_sparse_mat = list(constants_sparse_mat.values()) + if not self._constants_sparse_mat: + self._constants_sparse_mat = List() + self._constants_sparse_mat.append(csr.CSR.empty(2, 2)) # indent code python_str = " " + python_str @@ -467,7 +523,7 @@ def __init__(self, symbol): # add function def to first line python_str = ( "@jit(nopython=True)\n" - "def evaluate(constants, t=None, y=None, " + "def evaluate(constants_ndarray, constants_sparse_mat, t=None, y=None, " "y_dot=None, inputs=None, known_evals=None):\n" + python_str ) @@ -485,8 +541,8 @@ def __init__(self, symbol): # store a copy of examine_jaxpr python_str = python_str + "\nself._evaluate = evaluate" - print(python_str) + print(python_str) self._python_str = python_str self._result_var = result_var @@ -504,8 +560,6 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): if y is not None and y.ndim == 1: y = y.reshape(-1, 1) - print(inputs) - print(known_evals) inputs_numba = Dict.empty( key_type=types.unicode_type, value_type=types.float64, @@ -517,11 +571,10 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): ) if inputs: for k, v in inputs.items(): - inputs_numba_dict[k] = v - print(type(self._constants), type(t), type(y), type(y_dot), type(inputs_numba), - type(known_evals_numba)) + inputs_numba[k] = v - result = self._evaluate(self._constants, t, y, y_dot, inputs_numba, + result = self._evaluate(self._constants_ndarray, self._constants_sparse_mat, + t, y, y_dot, inputs_numba, known_evals_numba) # don't need known_evals, but need to reproduce Symbol.evaluate signature diff --git a/setup.py b/setup.py index 55fd14a2bc..d07afe94d6 100644 --- a/setup.py +++ b/setup.py @@ -196,6 +196,7 @@ def compile_KLU(): "jupyter", # For example notebooks "pybtex", "sympy==1.9", + "csr @ git+https://github.com/lenskit/csr.git@v0.4.0", # Note: Matplotlib is loaded for debug plots, but to ensure pybamm runs # on systems without an attached display, it should never be imported # outside of plot() methods.