Skip to content

Commit

Permalink
Issue 1500 print parameter info by submodel (pybamm-team#3846)
Browse files Browse the repository at this point in the history
* Add "by_submodel" feature and modify docstring

* Added line in CHANGELOG.md

* Implemented error handling if `get_parameter_info` method is used directly on any submodel.

* Implement changes to fix problem: Keep a track of what variables were added by get_coupled_variables in model.variables

* Implement changes to fix problem: Keep a track of what variables were added by get_coupled_variables in model.variables V2

* Created object `self.variables_by_submodel` which contains submodel's name as key and its variables as value.
Modified `get_parameter_info` to parse through the submodels' variables and give out `parameter_info` in `by_submodel=True` condition
Modified `print_parameter_info` to print `parameter_info` of a model with an option to print submodel wise using `by_submodel=True`

* Removed redundant comments.

* style: pre-commit fixes

* Implemented "NotImplementedError" when user tries to use "get_parameter_info" or "print_parameter_info" directly on a submodel

* Implemented "NotImplementedError" when user tries to use "get_parameter_info" or "print_parameter_info" directly on a submodel V2

* Updated Docstring of "get_parameter_info"

* Added Test Case for "NotImplementedError" when "get_parameter_info" or "print_parameter_info" is used on a submodel.

* Renamed variables, updated test, updated docstring

* Added "_find_symbols_by_submodel" method and modified "get_parameter_info" to get new parameters

* Optimised "print_parameter_info" by simplification, improved readability and reduced repetition

* Removed "calculate_max_lengths" and "format_table_row" from being nested inside into semi-private methods

* Tests for "get_parameter_info" for both "by_submodel=False" and "by_submodel=True"

* Removed duplicate test in "test_base_submodel" for when "get_parameter_info" is used directly on a submodel

* added `test_get_parameter_info_submodel` test using custom model

* added `test_print_parameter_info` and `test_print_parameter_info_submodel`

* Modified `test_get_parameter_info_submodel` to include edge cases

* formatted `test_base_model.py`

* Moved change log to unreleased

* Changed the formatted table's style

* Added `UTF-8` encoding to the format table

* Updated jupyter notebook `parameterization.ipynb` and added `UTF-8` encoding in environment variables

* added utf-8 encoding in PYBAMM_ENV

* Resolve minor issues

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Robert Timms <43040151+rtimms@users.noreply.github.com>
Co-authored-by: Arjun Verma <arjunverma.oc@gmail.com>
Co-authored-by: Eric G. Kratz <kratman@users.noreply.github.com>
Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
Co-authored-by: cringeyburger <cringeyburger>
  • Loading branch information
6 people committed Apr 4, 2024
1 parent 8512040 commit fe4ac31
Show file tree
Hide file tree
Showing 6 changed files with 641 additions and 206 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Modified `step` function to take an array of time `t_eval` as an argument and deprecated use of `npts`. ([#3627](https://github.com/pybamm-team/PyBaMM/pull/3627))
- Renamed "electrode diffusivity" to "particle diffusivity" as a non-breaking change with a deprecation warning ([#3624](https://github.com/pybamm-team/PyBaMM/pull/3624))
- Add support for BPX version 0.4.0 which allows for blended electrodes and user-defined parameters in BPX([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414))
- Added `by_submodel` feature in `print_parameter_info` method to allow users to print parameters and types of submodels in a tabular and readable format ([#3628](https://github.com/pybamm-team/PyBaMM/pull/3628))

## Bug Fixes

Expand Down
297 changes: 151 additions & 146 deletions docs/source/examples/notebooks/parameterization/parameterization.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PYBAMM_ENV = {
"SUNDIALS_INST": f"{homedir}/.local",
"LD_LIBRARY_PATH": f"{homedir}/.local/lib",
"PYTHONIOENCODING": "utf-8",
}
VENV_DIR = Path("./venv").resolve()

Expand Down
299 changes: 240 additions & 59 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self, name="Unnamed model"):
self._algebraic = {}
self._initial_conditions = {}
self._boundary_conditions = {}
self._variables_by_submodel = {}
self._variables = pybamm.FuzzyDict({})
self._events = []
self._concatenated_rhs = None
Expand Down Expand Up @@ -421,83 +422,232 @@ def input_parameters(self):
self._input_parameters = self._find_symbols(pybamm.InputParameter)
return self._input_parameters

def get_parameter_info(self):
def get_parameter_info(self, by_submodel=False):
"""
Extracts the parameter information and returns it as a dictionary.
To get a list of all parameter-like objects without extra information,
use :py:attr:`model.parameters`.
Parameters
----------
by_submodel : bool, optional
Whether to return the parameter info sub-model wise or not (default False)
"""
parameter_info = {}
parameters = self._find_symbols(pybamm.Parameter)
for param in parameters:
parameter_info[param.name] = (param, "Parameter")

input_parameters = self._find_symbols(pybamm.InputParameter)
for input_param in input_parameters:
if not input_param.domain:
parameter_info[input_param.name] = (input_param, "InputParameter")
else:
parameter_info[input_param.name] = (
input_param,
f"InputParameter in {input_param.domain}",

if by_submodel:
for submodel_name, submodel_vars in self._variables_by_submodel.items():
submodel_info = {}
for var_name, var_symbol in submodel_vars.items():
if isinstance(var_symbol, pybamm.Parameter):
submodel_info[var_name] = (var_symbol, "Parameter")
elif isinstance(var_symbol, pybamm.InputParameter):
if not var_symbol.domain:
submodel_info[var_name] = (var_symbol, "InputParameter")
else:
submodel_info[var_name] = (
var_symbol,
f"InputParameter in {var_symbol.domain}",
)
elif isinstance(var_symbol, pybamm.FunctionParameter):
input_names = "', '".join(var_symbol.input_names)
submodel_info[var_name] = (
var_symbol,
f"FunctionParameter with inputs(s) '{input_names}'",
)
else:
submodel_info[var_name] = (var_symbol, "Unknown Type")

parameters = self._find_symbols_by_submodel(
pybamm.Parameter, submodel_name
)
for param in parameters:
submodel_info[param.name] = (param, "Parameter")

function_parameters = self._find_symbols(pybamm.FunctionParameter)
for func_param in function_parameters:
if func_param.name not in parameter_info:
input_names = "', '".join(func_param.input_names)
parameter_info[func_param.name] = (
func_param,
f"FunctionParameter with inputs(s) '{input_names}'",
input_parameters = self._find_symbols_by_submodel(
pybamm.InputParameter, submodel_name
)
for input_param in input_parameters:
if not input_param.domain:
submodel_info[input_param.name] = (
input_param,
"InputParameter",
)
else:
submodel_info[input_param.name] = (
input_param,
f"InputParameter in {input_param.domain}",
)

return parameter_info
function_parameters = self._find_symbols_by_submodel(
pybamm.FunctionParameter, submodel_name
)
for func_param in function_parameters:
if func_param.name not in parameter_info:
input_names = "', '".join(func_param.input_names)
submodel_info[func_param.name] = (
func_param,
f"FunctionParameter with inputs(s) '{input_names}'",
)

parameter_info[submodel_name] = submodel_info

else:
parameters = self._find_symbols(pybamm.Parameter)
for param in parameters:
parameter_info[param.name] = (param, "Parameter")

input_parameters = self._find_symbols(pybamm.InputParameter)
for input_param in input_parameters:
if not input_param.domain:
parameter_info[input_param.name] = (input_param, "InputParameter")
else:
parameter_info[input_param.name] = (
input_param,
f"InputParameter in {input_param.domain}",
)

def print_parameter_info(self):
"""Print parameter information in a formatted table from a dictionary of parameters"""
info = self.get_parameter_info()
max_param_name_length = 0
max_param_type_length = 0
function_parameters = self._find_symbols(pybamm.FunctionParameter)
for func_param in function_parameters:
if func_param.name not in parameter_info:
input_names = "', '".join(func_param.input_names)
parameter_info[func_param.name] = (
func_param,
f"FunctionParameter with inputs(s) '{input_names}'",
)

for param, param_type in info.values():
param_name_length = len(getattr(param, "name", str(param)))
param_type_length = len(param_type)
max_param_name_length = max(max_param_name_length, param_name_length)
max_param_type_length = max(max_param_type_length, param_type_length)
return parameter_info

header_format = (
f"| {{:<{max_param_name_length}}} | {{:<{max_param_type_length}}} |"
def _calculate_max_lengths(self, parameter_dict):
"""
Calculate the maximum length of parameters and parameter type in a dictionary
Parameters
----------
parameter_dict : dict
The dict from which maximum lengths are calculated
"""
max_name_length = max(
len(getattr(parameter, "name", str(parameter)))
for parameter, _ in parameter_dict.values()
)
row_format = (
f"| {{:<{max_param_name_length}}} | {{:<{max_param_type_length}}} |"
max_type_length = max(
len(parameter_type) for _, parameter_type in parameter_dict.values()
)

table = [
header_format.format("Parameter", "Type of parameter"),
header_format.format(
"=" * max_param_name_length, "=" * max_param_type_length
),
return max_name_length, max_type_length

def _format_table_row(
self, param_name, param_type, max_name_length, max_type_length
):
"""
Format the parameter information in a formatted table
Parameters
----------
param_name : str
The name of the parameter
param_type : str
The type of the parameter
max_name_length : int
The maximum length of the parameter in the dictionary
max_type_length : int
The maximum length of the parameter type in the dictionary
"""
param_name_lines = [
param_name[i : i + max_name_length]
for i in range(0, len(param_name), max_name_length)
]
param_type_lines = [
param_type[i : i + max_type_length]
for i in range(0, len(param_type), max_type_length)
]
max_lines = max(len(param_name_lines), len(param_type_lines))

return [
f"│ {param_name_lines[i]:<{max_name_length}}{param_type_lines[i]:<{max_type_length}} │"
for i in range(max_lines)
]

for param, param_type in info.values():
param_name = getattr(param, "name", str(param))
param_name_lines = [
param_name[i : i + max_param_name_length]
for i in range(0, len(param_name), max_param_name_length)
]
param_type_lines = [
param_type[i : i + max_param_type_length]
for i in range(0, len(param_type), max_param_type_length)
def print_parameter_info(self, by_submodel=False):
"""
Print parameter information in a formatted table from a dictionary of parameters
Parameters
----------
by_submodel : bool, optional
Whether to print the parameter info sub-model wise or not (default False)
"""

if by_submodel:
parameter_info = self.get_parameter_info(by_submodel=True)
for submodel_name, submodel_vars in parameter_info.items():
if not submodel_vars:
print(f"'{submodel_name}' submodel parameters: \nNo parameters\n")
else:
print(f"'{submodel_name}' submodel parameters:")
(
max_param_name_length,
max_param_type_length,
) = self._calculate_max_lengths(submodel_vars)

table = [
f"┌─{'─' * max_param_name_length}─┬─{'─' * max_param_type_length}─┐",
f"│ {'Parameter':<{max_param_name_length}}{'Type of parameter':<{max_param_type_length}} │",
f"├─{'─' * max_param_name_length}─┼─{'─' * max_param_type_length}─┤",
]

for param, param_type in submodel_vars.values():
param_name = getattr(param, "name", str(param))
table.extend(
self._format_table_row(
param_name,
param_type,
max_param_name_length,
max_param_type_length,
)
)
table.extend(
[
f"└─{'─' * max_param_name_length}─┴─{'─' * max_param_type_length}─┘",
]
)
table = "\n".join(table) + "\n"
table.encode("utf-8")
print(table)

else:
info = self.get_parameter_info()
max_param_name_length, max_param_type_length = self._calculate_max_lengths(
info
)

table = [
f"┌─{'─' * max_param_name_length}─┬─{'─' * max_param_type_length}─┐",
f"│ {'Parameter':<{max_param_name_length}}{'Type of parameter':<{max_param_type_length}} │",
f"├─{'─' * max_param_name_length}─┼─{'─' * max_param_type_length}─┤",
]
max_lines = max(len(param_name_lines), len(param_type_lines))

for i in range(max_lines):
param_line = param_name_lines[i] if i < len(param_name_lines) else ""
type_line = param_type_lines[i] if i < len(param_type_lines) else ""
table.append(row_format.format(param_line, type_line))
for param, param_type in info.values():
param_name = getattr(param, "name", str(param))
table.extend(
self._format_table_row(
param_name,
param_type,
max_param_name_length,
max_param_type_length,
)
)

for line in table:
print(line)
table.extend(
[
f"└─{'─' * max_param_name_length}─┴─{'─' * max_param_type_length}─┘",
]
)

table = "\n".join(table) + "\n"
table.encode("utf-8")
print(table)

def _find_symbols(self, typ):
"""Find all the instances of `typ` in the model"""
Expand All @@ -516,6 +666,23 @@ def _find_symbols(self, typ):
)
return list(all_input_parameters)

def _find_symbols_by_submodel(self, typ, submodel):
"""Find all the instances of `typ` in the submodel"""
unpacker = pybamm.SymbolUnpacker(typ)
all_input_parameters = unpacker.unpack_list_of_symbols(
list(self.submodels[submodel].rhs.values())
+ list(self.submodels[submodel].algebraic.values())
+ list(self.submodels[submodel].initial_conditions.values())
+ [
x[side][0]
for x in self.submodels[submodel].boundary_conditions.values()
for side in x.keys()
]
+ list(self._variables_by_submodel[submodel].values())
+ [event.expression for event in self.submodels[submodel].events]
)
return list(all_input_parameters)

def new_copy(self):
"""
Creates a copy of the model, explicitly copying all the mutable attributes
Expand Down Expand Up @@ -555,11 +722,16 @@ def update(self, *submodels):

def build_fundamental(self):
# Get the fundamental variables
self._variables_by_submodel = {submodel: {} for submodel in self.submodels}
for submodel_name, submodel in self.submodels.items():
pybamm.logger.debug(
f"Getting fundamental variables for {submodel_name} submodel ({self.name})"
)
self.variables.update(submodel.get_fundamental_variables())
submodel_fundamental_variables = submodel.get_fundamental_variables()
self._variables_by_submodel[submodel_name].update(
submodel_fundamental_variables
)
self.variables.update(submodel_fundamental_variables)

self._built_fundamental = True

Expand All @@ -581,9 +753,18 @@ def build_coupled_variables(self):
f"Getting coupled variables for {submodel_name} submodel ({self.name})"
)
try:
self.variables.update(
submodel.get_coupled_variables(self.variables)
model_var_copy = self.variables.copy()
updated_variables = submodel.get_coupled_variables(
self.variables
)
self._variables_by_submodel[submodel_name].update(
{
key: updated_variables[key]
for key in updated_variables
if key not in model_var_copy
}
)
self.variables.update(updated_variables)
submodels.remove(submodel_name)
except KeyError as key:
if len(submodels) == 1 or count == 100:
Expand Down
Loading

0 comments on commit fe4ac31

Please sign in to comment.