Skip to content

Commit

Permalink
Migrating unittest to pytest (Part 7) (#4431)
Browse files Browse the repository at this point in the history
* Migrating  unittest to pytest (Part 7)

Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>

* style: pre-commit fixes

* Removing style failures

Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>

* Update tests/unit/test_parameters/test_parameter_values.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_parameters/test_process_parameter_data.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_parameters/test_process_parameter_data.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_parameters/test_process_parameter_data.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_plotting/test_quick_plot.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_spatial_methods/test_spectral_volume.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_spatial_methods/test_spectral_volume.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_spatial_methods/test_spectral_volume.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_spatial_methods/test_spectral_volume.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_plotting/test_quick_plot.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_solvers/test_casadi_algebraic_solver.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_solvers/test_casadi_algebraic_solver.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_spatial_methods/test_spectral_volume.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_spatial_methods/test_spectral_volume.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update tests/unit/test_spatial_methods/test_spectral_volume.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Removing DepricatioWarning failure

Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>

---------

Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>
Co-authored-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
  • Loading branch information
4 people committed Sep 13, 2024
1 parent c2d8ac2 commit ba2aa67
Show file tree
Hide file tree
Showing 22 changed files with 693 additions and 964 deletions.
366 changes: 169 additions & 197 deletions tests/unit/test_parameters/test_parameter_values.py

Large diffs are not rendered by default.

56 changes: 23 additions & 33 deletions tests/unit/test_parameters/test_process_parameter_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,52 @@
import numpy as np
import pybamm

import unittest
import pytest


class TestProcessParameterData(unittest.TestCase):
class TestProcessParameterData:
def test_process_1D_data(self):
name = "lico2_ocv_example"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_1D_data(name, path)
self.assertEqual(processed[0], name)
self.assertIsInstance(processed[1], tuple)
self.assertIsInstance(processed[1][0][0], np.ndarray)
self.assertIsInstance(processed[1][1], np.ndarray)
assert processed[0] == name
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)

def test_process_2D_data(self):
name = "lico2_diffusivity_Dualfoil1998_2D"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_2D_data(name, path)
self.assertEqual(processed[0], name)
self.assertIsInstance(processed[1], tuple)
self.assertIsInstance(processed[1][0][0], np.ndarray)
self.assertIsInstance(processed[1][0][1], np.ndarray)
self.assertIsInstance(processed[1][1], np.ndarray)
assert processed[0] == name
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][0][1], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)

def test_process_2D_data_csv(self):
name = "data_for_testing_2D"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_2D_data_csv(name, path)

self.assertEqual(processed[0], name)
self.assertIsInstance(processed[1], tuple)
self.assertIsInstance(processed[1][0][0], np.ndarray)
self.assertIsInstance(processed[1][0][1], np.ndarray)
self.assertIsInstance(processed[1][1], np.ndarray)
assert processed[0] == name
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][0][1], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)

def test_process_3D_data_csv(self):
name = "data_for_testing_3D"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_3D_data_csv(name, path)

self.assertEqual(processed[0], name)
self.assertIsInstance(processed[1], tuple)
self.assertIsInstance(processed[1][0][0], np.ndarray)
self.assertIsInstance(processed[1][0][1], np.ndarray)
self.assertIsInstance(processed[1][0][2], np.ndarray)
self.assertIsInstance(processed[1][1], np.ndarray)
assert processed[0] == name
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][0][1], np.ndarray)
assert isinstance(processed[1][0][2], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)

def test_error(self):
with self.assertRaisesRegex(FileNotFoundError, "Could not find file"):
with pytest.raises(FileNotFoundError, match="Could not find file"):
pybamm.parameters.process_1D_data("not_a_real_file", "not_a_real_path")


if __name__ == "__main__":
print("Add -v for more debug output")
import sys

if "-v" in sys.argv:
debug = True
pybamm.settings.debug_mode = True
unittest.main()
113 changes: 50 additions & 63 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import pybamm
import unittest
import pytest

