Skip to content

Commit

Permalink
Merge pull request #1919 from Saransh-cpp/issue-1918-parameter-path
Browse files Browse the repository at this point in the history
Make parameters importable from a directory having "pybamm" in its name
  • Loading branch information
valentinsulzer authored Feb 5, 2022
2 parents 1c209f2 + eab470b commit 3d55351
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ setup.log

# test
test.c
test.json

# tox
.tox/
Expand Down
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
## Bug fixes

- Fixed a bug where thermal submodels could not be used with half-cells ([#1929](https://github.com/pybamm-team/PyBaMM/pull/1929))
- Parameters can now be imported from a directory having "pybamm" in its name ([#1919](https://github.com/pybamm-team/PyBaMM/pull/1919))
- `scikit.odes` and `SUNDIALS` can now be installed using `pybamm_install_odes` ([#1916](https://github.com/pybamm-team/PyBaMM/pull/1916))

## Breaking changes

- The `domain` setter and `auxiliary_domains` getter have been deprecated, `domains` setter/getter should be used instead. The `domain` getter is still active. We now recommend creating symbols with `domains={...}` instead of `domain=..., auxiliary_domains={...}`, but the latter is not yet deprecated ([#1866](https://github.com/pybamm-team/PyBaMM/pull/1866))

## Bug fixes

- `scikit.odes` and `SUNDIALS` can now be installed using `pybamm_install_odes` ([#1916](https://github.com/pybamm-team/PyBaMM/pull/1916))

# [v22.1](https://github.com/pybamm-team/PyBaMM/tree/v22.1) - 2022-01-31

## Features
Expand Down
10 changes: 5 additions & 5 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,25 +288,25 @@ def load_function(filename):
orig_dir = os.getcwd()

# Strip absolute path to pybamm/input/example.py
if "pybamm" in filename:
if "pybamm/input/parameters" in filename or "pybamm\\input\\parameters" in filename:
root_path = filename[filename.rfind("pybamm") :]
# If the function is in the current working directory
elif os.getcwd() in filename:
root_path = filename.replace(os.getcwd(), "")
# getcwd() returns "C:\\" when in the root drive and "C:\\a\\b\\c" otherwise
if root_path[0] == "\\" or root_path[0] == "/":
root_path = root_path[1:]
# If the function is not in the current working directory and the path provided is
# absolute
elif os.path.isabs(filename) and not os.getcwd() in filename: # pragma: no cover
# Change directory to import the function
dir_path = os.path.split(filename)[0]
os.chdir(dir_path)
root_path = filename.replace(os.getcwd(), "")
root_path = root_path[1:]
else:
root_path = filename

# getcwd() returns "C:\\" when in the root drive and "C:\\a\\b\\c" otherwise
if root_path[0] == "\\" or root_path[0] == "/":
root_path = root_path[1:]

path = root_path.replace("/", ".")
path = path.replace("\\", ".")
pybamm.logger.debug(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_spm_simulation(self):
quick_plot.plot(0)

# test creating a GIF
quick_plot.create_gif(number_of_images=5, duration=3)
quick_plot.create_gif(number_of_images=3, duration=3)
assert not os.path.exists("plot*.png")
assert os.path.exists("plot.gif")
os.remove("plot.gif")
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,14 @@ def test_plot(self):

def test_create_gif(self):
sim = pybamm.Simulation(pybamm.lithium_ion.SPM())
t_eval = np.linspace(0, 100, 5)
sim.solve(t_eval=t_eval)
sim.solve(t_eval=[0, 10])

# create a GIF without calling the plot method
sim.create_gif(number_of_images=5, duration=1)
sim.create_gif(number_of_images=3, duration=1)

# call the plot method before creating the GIF
sim.plot(testing=True)
sim.create_gif(number_of_images=5, duration=1)
sim.create_gif(number_of_images=3, duration=1)

os.remove("plot.gif")

Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
import numpy as np
import os
import pybamm
import shutil
import tempfile
import unittest
import importlib
import subprocess
from unittest.mock import patch
from io import StringIO

# Insert .../x/y/z/PyBaMM in sys.path when running this file individually
import sys

if os.getcwd() not in sys.path:
sys.path.insert(0, os.getcwd())


class TestUtil(unittest.TestCase):
"""
Expand Down Expand Up @@ -75,6 +84,38 @@ def test_load_function(self):
pybamm.input.parameters.lithium_ion.negative_electrodes.graphite_Chen2020.graphite_LGM50_electrolyte_exchange_current_density_Chen2020.graphite_LGM50_electrolyte_exchange_current_density_Chen2020, # noqa
)

# Test function load for parameters in a directory having "pybamm" in its name
# create a new lithium_ion folder in the root PyBaMM directory
subprocess.run(["pybamm_edit_parameter", "lithium_ion"])

# path for a function in the created directory ->
# x/y/z/PyBaMM/lithium_ion/negative_electrode/ ....
test_path = os.path.join(
os.getcwd(),
"lithium_ion",
"negative_electrodes",
"graphite_Chen2020",
"graphite_LGM50_electrolyte_exchange_current_density_Chen2020.py",
)

# load the function
func = pybamm.load_function(test_path)

# cannot directly do - lithium_ion.negative_electrodes.graphite_Chen2020 as
# lithium_ion is not a python module
module_object = importlib.import_module(
"lithium_ion.negative_electrodes.graphite_Chen2020.graphite_LGM50_electrolyte_exchange_current_density_Chen2020" # noqa
)
self.assertEqual(
func,
getattr(
module_object,
"graphite_LGM50_electrolyte_exchange_current_density_Chen2020",
),
)

shutil.rmtree("lithium_ion")

def test_rmse(self):
self.assertEqual(pybamm.rmse(np.ones(5), np.zeros(5)), 1)
self.assertEqual(pybamm.rmse(2 * np.ones(5), np.zeros(5)), 2)
Expand Down Expand Up @@ -127,6 +168,7 @@ def test_get_parameters_filepath(self):
tempfile_obj = tempfile.NamedTemporaryFile("w", dir=package_dir)
path = os.path.join(package_dir, tempfile_obj.name)
self.assertTrue(pybamm.get_parameters_filepath(tempfile_obj.name) == path)
tempfile_obj.close()

def test_is_jax_compatible(self):
if pybamm.have_jax():
Expand Down

0 comments on commit 3d55351

Please sign in to comment.