Skip to content

Commit

Permalink
More test simplification with fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Mar 3, 2023
1 parent ac23464 commit cb2729f
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 118 deletions.
66 changes: 44 additions & 22 deletions reproject/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import os

import pytest
import numpy as np
from astropy.wcs import WCS
import pytest
from astropy.io import fits
from astropy.nddata import NDData
from astropy.wcs import WCS

try:
from pytest_astropy_header.display import PYTEST_HEADER_MODULES, TESTED_VERSIONS
Expand Down Expand Up @@ -37,23 +37,7 @@ def pytest_configure(config):
TESTED_VERSIONS["reproject"] = __version__


@pytest.fixture(params= [
"filename",
"path",
"hdulist",
"hdulist_all",
"primary_hdu",
"image_hdu",
"comp_image_hdu",
"ape14_wcs",
"shape_wcs_tuple",
"data_wcs_tuple",
"nddata",
])
def valid_celestial_input(tmp_path, request):

request.param

array = np.ones((30, 40))

wcs = WCS(naxis=2)
Expand All @@ -78,12 +62,10 @@ def valid_celestial_input(tmp_path, request):
if request.param == "filename":
input_value = str(input_value)
hdulist.writeto(input_value)
kwargs['hdu_in'] = 0
kwargs["hdu_in"] = 0
elif request.param == "hdulist":
input_value = hdulist
kwargs['hdu_in'] = 1
elif request.param == "hdulist_all":
input_value = hdulist
kwargs["hdu_in"] = 1
elif request.param == "primary_hdu":
input_value = hdulist[0]
elif request.param == "image_hdu":
Expand All @@ -99,7 +81,47 @@ def valid_celestial_input(tmp_path, request):
input_value = (array, wcs)
elif request.param == "nddata":
input_value = NDData(data=array, wcs=wcs)
elif request.param == "ape14_wcs":
input_value = wcs
input_value._naxis = list(array.shape[::-1])
elif request.param == "shape_wcs_tuple":
input_value = (array.shape, wcs)

else:
raise ValueError(f"Unknown mode: {request.param}")

return array, wcs, input_value, kwargs


@pytest.fixture(
params=[
"filename",
"path",
"hdulist",
"primary_hdu",
"image_hdu",
"comp_image_hdu",
"data_wcs_tuple",
"nddata",
]
)
def valid_celestial_input_data(tmp_path, request):
return valid_celestial_input(tmp_path, request)


@pytest.fixture(
params=[
"filename",
"path",
"hdulist",
"primary_hdu",
"image_hdu",
"comp_image_hdu",
"data_wcs_tuple",
"nddata",
"ape14_wcs",
"shape_wcs_tuple",
]
)
def valid_celestial_input_shapes(tmp_path, request):
return valid_celestial_input(tmp_path, request)
17 changes: 9 additions & 8 deletions reproject/mosaicking/tests/test_wcs_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,23 @@ def test_args_tuple_header(self):


@pytest.mark.parametrize("iterable", [False, True])
def test_input_types(valid_celestial_input, iterable):

def test_input_types(valid_celestial_input_shapes, iterable):
# Test different kinds of inputs and check the result is always the same

array, wcs, input_value, kwargs = valid_celestial_input
array, wcs, input_value, kwargs = valid_celestial_input_shapes

wcs_ref, shape_ref = find_optimal_celestial_wcs([(array, wcs)], frame=FK5())

if isinstance(input_value, fits.HDUList) and iterable and kwargs == {}:
pytest.skip()

if iterable:
input_value = [input_value]

wcs_test, shape_test = find_optimal_celestial_wcs(input_value, frame=FK5(), **kwargs)

assert_header_allclose(wcs_test.to_header(), wcs_ref.to_header())

assert shape_test == shape_ref

if isinstance(input_value, fits.HDUList) and not iterable:
# Also check case of not passing hdu_in and having all HDUs being included

wcs_test, shape_test = find_optimal_celestial_wcs(input_value, frame=FK5())
assert_header_allclose(wcs_test.to_header(), wcs_ref.to_header())
assert shape_test == shape_ref
120 changes: 34 additions & 86 deletions reproject/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,115 +5,51 @@
from astropy.utils.data import get_pkg_data_filename
from astropy.wcs import WCS

from reproject.tests.helpers import assert_header_allclose
from reproject.utils import parse_input_data, parse_input_shape, parse_output_projection


@pytest.mark.filterwarnings("ignore:unclosed file:ResourceWarning")
def test_parse_input_data(tmpdir):
header = fits.Header.fromtextfile(get_pkg_data_filename("data/gc_ga.hdr"))
def test_parse_input_data(tmpdir, valid_celestial_input_data, request):
array_ref, wcs_ref, input_value, kwargs = valid_celestial_input_data

data = np.arange(200).reshape((10, 20))
data, wcs = parse_input_data(input_value, **kwargs)
np.testing.assert_allclose(data, array_ref)
assert_header_allclose(wcs.to_header(), wcs_ref.to_header())

hdu = fits.ImageHDU(data, header)

# We want to test that the WCS is being parsed and output correctly in each
# of these cases. WCS doesn't seem to implement __eq__, so we convert the
# output WCS to a Header and compare that. Here we convert the original
# Header to a WCS and back to ensure an apples-to-apples comparision.
ref_coord_system = WCS(header).to_header()
def test_parse_input_data_invalid():
data = np.ones((30, 40))

