diff --git a/reproject/conftest.py b/reproject/conftest.py index bb3685062..cc4a7ab91 100644 --- a/reproject/conftest.py +++ b/reproject/conftest.py @@ -5,11 +5,11 @@ import os -import pytest import numpy as np -from astropy.wcs import WCS +import pytest from astropy.io import fits from astropy.nddata import NDData +from astropy.wcs import WCS try: from pytest_astropy_header.display import PYTEST_HEADER_MODULES, TESTED_VERSIONS @@ -37,23 +37,7 @@ def pytest_configure(config): TESTED_VERSIONS["reproject"] = __version__ -@pytest.fixture(params= [ - "filename", - "path", - "hdulist", - "hdulist_all", - "primary_hdu", - "image_hdu", - "comp_image_hdu", - "ape14_wcs", - "shape_wcs_tuple", - "data_wcs_tuple", - "nddata", - ]) def valid_celestial_input(tmp_path, request): - - request.param - array = np.ones((30, 40)) wcs = WCS(naxis=2) @@ -78,12 +62,10 @@ def valid_celestial_input(tmp_path, request): if request.param == "filename": input_value = str(input_value) hdulist.writeto(input_value) - kwargs['hdu_in'] = 0 + kwargs["hdu_in"] = 0 elif request.param == "hdulist": input_value = hdulist - kwargs['hdu_in'] = 1 - elif request.param == "hdulist_all": - input_value = hdulist + kwargs["hdu_in"] = 1 elif request.param == "primary_hdu": input_value = hdulist[0] elif request.param == "image_hdu": @@ -99,7 +81,47 @@ def valid_celestial_input(tmp_path, request): input_value = (array, wcs) elif request.param == "nddata": input_value = NDData(data=array, wcs=wcs) + elif request.param == "ape14_wcs": + input_value = wcs + input_value._naxis = list(array.shape[::-1]) + elif request.param == "shape_wcs_tuple": + input_value = (array.shape, wcs) + else: raise ValueError(f"Unknown mode: {request.param}") return array, wcs, input_value, kwargs + + +@pytest.fixture( + params=[ + "filename", + "path", + "hdulist", + "primary_hdu", + "image_hdu", + "comp_image_hdu", + "data_wcs_tuple", + "nddata", + ] +) +def valid_celestial_input_data(tmp_path, request): + return valid_celestial_input(tmp_path, request) + + +@pytest.fixture( + params=[ + "filename", + "path", + "hdulist", + "primary_hdu", + "image_hdu", + "comp_image_hdu", + "data_wcs_tuple", + "nddata", + "ape14_wcs", + "shape_wcs_tuple", + ] +) +def valid_celestial_input_shapes(tmp_path, request): + return valid_celestial_input(tmp_path, request) diff --git a/reproject/mosaicking/tests/test_wcs_helpers.py b/reproject/mosaicking/tests/test_wcs_helpers.py index 7aea8fbc8..87132f0ef 100644 --- a/reproject/mosaicking/tests/test_wcs_helpers.py +++ b/reproject/mosaicking/tests/test_wcs_helpers.py @@ -230,22 +230,23 @@ def test_args_tuple_header(self): @pytest.mark.parametrize("iterable", [False, True]) -def test_input_types(valid_celestial_input, iterable): - +def test_input_types(valid_celestial_input_shapes, iterable): # Test different kinds of inputs and check the result is always the same - array, wcs, input_value, kwargs = valid_celestial_input + array, wcs, input_value, kwargs = valid_celestial_input_shapes wcs_ref, shape_ref = find_optimal_celestial_wcs([(array, wcs)], frame=FK5()) - if isinstance(input_value, fits.HDUList) and iterable and kwargs == {}: - pytest.skip() - if iterable: input_value = [input_value] wcs_test, shape_test = find_optimal_celestial_wcs(input_value, frame=FK5(), **kwargs) - assert_header_allclose(wcs_test.to_header(), wcs_ref.to_header()) - assert shape_test == shape_ref + + if isinstance(input_value, fits.HDUList) and not iterable: + # Also check case of not passing hdu_in and having all HDUs being included + + wcs_test, shape_test = find_optimal_celestial_wcs(input_value, frame=FK5()) + assert_header_allclose(wcs_test.to_header(), wcs_ref.to_header()) + assert shape_test == shape_ref diff --git a/reproject/tests/test_utils.py b/reproject/tests/test_utils.py index 5908f231b..43d4f3f69 100644 --- a/reproject/tests/test_utils.py +++ b/reproject/tests/test_utils.py @@ -5,115 +5,51 @@ from astropy.utils.data import get_pkg_data_filename from astropy.wcs import WCS +from reproject.tests.helpers import assert_header_allclose from reproject.utils import parse_input_data, parse_input_shape, parse_output_projection @pytest.mark.filterwarnings("ignore:unclosed file:ResourceWarning") -def test_parse_input_data(tmpdir): - header = fits.Header.fromtextfile(get_pkg_data_filename("data/gc_ga.hdr")) +def test_parse_input_data(tmpdir, valid_celestial_input_data, request): + array_ref, wcs_ref, input_value, kwargs = valid_celestial_input_data - data = np.arange(200).reshape((10, 20)) + data, wcs = parse_input_data(input_value, **kwargs) + np.testing.assert_allclose(data, array_ref) + assert_header_allclose(wcs.to_header(), wcs_ref.to_header()) - hdu = fits.ImageHDU(data, header) - # We want to test that the WCS is being parsed and output correctly in each - # of these cases. WCS doesn't seem to implement __eq__, so we convert the - # output WCS to a Header and compare that. Here we convert the original - # Header to a WCS and back to ensure an apples-to-apples comparision. - ref_coord_system = WCS(header).to_header() +def test_parse_input_data_invalid(): + data = np.ones((30, 40)) - # As HDU - array, coordinate_system = parse_input_data(hdu) - np.testing.assert_allclose(array, data) - assert coordinate_system.to_header() == ref_coord_system + with pytest.raises(TypeError, match="input_data should either be an HDU object"): + parse_input_data(data) - # As filename - filename = tmpdir.join("test.fits").strpath - hdu.writeto(filename) - with pytest.raises(ValueError) as exc: - array, coordinate_system = parse_input_data(filename) - assert exc.value.args[0] == ( - "More than one HDU is present, please specify HDU to use with ``hdu_in=`` option" +def test_parse_input_shape_missing_hdu_in(): + hdulist = fits.HDUList( + [fits.PrimaryHDU(data=np.ones((30, 40))), fits.ImageHDU(data=np.ones((20, 30)))] ) - array, coordinate_system = parse_input_data(filename, hdu_in=1) - np.testing.assert_allclose(array, data) - assert coordinate_system.to_header() == ref_coord_system - - # As array, header - array, coordinate_system = parse_input_data((data, header)) - np.testing.assert_allclose(array, data) - assert coordinate_system.to_header() == ref_coord_system - - # As array, WCS - wcs = WCS(hdu.header) - array, coordinate_system = parse_input_data((data, wcs)) - np.testing.assert_allclose(array, data) - assert coordinate_system is wcs - - ndd = NDData(data, wcs=wcs) - array, coordinate_system = parse_input_data(ndd) - np.testing.assert_allclose(array, data) - assert coordinate_system is wcs - - # Invalid - with pytest.raises(TypeError) as exc: - parse_input_data(data) - assert exc.value.args[0] == ( - "input_data should either be an HDU object or a tuple of (array, WCS) or (array, Header)" - ) + with pytest.raises(TypeError, match="More than one HDU"): + parse_input_data(hdulist) @pytest.mark.filterwarnings("ignore:unclosed file:ResourceWarning") -def test_parse_input_shape(tmpdir): +def test_parse_input_shape(tmpdir, valid_celestial_input_shapes): """ This should support everything that parse_input_data does, *plus* an "array-like" argument that is just a shape rather than a populated array. """ - header = fits.Header.fromtextfile(get_pkg_data_filename("data/gc_ga.hdr")) - in_shape = (10, 20) - data = np.arange(200).reshape(in_shape) - hdu = fits.ImageHDU(data) - - # As HDU - shape, coordinate_system = parse_input_shape(hdu) - assert shape == in_shape - # As filename - filename = tmpdir.join("test.fits").strpath - hdu.writeto(filename) + array_ref, wcs_ref, input_value, kwargs = valid_celestial_input_shapes - with pytest.raises(ValueError) as exc: - shape, coordinate_system = parse_input_shape(filename) - assert exc.value.args[0] == ( - "More than one HDU is present, please specify HDU to use with ``hdu_in=`` option" - ) - - shape, coordinate_system = parse_input_shape(filename, hdu_in=1) - assert shape == in_shape - - # As array, header - shape, coordinate_system = parse_input_shape((data, header)) - assert shape == in_shape - - # As array, WCS - wcs = WCS(hdu.header) - shape, coordinate_system = parse_input_shape((data, wcs)) - assert shape == in_shape - - ndd = NDData(data, wcs=wcs) - shape, coordinate_system = parse_input_shape(ndd) - assert shape == in_shape - assert coordinate_system is wcs + shape, wcs = parse_input_shape(input_value, **kwargs) + assert shape == array_ref.shape + assert_header_allclose(wcs.to_header(), wcs_ref.to_header()) - # As shape, header - shape, coordinate_system = parse_input_shape((data.shape, header)) - assert shape == in_shape - # As shape, WCS - shape, coordinate_system = parse_input_shape((data.shape, wcs)) - assert shape == in_shape +def test_parse_input_shape_invalid(): + data = np.ones((30, 40)) # Invalid with pytest.raises(TypeError) as exc: @@ -124,6 +60,18 @@ def test_parse_input_shape(tmpdir): ) +def test_parse_input_shape_missing_hdu_in(): + hdulist = fits.HDUList( + [fits.PrimaryHDU(data=np.ones((30, 40))), fits.ImageHDU(data=np.ones((20, 30)))] + ) + + with pytest.raises(ValueError) as exc: + shape, coordinate_system = parse_input_shape(hdulist) + assert exc.value.args[0] == ( + "More than one HDU is present, please specify HDU to use with ``hdu_in=`` option" + ) + + def test_parse_output_projection(tmpdir): header = fits.Header.fromtextfile(get_pkg_data_filename("data/gc_ga.hdr")) wcs = WCS(header) diff --git a/reproject/utils.py b/reproject/utils.py index d7ba76c9c..8407a4a30 100644 --- a/reproject/utils.py +++ b/reproject/utils.py @@ -9,7 +9,7 @@ from astropy.io import fits from astropy.io.fits import CompImageHDU, HDUList, Header, ImageHDU, PrimaryHDU from astropy.wcs import WCS -from astropy.wcs.wcsapi import BaseLowLevelWCS, BaseHighLevelWCS, SlicedLowLevelWCS +from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, SlicedLowLevelWCS from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper from dask.utils import SerializableLock @@ -26,7 +26,7 @@ def parse_input_data(input_data, hdu_in=None): Parse input data to return a Numpy array and WCS object. """ - if isinstance(input_data, str): + if isinstance(input_data, (str, Path)): with fits.open(input_data) as hdul: return parse_input_data(hdul, hdu_in=hdu_in) elif isinstance(input_data, HDUList):