Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved GWCS handling in Spectrum1D #1074

Merged
merged 12 commits into from
Aug 28, 2023
7 changes: 6 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ New Features
- Spectral axis can now be any axis, rather than being forced to be last. See docs
for more details. [#1033]

- Spectrum1D now properly handles GWCS input for wcs attribute. [#1074]

Other Changes and Additions
^^^^^^^^^^^^^^^^^^^^^^^^^^^

1.11.0 (unreleased)
- JWST reader no longer transposes the input data cube for 3D data and retains
full GWCS information (including spatial). [#1074]

1.12.0 (unreleased)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs a change log entry

-------------------

New Features
Expand Down
44 changes: 5 additions & 39 deletions specutils/io/default_loaders/jwst_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from astropy.table import Table
from astropy.io import fits
from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance
from astropy.time import Time
from astropy.wcs import WCS
from gwcs.wcstools import grid_from_bounding_box

from ...spectra import Spectrum1D, SpectrumList
Expand Down Expand Up @@ -583,38 +581,9 @@ def _jwst_s3d_loader(filename, **kwargs):
except (ValueError, KeyError):
flux_unit = None

# The spectral axis is first. We need it last
flux_array = hdu.data.T
flux_array = hdu.data
flux = Quantity(flux_array, unit=flux_unit)

# Get the wavelength array from the GWCS object which returns a
# tuple of (RA, Dec, lambda).
# Since the spatial and spectral axes are orthogonal in s3d data,
# it is much faster to compute a slice down the spectral axis.
grid = grid_from_bounding_box(wcs.bounding_box)[:, :, 0, 0]
_, _, wavelength_array = wcs(*grid)
_, _, wavelength_unit = wcs.output_frame.unit

wavelength = Quantity(wavelength_array, unit=wavelength_unit)

# The GWCS is currently broken for some IFUs, here we work around that
wcs = None
if wavelength.shape[0] != flux.shape[-1]:
# Need MJD-OBS for this workaround
if 'MJD-OBS' not in hdu.header:
for key in ('MJD-BEG', 'DATE-OBS'): # Possible alternatives
if key in hdu.header:
if key.startswith('MJD'):
hdu.header['MJD-OBS'] = hdu.header[key]
break
else:
t = Time(hdu.header[key])
hdu.header['MJD-OBS'] = t.mjd
break
wcs = WCS(hdu.header)
# Swap to match the flux transpose
wcs = wcs.swapaxes(-1, 0)

# Merge primary and slit headers and dump into meta
slit_header = hdu.header
header = primary_header.copy()
Expand All @@ -625,7 +594,7 @@ def _jwst_s3d_loader(filename, **kwargs):
ext_name = primary_header.get("ERREXT", "ERR")
err_type = hdulist[ext_name].header.get("ERRTYPE", 'ERR')
err_unit = hdulist[ext_name].header.get("BUNIT", None)
err_array = hdulist[ext_name].data.T
err_array = hdulist[ext_name].data

# ERRTYPE can be one of "ERR", "IERR", "VAR", "IVAR"
# but mostly ERR for JWST cubes
Expand All @@ -643,13 +612,10 @@ def _jwst_s3d_loader(filename, **kwargs):

# get mask information
mask_name = primary_header.get("MASKEXT", "DQ")
mask = hdulist[mask_name].data.T
mask = hdulist[mask_name].data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why was data being transposed before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specutils currently demands that the spectral axis be last in a multidimensional spectrum. Specutils 2.0 is going to remove that requirements, so we'll be able to keep the data in the same shape as you would get if you read it in with say astropy.io.fits.


spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask, spectral_axis_index=0)

if wcs is not None:
spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask)
else:
spec = Spectrum1D(flux=flux, spectral_axis=wavelength, meta=meta,
uncertainty=err, mask=mask)
spectra.append(spec)

return SpectrumList(spectra)
2 changes: 1 addition & 1 deletion specutils/io/default_loaders/tests/test_jwst_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def test_jwst_s3d_single(tmp_path, cube):

data = Spectrum1D.read(tmpfile, format='JWST s3d')
assert type(data) is Spectrum1D
assert data.shape == (10, 10, 30)
assert data.shape == (30, 10, 10)
assert data.uncertainty is not None
assert data.mask is not None
assert data.uncertainty.unit == 'MJy'
59 changes: 43 additions & 16 deletions specutils/spectra/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import numpy as np
from astropy import units as u
from astropy.coordinates import SpectralCoord
from astropy.utils.decorators import lazyproperty
from astropy.utils.decorators import deprecated
from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin
from gwcs.wcs import WCS as GWCS

from .spectral_axis import SpectralAxis
from .spectrum_mixin import OneDSpectrumMixin
Expand Down Expand Up @@ -228,24 +230,34 @@
f"of the corresponding flux axis ({flux.shape[self.spectral_axis_index]})")