import numpy as np
from tempfile import TemporaryDirectory


class TestQuickPlot(unittest.TestCase):
class TestQuickPlot:
def test_simple_ode_model(self):
model = pybamm.lithium_ion.BaseModel(name="Simple ODE Model")

Expand Down Expand Up @@ -77,11 +77,11 @@ def test_simple_ode_model(self):
# update the axis
new_axis = [0, 0.5, 0, 1]
quick_plot.axis_limits.update({("a",): new_axis})
self.assertEqual(quick_plot.axis_limits[("a",)], new_axis)
assert quick_plot.axis_limits[("a",)] == new_axis

# and now reset them
quick_plot.reset_axis()
self.assertNotEqual(quick_plot.axis_limits[("a",)], new_axis)
assert quick_plot.axis_limits[("a",)] != new_axis

# check dynamic plot loads
quick_plot.dynamic_plot(show_plot=False)
Expand All @@ -90,7 +90,7 @@ def test_simple_ode_model(self):

# Test with different output variables
quick_plot = pybamm.QuickPlot(solution, ["b broadcasted"])
self.assertEqual(len(quick_plot.axis_limits), 1)
assert len(quick_plot.axis_limits) == 1
quick_plot.plot(0)

quick_plot = pybamm.QuickPlot(
Expand All @@ -103,18 +103,18 @@ def test_simple_ode_model(self):
"c broadcasted positive electrode",
],
)
self.assertEqual(len(quick_plot.axis_limits), 5)
assert len(quick_plot.axis_limits) == 5
quick_plot.plot(0)

# update the axis
new_axis = [0, 0.5, 0, 1]
var_key = ("c broadcasted",)
quick_plot.axis_limits.update({var_key: new_axis})
self.assertEqual(quick_plot.axis_limits[var_key], new_axis)
assert quick_plot.axis_limits[var_key] == new_axis

# and now reset them
quick_plot.reset_axis()
self.assertNotEqual(quick_plot.axis_limits[var_key], new_axis)
assert quick_plot.axis_limits[var_key] != new_axis

# check dynamic plot loads
quick_plot.dynamic_plot(show_plot=False)
Expand All @@ -135,19 +135,19 @@ def test_simple_ode_model(self):
labels=["sol 1", "sol 2"],
n_rows=2,
)
self.assertEqual(quick_plot.colors, ["r", "g", "b"])
self.assertEqual(quick_plot.linestyles, ["-", "--"])
self.assertEqual(quick_plot.figsize, (1, 2))
self.assertEqual(quick_plot.labels, ["sol 1", "sol 2"])
self.assertEqual(quick_plot.n_rows, 2)
self.assertEqual(quick_plot.n_cols, 1)
assert quick_plot.colors == ["r", "g", "b"]
assert quick_plot.linestyles == ["-", "--"]
assert quick_plot.figsize == (1, 2)
assert quick_plot.labels == ["sol 1", "sol 2"]
assert quick_plot.n_rows == 2
assert quick_plot.n_cols == 1

# Test different time units
quick_plot = pybamm.QuickPlot(solution, ["a"])
self.assertEqual(quick_plot.time_scaling_factor, 1)
assert quick_plot.time_scaling_factor == 1
quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="seconds")
quick_plot.plot(0)
self.assertEqual(quick_plot.time_scaling_factor, 1)
assert quick_plot.time_scaling_factor == 1
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_xdata(), t_eval
)
Expand All @@ -156,7 +156,7 @@ def test_simple_ode_model(self):
)
quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="minutes")
quick_plot.plot(0)
self.assertEqual(quick_plot.time_scaling_factor, 60)
assert quick_plot.time_scaling_factor == 60
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_xdata(), t_eval / 60
)
Expand All @@ -165,30 +165,30 @@ def test_simple_ode_model(self):
)
quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="hours")
quick_plot.plot(0)
self.assertEqual(quick_plot.time_scaling_factor, 3600)
assert quick_plot.time_scaling_factor == 3600
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_xdata(), t_eval / 3600
)
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_eval
)
with self.assertRaisesRegex(ValueError, "time unit"):
with pytest.raises(ValueError, match="time unit"):
pybamm.QuickPlot(solution, ["a"], time_unit="bad unit")
# long solution defaults to hours instead of seconds
solution_long = solver.solve(model, np.linspace(0, 1e5))
quick_plot = pybamm.QuickPlot(solution_long, ["a"])
self.assertEqual(quick_plot.time_scaling_factor, 3600)
assert quick_plot.time_scaling_factor == 3600

