From 460e1e181865071d887b83d15a458f72fee88454 Mon Sep 17 00:00:00 2001 From: Garrett 'Karto' Keating Date: Fri, 31 Jan 2025 14:09:04 -0500 Subject: [PATCH] Adding missing test coverage in tools --- src/pyuvdata/utils/tools.py | 4 ++-- tests/utils/test_tools.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/pyuvdata/utils/tools.py b/src/pyuvdata/utils/tools.py index 3b43338d0..b8935e5fe 100644 --- a/src/pyuvdata/utils/tools.py +++ b/src/pyuvdata/utils/tools.py @@ -204,11 +204,11 @@ def _test_array_constant(array, *, tols=None): from pyuvdata.parameter import UVParameter if isinstance(array, UVParameter): - array_to_test = array.value + array_to_test = np.asarray(array.value) if tols is None: tols = array.tols else: - array_to_test = array + array_to_test = np.asarray(array) if tols is None: tols = (0, 0) assert isinstance(tols, tuple), "tols must be a length-2 tuple" diff --git a/tests/utils/test_tools.py b/tests/utils/test_tools.py index d09d35d20..711da6aaf 100644 --- a/tests/utils/test_tools.py +++ b/tests/utils/test_tools.py @@ -2,9 +2,11 @@ # Licensed under the 2-clause BSD License """Tests for helper utility functions.""" +import numpy as np import pytest from pyuvdata import utils +from pyuvdata.parameter import UVParameter from pyuvdata.testing import check_warnings @@ -89,3 +91,23 @@ def test_eval_inds(inds, nrecs, exp_output, nwarn): ): output = utils.tools._eval_inds(inds=inds, nrecs=nrecs, strict=False) assert all(exp_output == output) + + +@pytest.mark.parametrize("is_param", [True, False]) +@pytest.mark.parametrize( + "inp_arr,tols,exp_outcome", + [ + [np.array([0, 0, 0, 0]), (0, 0), True], + [[0, 0, 0, 0], None, True], + [[0, 0, 0, 1], (0, 0), False], + [[0, 0, 0, 1], None, False], + [[0, 0, 0, 1], (1, 0), True], + ], +) +def test_array_constant(inp_arr, is_param, tols, exp_outcome): + if is_param: + kwargs = {"value": inp_arr} + if tols is not None: + kwargs["tols"] = tols + inp_arr = UVParameter("test", **kwargs) + assert exp_outcome == utils.tools._test_array_constant(inp_arr, tols=tols)