diff --git a/CHANGES.rst b/CHANGES.rst index d66979614..e05078dfb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,11 @@ New Features - Spectral axis can now be any axis, rather than being forced to be last. See docs for more details. [#1033] +- Spectrum1D now properly handles GWCS input for wcs attribute. [#1074] + +- JWST reader no longer transposes the input data cube for 3D data and retains + full GWCS information (including spatial). [#1074] + Other Changes and Additions ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/specutils/io/default_loaders/jwst_reader.py b/specutils/io/default_loaders/jwst_reader.py index 423449c05..05e5b0e4a 100644 --- a/specutils/io/default_loaders/jwst_reader.py +++ b/specutils/io/default_loaders/jwst_reader.py @@ -8,8 +8,6 @@ from astropy.table import Table from astropy.io import fits from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance -from astropy.time import Time -from astropy.wcs import WCS from gwcs.wcstools import grid_from_bounding_box from ...spectra import Spectrum1D, SpectrumList @@ -579,38 +577,9 @@ def _jwst_s3d_loader(filename, **kwargs): except (ValueError, KeyError): flux_unit = None - # The spectral axis is first. We need it last - flux_array = hdu.data.T + flux_array = hdu.data flux = Quantity(flux_array, unit=flux_unit) - # Get the wavelength array from the GWCS object which returns a - # tuple of (RA, Dec, lambda). - # Since the spatial and spectral axes are orthogonal in s3d data, - # it is much faster to compute a slice down the spectral axis. - grid = grid_from_bounding_box(wcs.bounding_box)[:, :, 0, 0] - _, _, wavelength_array = wcs(*grid) - _, _, wavelength_unit = wcs.output_frame.unit - - wavelength = Quantity(wavelength_array, unit=wavelength_unit) - - # The GWCS is currently broken for some IFUs, here we work around that - wcs = None - if wavelength.shape[0] != flux.shape[-1]: - # Need MJD-OBS for this workaround - if 'MJD-OBS' not in hdu.header: - for key in ('MJD-BEG', 'DATE-OBS'): # Possible alternatives - if key in hdu.header: - if key.startswith('MJD'): - hdu.header['MJD-OBS'] = hdu.header[key] - break - else: - t = Time(hdu.header[key]) - hdu.header['MJD-OBS'] = t.mjd - break - wcs = WCS(hdu.header) - # Swap to match the flux transpose - wcs = wcs.swapaxes(-1, 0) - # Merge primary and slit headers and dump into meta slit_header = hdu.header header = primary_header.copy() @@ -621,7 +590,7 @@ def _jwst_s3d_loader(filename, **kwargs): ext_name = primary_header.get("ERREXT", "ERR") err_type = hdulist[ext_name].header.get("ERRTYPE", 'ERR') err_unit = hdulist[ext_name].header.get("BUNIT", None) - err_array = hdulist[ext_name].data.T + err_array = hdulist[ext_name].data # ERRTYPE can be one of "ERR", "IERR", "VAR", "IVAR" # but mostly ERR for JWST cubes @@ -639,13 +608,10 @@ def _jwst_s3d_loader(filename, **kwargs): # get mask information mask_name = primary_header.get("MASKEXT", "DQ") - mask = hdulist[mask_name].data.T + mask = hdulist[mask_name].data + + spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask, spectral_axis_index=0) - if wcs is not None: - spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask) - else: - spec = Spectrum1D(flux=flux, spectral_axis=wavelength, meta=meta, - uncertainty=err, mask=mask) spectra.append(spec) return SpectrumList(spectra) diff --git a/specutils/io/default_loaders/tests/test_jwst_reader.py b/specutils/io/default_loaders/tests/test_jwst_reader.py index 80b6d7799..50523b2f7 100644 --- a/specutils/io/default_loaders/tests/test_jwst_reader.py +++ b/specutils/io/default_loaders/tests/test_jwst_reader.py @@ -434,7 +434,7 @@ def test_jwst_s3d_single(tmp_path, cube): data = Spectrum1D.read(tmpfile, format='JWST s3d') assert type(data) is Spectrum1D - assert data.shape == (10, 10, 30) + assert data.shape == (30, 10, 10) assert data.uncertainty is not None assert data.mask is not None assert data.uncertainty.unit == 'MJy' diff --git a/specutils/spectra/spectrum1d.py b/specutils/spectra/spectrum1d.py index d57b5a34b..3bf4d4b07 100644 --- a/specutils/spectra/spectrum1d.py +++ b/specutils/spectra/spectrum1d.py @@ -3,9 +3,11 @@ import numpy as np from astropy import units as u +from astropy.coordinates import SpectralCoord from astropy.utils.decorators import lazyproperty from astropy.utils.decorators import deprecated from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin +from gwcs.wcs import WCS as GWCS from .spectral_axis import SpectralAxis from .spectrum_mixin import OneDSpectrumMixin @@ -228,24 +230,34 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None, f"of the corresponding flux axis ({flux.shape[self.spectral_axis_index]})") # If a WCS is provided, determine which axis is the spectral axis - if wcs is not None and hasattr(wcs, "naxis"): - if wcs.naxis > 1: + if wcs is not None: + naxis = None + if hasattr(wcs, "naxis"): + naxis = wcs.naxis + # GWCS doesn't have naxis + elif hasattr(wcs, "world_n_dim"): + naxis = wcs.world_n_dim + + if naxis is not None and naxis > 1: temp_axes = [] phys_axes = wcs.world_axis_physical_types - for i in range(len(phys_axes)): - if phys_axes[i] is None: - continue - if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect": - temp_axes.append(i) - if len(temp_axes) != 1: - raise ValueError("Input WCS must have exactly one axis with " - "spectral units, found {}".format(len(temp_axes))) - else: - # Due to FITS conventions, the WCS axes are listed in opposite - # order compared to the data array. - self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1 + if self._spectral_axis_index is None: + for i in range(len(phys_axes)): + if phys_axes[i] is None: + continue + if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect": + temp_axes.append(i) + if len(temp_axes) != 1: + raise ValueError("Input WCS must have exactly one axis with " + "spectral units, found {}".format(len(temp_axes))) + else: + # Due to FITS conventions, the WCS axes are listed in opposite + # order compared to the data array. + self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1 if move_spectral_axis is not None: + if isinstance(wcs, GWCS): + raise ValueError("move_spectral_axis cannot be used with GWCS") if isinstance(move_spectral_axis, str): if move_spectral_axis.lower() == 'first': move_to_index = 0 @@ -353,8 +365,23 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None, spec_axis = self.wcs.spectral.pixel_to_world( np.arange(self.flux.shape[self.spectral_axis_index])) else: - spec_axis = self.wcs.pixel_to_world( - np.arange(self.flux.shape[self.spectral_axis_index])) + # We now keep the entire GWCS, including spatial information, so we need to include + # all axes in the pixel_to_world call. Note that this assumes/requires that the + # dispersion is the same at all spatial locations. + wcs_args = [] + for i in range(len(self.flux.shape)): + wcs_args.append(np.zeros(self.flux.shape[self.spectral_axis_index])) + # Replace with arange for the spectral axis + wcs_args[self.spectral_axis_index] = np.arange(self.flux.shape[self.spectral_axis_index]) + wcs_args.reverse() + temp_coords = self.wcs.pixel_to_world(*wcs_args) + # If there are spatial axes, temp_coords will have a SkyCoord and a SpectralCoord + if isinstance(temp_coords, list): + for coords in temp_coords: + if isinstance(coords, SpectralCoord): + spec_axis = coords + else: + spec_axis = temp_coords try: if spec_axis.unit.is_equivalent(u.one):