# Test different spatial units
quick_plot = pybamm.QuickPlot(solution, ["a"])
self.assertEqual(quick_plot.spatial_unit, r"$\mu$m")
assert quick_plot.spatial_unit == r"$\mu$m"
quick_plot = pybamm.QuickPlot(solution, ["a"], spatial_unit="m")
self.assertEqual(quick_plot.spatial_unit, "m")
assert quick_plot.spatial_unit == "m"
quick_plot = pybamm.QuickPlot(solution, ["a"], spatial_unit="mm")
self.assertEqual(quick_plot.spatial_unit, "mm")
assert quick_plot.spatial_unit == "mm"
quick_plot = pybamm.QuickPlot(solution, ["a"], spatial_unit="um")
self.assertEqual(quick_plot.spatial_unit, r"$\mu$m")
with self.assertRaisesRegex(ValueError, "spatial unit"):
assert quick_plot.spatial_unit == r"$\mu$m"
with pytest.raises(ValueError, match="spatial unit"):
pybamm.QuickPlot(solution, ["a"], spatial_unit="bad unit")

# Test 2D variables
Expand All @@ -197,24 +197,25 @@ def test_simple_ode_model(self):
quick_plot.dynamic_plot(show_plot=False)
quick_plot.slider_update(0.01)

with self.assertRaisesRegex(NotImplementedError, "Cannot plot 2D variables"):
with pytest.raises(NotImplementedError, match="Cannot plot 2D variables"):
pybamm.QuickPlot([solution, solution], ["2D variable"])

# Test different variable limits
quick_plot = pybamm.QuickPlot(
solution, ["a", ["c broadcasted", "c broadcasted"]], variable_limits="tight"
)
self.assertEqual(quick_plot.axis_limits[("a",)][2:], [None, None])
self.assertEqual(
quick_plot.axis_limits[("c broadcasted", "c broadcasted")][2:], [None, None]
)
assert quick_plot.axis_limits[("a",)][2:] == [None, None]
assert quick_plot.axis_limits[("c broadcasted", "c broadcasted")][2:] == [
None,
None,
]
quick_plot.plot(0)
quick_plot.slider_update(1)

quick_plot = pybamm.QuickPlot(
solution, ["2D variable"], variable_limits="tight"
)
self.assertEqual(quick_plot.variable_limits[("2D variable",)], (None, None))
assert quick_plot.variable_limits[("2D variable",)] == (None, None)
quick_plot.plot(0)
quick_plot.slider_update(1)

Expand All @@ -223,41 +224,37 @@ def test_simple_ode_model(self):
["a", ["c broadcasted", "c broadcasted"]],
variable_limits={"a": [1, 2], ("c broadcasted", "c broadcasted"): [3, 4]},
)
self.assertEqual(quick_plot.axis_limits[("a",)][2:], [1, 2])
self.assertEqual(
quick_plot.axis_limits[("c broadcasted", "c broadcasted")][2:], [3, 4]
)
assert quick_plot.axis_limits[("a",)][2:] == [1, 2]
assert quick_plot.axis_limits[("c broadcasted", "c broadcasted")][2:] == [3, 4]
quick_plot.plot(0)
quick_plot.slider_update(1)

