Skip to content

Commit

Permalink
#1129 start matrix operations
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Aug 20, 2020
1 parent d2ad43a commit aa0b0b4
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 160 deletions.
73 changes: 19 additions & 54 deletions pybamm/expression_tree/operations/evaluate_julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,41 +85,12 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):
# Multiplication and Division need special handling for scipy sparse matrices
# TODO: we can pass through a dummy y and t to get the type and then hardcode
# the right line, avoiding these checks
if isinstance(symbol, pybamm.Multiplication):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if not to_dense and scipy.sparse.issparse(dummy_eval_left):
symbol_str = "{0}.multiply({1})".format(
children_vars[0], children_vars[1]
)
elif not to_dense and scipy.sparse.issparse(dummy_eval_right):
symbol_str = "{1}.multiply({0})".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0} .* {1}".format(children_vars[0], children_vars[1])
if isinstance(symbol, (pybamm.Multiplication, pybamm.Inner)):
symbol_str = "{0} .* {1}".format(children_vars[0], children_vars[1])
if isinstance(symbol, pybamm.MatrixMultiplication):
symbol_str = "{0} * {1}".format(children_vars[0], children_vars[1])
elif isinstance(symbol, pybamm.Division):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
if not to_dense and scipy.sparse.issparse(dummy_eval_left):
symbol_str = "{0}.multiply(1/{1})".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0} / {1}".format(children_vars[0], children_vars[1])

elif isinstance(symbol, pybamm.Inner):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if not to_dense and scipy.sparse.issparse(dummy_eval_left):
symbol_str = "{0}.multiply({1})".format(
children_vars[0], children_vars[1]
)
elif not to_dense and scipy.sparse.issparse(dummy_eval_right):
symbol_str = "{1}.multiply({0})".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0} * {1}".format(children_vars[0], children_vars[1])
symbol_str = "{0} ./ {1}".format(children_vars[0], children_vars[1])

elif isinstance(symbol, pybamm.Minimum):
symbol_str = "np.minimum({},{})".format(children_vars[0], children_vars[1])
Expand All @@ -133,8 +104,10 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):
elif isinstance(symbol, pybamm.UnaryOperator):
# Index has a different syntax than other univariate operations
if isinstance(symbol, pybamm.Index):
# Because of how julia indexing works, add 1 to the start, but not to the
# stop
symbol_str = "{}[{}:{}]".format(
children_vars[0], symbol.slice.start, symbol.slice.stop
children_vars[0], symbol.slice.start + 1, symbol.slice.stop
)
else:
symbol_str = symbol.name + children_vars[0]
Expand All @@ -148,24 +121,17 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):
children_str += ", " + child_var
# write functions directly
julia_name = symbol.julia_name
symbol_str = "{}({})".format(julia_name, children_str)
# add a . to allow elementwise operations
symbol_str = "{}.({})".format(julia_name, children_str)

elif isinstance(symbol, pybamm.Concatenation):

# don't bother to concatenate if there is only a single child
if isinstance(symbol, pybamm.NumpyConcatenation):
if len(children_vars) > 1:
symbol_str = "np.concatenate(({}))".format(",".join(children_vars))
if isinstance(symbol, (pybamm.NumpyConcatenation, pybamm.SparseStack)):
if len(children_vars) == 1:
symbol_str = children_vars
else:
symbol_str = "{}".format(",".join(children_vars))

elif isinstance(symbol, pybamm.SparseStack):
if not to_dense and len(children_vars) > 1:
symbol_str = "scipy.sparse.vstack(({}))".format(",".join(children_vars))
elif len(children_vars) > 1:
symbol_str = "np.vstack(({}))".format(",".join(children_vars))
else:
symbol_str = "{}".format(",".join(children_vars))
symbol_str = "vcat({})".format(",".join(children_vars))

# DomainConcatenation specifies a particular ordering for the concatenation,
# which we must follow
Expand Down Expand Up @@ -217,7 +183,9 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):