# If a WCS is provided, determine which axis is the spectral axis
if wcs is not None and hasattr(wcs, "naxis"):
if wcs.naxis > 1:
if wcs is not None:
naxis = None
if hasattr(wcs, "naxis"):
naxis = wcs.naxis
# GWCS doesn't have naxis
elif hasattr(wcs, "world_n_dim"):
naxis = wcs.world_n_dim

if naxis is not None and naxis > 1:
temp_axes = []
phys_axes = wcs.world_axis_physical_types
for i in range(len(phys_axes)):
if phys_axes[i] is None:
continue
if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect":
temp_axes.append(i)
if len(temp_axes) != 1:
raise ValueError("Input WCS must have exactly one axis with "
"spectral units, found {}".format(len(temp_axes)))
else:
# Due to FITS conventions, the WCS axes are listed in opposite
# order compared to the data array.
self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1
if self._spectral_axis_index is None:
for i in range(len(phys_axes)):
if phys_axes[i] is None:
continue
if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect":
temp_axes.append(i)
if len(temp_axes) != 1:
raise ValueError("Input WCS must have exactly one axis with "

Check warning on line 251 in specutils/spectra/spectrum1d.py

View check run for this annotation

Codecov / codecov/patch

specutils/spectra/spectrum1d.py#L251

Added line #L251 was not covered by tests
"spectral units, found {}".format(len(temp_axes)))
else:
# Due to FITS conventions, the WCS axes are listed in opposite
# order compared to the data array.
self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1

if move_spectral_axis is not None:
if isinstance(wcs, GWCS):
raise ValueError("move_spectral_axis cannot be used with GWCS")

Check warning on line 260 in specutils/spectra/spectrum1d.py

View check run for this annotation

Codecov / codecov/patch

specutils/spectra/spectrum1d.py#L260

Added line #L260 was not covered by tests
if isinstance(move_spectral_axis, str):
if move_spectral_axis.lower() == 'first':
move_to_index = 0
Expand Down Expand Up @@ -353,8 +365,23 @@
spec_axis = self.wcs.spectral.pixel_to_world(
np.arange(self.flux.shape[self.spectral_axis_index]))
else:
spec_axis = self.wcs.pixel_to_world(
np.arange(self.flux.shape[self.spectral_axis_index]))
# We now keep the entire GWCS, including spatial information, so we need to include
# all axes in the pixel_to_world call. Note that this assumes/requires that the
# dispersion is the same at all spatial locations.
wcs_args = []
for i in range(len(self.flux.shape)):
wcs_args.append(np.zeros(self.flux.shape[self.spectral_axis_index]))
# Replace with arange for the spectral axis
wcs_args[self.spectral_axis_index] = np.arange(self.flux.shape[self.spectral_axis_index])
wcs_args.reverse()
temp_coords = self.wcs.pixel_to_world(*wcs_args)
# If there are spatial axes, temp_coords will have a SkyCoord and a SpectralCoord
if isinstance(temp_coords, list):
for coords in temp_coords:
if isinstance(coords, SpectralCoord):
spec_axis = coords
else:
spec_axis = temp_coords

try:
if spec_axis.unit.is_equivalent(u.one):
Expand Down
Loading