quick_plot = pybamm.QuickPlot(
solution, ["a", "b broadcasted"], variable_limits={"a": "tight"}
)
self.assertEqual(quick_plot.axis_limits[("a",)][2:], [None, None])
self.assertNotEqual(
quick_plot.axis_limits[("b broadcasted",)][2:], [None, None]
)
assert quick_plot.axis_limits[("a",)][2:] == [None, None]
assert quick_plot.axis_limits[("b broadcasted",)][2:] != [None, None]
quick_plot.plot(0)
quick_plot.slider_update(1)

with self.assertRaisesRegex(
TypeError, "variable_limits must be 'fixed', 'tight', or a dict"
with pytest.raises(
TypeError, match="variable_limits must be 'fixed', 'tight', or a dict"
):
pybamm.QuickPlot(
solution, ["a", "b broadcasted"], variable_limits="bad variable limits"
)

# Test errors
with self.assertRaisesRegex(ValueError, "Mismatching variable domains"):
with pytest.raises(ValueError, match="Mismatching variable domains"):
pybamm.QuickPlot(solution, [["a", "b broadcasted"]])
with self.assertRaisesRegex(ValueError, "labels"):
with pytest.raises(ValueError, match="labels"):
pybamm.QuickPlot(
[solution, solution], ["a"], labels=["sol 1", "sol 2", "sol 3"]
)

# No variable can be NaN
with self.assertRaisesRegex(
ValueError, "All-NaN variable 'NaN variable' provided"
with pytest.raises(
ValueError, match="All-NaN variable 'NaN variable' provided"
):
pybamm.QuickPlot(solution, ["NaN variable"])

Expand All @@ -269,7 +266,7 @@ def test_plot_with_different_models(self):
model.rhs = {a: pybamm.Scalar(0)}
model.initial_conditions = {a: pybamm.Scalar(0)}
solution = pybamm.CasadiSolver("fast").solve(model, [0, 1])
with self.assertRaisesRegex(ValueError, "No default output variables"):
with pytest.raises(ValueError, match="No default output variables"):
pybamm.QuickPlot(solution)

def test_spm_simulation(self):
Expand Down Expand Up @@ -462,17 +459,17 @@ def test_plot_2plus1D_spm(self):
][1]
np.testing.assert_array_almost_equal(qp_data.T, phi_n[:, :, -1])

with self.assertRaisesRegex(NotImplementedError, "Shape not recognized for"):
with pytest.raises(NotImplementedError, match="Shape not recognized for"):
pybamm.QuickPlot(solution, ["Negative particle concentration [mol.m-3]"])

pybamm.close_plots()

def test_invalid_input_type_failure(self):
with self.assertRaisesRegex(TypeError, "Solutions must be"):
with pytest.raises(TypeError, match="Solutions must be"):
pybamm.QuickPlot(1)

def test_empty_list_failure(self):
with self.assertRaisesRegex(TypeError, "QuickPlot requires at least 1"):
with pytest.raises(TypeError, match="QuickPlot requires at least 1"):
pybamm.QuickPlot([])

def test_model_with_inputs(self):
Expand Down Expand Up @@ -509,20 +506,10 @@ def test_model_with_inputs(self):
pybamm.close_plots()


class TestQuickPlotAxes(unittest.TestCase):
class TestQuickPlotAxes:
def test_quick_plot_axes(self):
axes = pybamm.QuickPlotAxes()
axes.add(("test 1", "test 2"), 1)
self.assertEqual(axes[0], 1)
self.assertEqual(axes.by_variable("test 1"), 1)
self.assertEqual(axes.by_variable("test 2"), 1)


if __name__ == "__main__":
print("Add -v for more debug output")
import sys

if "-v" in sys.argv:
debug = True
pybamm.settings.debug_mode = True
unittest.main()
assert axes[0] == 1
assert axes.by_variable("test 1") == 1
assert axes.by_variable("test 2") == 1
Loading

0 comments on commit ba2aa67

Please sign in to comment.