else:
raise NotImplementedError(
"Not implemented for a symbol of type '{}'".format(type(symbol))
"Conversion to Julia not implemented for a symbol of type '{}'".format(
type(symbol)
)
)

variable_symbols[symbol.id] = symbol_str
Expand Down Expand Up @@ -294,12 +262,9 @@ def get_julia_function(symbol):
constants, julia_str = to_julia(symbol, debug=False)

# extract constants in generated function
for i, symbol_id in enumerate(constants.keys()):
for symbol_id, const_value in constants.items():
const_name = id_to_julia_variable(symbol_id, True)
julia_str = "{} = constants[{}]\n".format(const_name, i) + julia_str

# constants passed in as an ordered dict, convert to list
constants = list(constants.values())
julia_str = "{} = {}\n".format(const_name, const_value) + julia_str

# indent code
julia_str = " " + julia_str
Expand Down
93 changes: 50 additions & 43 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import numbers
from platform import system

if system() != "Windows":
import jax

from jax.config import config

config.update("jax_enable_x64", True)


Expand Down Expand Up @@ -95,30 +97,35 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if not to_dense and scipy.sparse.issparse(dummy_eval_left):
symbol_str = "{0}.multiply({1})"\
.format(children_vars[0], children_vars[1])
symbol_str = "{0}.multiply({1})".format(
children_vars[0], children_vars[1]
)
elif not to_dense and scipy.sparse.issparse(dummy_eval_right):
symbol_str = "{1}.multiply({0})"\
.format(children_vars[0], children_vars[1])
symbol_str = "{1}.multiply({0})".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0} * {1}".format(children_vars[0], children_vars[1])
elif isinstance(symbol, pybamm.Division):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
if not to_dense and scipy.sparse.issparse(dummy_eval_left):
symbol_str = "{0}.multiply(1/{1})"\
.format(children_vars[0], children_vars[1])
symbol_str = "{0}.multiply(1/{1})".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0} / {1}".format(children_vars[0], children_vars[1])

elif isinstance(symbol, pybamm.Inner):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if not to_dense and scipy.sparse.issparse(dummy_eval_left):
symbol_str = "{0}.multiply({1})"\
.format(children_vars[0], children_vars[1])
symbol_str = "{0}.multiply({1})".format(
children_vars[0], children_vars[1]
)
elif not to_dense and scipy.sparse.issparse(dummy_eval_right):
symbol_str = "{1}.multiply({0})"\
.format(children_vars[0], children_vars[1])
symbol_str = "{1}.multiply({0})".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = "{0} * {1}".format(children_vars[0], children_vars[1])

Expand Down Expand Up @@ -159,18 +166,18 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):

# don't bother to concatenate if there is only a single child
if isinstance(symbol, pybamm.NumpyConcatenation):
if len(children_vars) > 1:
symbol_str = "np.concatenate(({}))".format(",".join(children_vars))
if len(children_vars) == 1:
symbol_str = children_vars
else:
symbol_str = "{}".format(",".join(children_vars))
symbol_str = "np.concatenate(({}))".format(",".join(children_vars))

elif isinstance(symbol, pybamm.SparseStack):
if not to_dense and len(children_vars) > 1:
if len(children_vars) == 1:
symbol_str = children_vars
elif not to_dense:
symbol_str = "scipy.sparse.vstack(({}))".format(",".join(children_vars))
elif len(children_vars) > 1:
symbol_str = "np.vstack(({}))".format(",".join(children_vars))
else:
symbol_str = "{}".format(",".join(children_vars))
symbol_str = "np.vstack(({}))".format(",".join(children_vars))

# DomainConcatenation specifies a particular ordering for the concatenation,
# which we must follow
Expand Down Expand Up @@ -217,7 +224,9 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):

else:
raise NotImplementedError(
"Not implemented for a symbol of type '{}'".format(type(symbol))
"Conversion to python not implemented for a symbol of type '{}'".format(
type(symbol)
)
)