# As HDU
array, coordinate_system = parse_input_data(hdu)
np.testing.assert_allclose(array, data)
assert coordinate_system.to_header() == ref_coord_system
with pytest.raises(TypeError, match="input_data should either be an HDU object"):
parse_input_data(data)

# As filename
filename = tmpdir.join("test.fits").strpath
hdu.writeto(filename)

with pytest.raises(ValueError) as exc:
array, coordinate_system = parse_input_data(filename)
assert exc.value.args[0] == (
"More than one HDU is present, please specify HDU to use with ``hdu_in=`` option"
def test_parse_input_shape_missing_hdu_in():
hdulist = fits.HDUList(
[fits.PrimaryHDU(data=np.ones((30, 40))), fits.ImageHDU(data=np.ones((20, 30)))]
)

array, coordinate_system = parse_input_data(filename, hdu_in=1)
np.testing.assert_allclose(array, data)
assert coordinate_system.to_header() == ref_coord_system

# As array, header
array, coordinate_system = parse_input_data((data, header))
np.testing.assert_allclose(array, data)
assert coordinate_system.to_header() == ref_coord_system

# As array, WCS
wcs = WCS(hdu.header)
array, coordinate_system = parse_input_data((data, wcs))
np.testing.assert_allclose(array, data)
assert coordinate_system is wcs

ndd = NDData(data, wcs=wcs)
array, coordinate_system = parse_input_data(ndd)
np.testing.assert_allclose(array, data)
assert coordinate_system is wcs

# Invalid
with pytest.raises(TypeError) as exc:
parse_input_data(data)
assert exc.value.args[0] == (
"input_data should either be an HDU object or a tuple of (array, WCS) or (array, Header)"
)
with pytest.raises(TypeError, match="More than one HDU"):
parse_input_data(hdulist)


@pytest.mark.filterwarnings("ignore:unclosed file:ResourceWarning")
def test_parse_input_shape(tmpdir):
def test_parse_input_shape(tmpdir, valid_celestial_input_shapes):
"""
This should support everything that parse_input_data does, *plus* an
"array-like" argument that is just a shape rather than a populated array.
"""
header = fits.Header.fromtextfile(get_pkg_data_filename("data/gc_ga.hdr"))
in_shape = (10, 20)
data = np.arange(200).reshape(in_shape)
hdu = fits.ImageHDU(data)

# As HDU
shape, coordinate_system = parse_input_shape(hdu)
assert shape == in_shape

# As filename
filename = tmpdir.join("test.fits").strpath
hdu.writeto(filename)
array_ref, wcs_ref, input_value, kwargs = valid_celestial_input_shapes

with pytest.raises(ValueError) as exc:
shape, coordinate_system = parse_input_shape(filename)
assert exc.value.args[0] == (
"More than one HDU is present, please specify HDU to use with ``hdu_in=`` option"
)

shape, coordinate_system = parse_input_shape(filename, hdu_in=1)
assert shape == in_shape

# As array, header
shape, coordinate_system = parse_input_shape((data, header))
assert shape == in_shape

# As array, WCS
wcs = WCS(hdu.header)
shape, coordinate_system = parse_input_shape((data, wcs))
assert shape == in_shape

ndd = NDData(data, wcs=wcs)
shape, coordinate_system = parse_input_shape(ndd)
assert shape == in_shape
assert coordinate_system is wcs
shape, wcs = parse_input_shape(input_value, **kwargs)
assert shape == array_ref.shape
assert_header_allclose(wcs.to_header(), wcs_ref.to_header())

# As shape, header
shape, coordinate_system = parse_input_shape((data.shape, header))
assert shape == in_shape

# As shape, WCS
shape, coordinate_system = parse_input_shape((data.shape, wcs))
assert shape == in_shape
def test_parse_input_shape_invalid():
data = np.ones((30, 40))

# Invalid
with pytest.raises(TypeError) as exc:
Expand All @@ -124,6 +60,18 @@ def test_parse_input_shape(tmpdir):
)


def test_parse_input_shape_missing_hdu_in():
hdulist = fits.HDUList(
[fits.PrimaryHDU(data=np.ones((30, 40))), fits.ImageHDU(data=np.ones((20, 30)))]
)

with pytest.raises(ValueError) as exc:
shape, coordinate_system = parse_input_shape(hdulist)
assert exc.value.args[0] == (
"More than one HDU is present, please specify HDU to use with ``hdu_in=`` option"
)


def test_parse_output_projection(tmpdir):
header = fits.Header.fromtextfile(get_pkg_data_filename("data/gc_ga.hdr"))
wcs = WCS(header)
Expand Down
4 changes: 2 additions & 2 deletions reproject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from astropy.io import fits
from astropy.io.fits import CompImageHDU, HDUList, Header, ImageHDU, PrimaryHDU
from astropy.wcs import WCS
from astropy.wcs.wcsapi import BaseLowLevelWCS, BaseHighLevelWCS, SlicedLowLevelWCS
from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, SlicedLowLevelWCS
from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper
from dask.utils import SerializableLock

Expand All @@ -26,7 +26,7 @@ def parse_input_data(input_data, hdu_in=None):
Parse input data to return a Numpy array and WCS object.
"""

if isinstance(input_data, str):
if isinstance(input_data, (str, Path)):
with fits.open(input_data) as hdul:
return parse_input_data(hdul, hdu_in=hdu_in)
elif isinstance(input_data, HDUList):
Expand Down

0 comments on commit cb2729f

Please sign in to comment.