variable_symbols[symbol.id] = symbol_str
Expand Down Expand Up @@ -294,18 +303,20 @@ def __init__(self, symbol):
# extract constants in generated function
for i, symbol_id in enumerate(constants.keys()):
const_name = id_to_python_variable(symbol_id, True)
python_str = '{} = constants[{}]\n'.format(const_name, i) + python_str
python_str = "{} = constants[{}]\n".format(const_name, i) + python_str

# constants passed in as an ordered dict, convert to list
self._constants = list(constants.values())

# indent code
python_str = ' ' + python_str
python_str = python_str.replace('\n', '\n ')
python_str = " " + python_str
python_str = python_str.replace("\n", "\n ")

# add function def to first line
python_str = 'def evaluate(constants, t=None, y=None, '\
'y_dot=None, inputs=None, known_evals=None):\n' + python_str
python_str = (
"def evaluate(constants, t=None, y=None, "
"y_dot=None, inputs=None, known_evals=None):\n" + python_str
)

# calculate the final variable that will output the result of calling `evaluate`
# on `symbol`
Expand All @@ -315,21 +326,18 @@ def __init__(self, symbol):

# add return line
if symbol.is_constant() and isinstance(result_value, numbers.Number):
python_str = python_str + '\n return ' + str(result_value)
python_str = python_str + "\n return " + str(result_value)
else:
python_str = python_str + '\n return ' + result_var
python_str = python_str + "\n return " + result_var

# store a copy of examine_jaxpr
python_str = python_str + \
'\nself._evaluate = evaluate'
python_str = python_str + "\nself._evaluate = evaluate"

self._python_str = python_str
self._symbol = symbol

# compile and run the generated python code,
compiled_function = compile(
python_str, result_var, "exec"
)
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
Expand Down Expand Up @@ -377,7 +385,7 @@ def __init__(self, symbol):
constants, python_str = pybamm.to_python(symbol, debug=False, to_dense=True)

# replace numpy function calls to jax numpy calls
python_str = python_str.replace('np.', 'jax.numpy.')
python_str = python_str.replace("np.", "jax.numpy.")

# convert all numpy constants to device vectors
for symbol_id in constants:
Expand All @@ -387,18 +395,20 @@ def __init__(self, symbol):
# extract constants in generated function
for i, symbol_id in enumerate(constants.keys()):
const_name = id_to_python_variable(symbol_id, True)
python_str = '{} = constants[{}]\n'.format(const_name, i) + python_str
python_str = "{} = constants[{}]\n".format(const_name, i) + python_str

# constants passed in as an ordered dict, convert to list
self._constants = list(constants.values())

# indent code
python_str = ' ' + python_str
python_str = python_str.replace('\n', '\n ')
python_str = " " + python_str
python_str = python_str.replace("\n", "\n ")

# add function def to first line
python_str = 'def evaluate_jax(constants, t=None, y=None, '\
'y_dot=None, inputs=None, known_evals=None):\n' + python_str
python_str = (
"def evaluate_jax(constants, t=None, y=None, "
"y_dot=None, inputs=None, known_evals=None):\n" + python_str
)

# calculate the final variable that will output the result of calling `evaluate`
# on `symbol`
Expand All @@ -408,18 +418,15 @@ def __init__(self, symbol):

# add return line
if symbol.is_constant() and isinstance(result_value, numbers.Number):
python_str = python_str + '\n return ' + str(result_value)
python_str = python_str + "\n return " + str(result_value)
else:
python_str = python_str + '\n return ' + result_var
python_str = python_str + "\n return " + result_var

# store a copy of examine_jaxpr
python_str = python_str + \
'\nself._evaluate_jax = evaluate_jax'
python_str = python_str + "\nself._evaluate_jax = evaluate_jax"

# compile and run the generated python code,
compiled_function = compile(
python_str, result_var, "exec"
)
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)

self._jit_evaluate = jax.jit(self._evaluate_jax, static_argnums=(0, 4, 5))
Expand Down
Loading

0 comments on commit aa0b0b4

Please sign in to comment.