From a3a32b47bc2cb3959118ae99e21c329a77c347f6 Mon Sep 17 00:00:00 2001 From: Ricky O'Steen <39831871+rosteen@users.noreply.github.com> Date: Mon, 28 Aug 2023 10:33:20 -0400 Subject: [PATCH] Improved GWCS handling in Spectrum1D (#1074) * Allow spectral axis to be anywhere, instead of forcing it to be last (#1033) * Starting to work on flexible spectral axis location Debugging initial spectrum creation Set private attribute here Working on debugging failing tests More things are temporarily broken, but I don't want to lose this work so I'm committing here Set spectral axis index to 0 if flux is None Working through test failures Fix codestyle Allow passing spectral_axis_index to wcs_fits loader Require specification of spectral_axis_index if WCS is 1D and flux is multi-D Decrement spectral_axis_index when slicing with integers Propagate spectral_axis_index through resampling Fix last test to account for spectral axis staying first Fix codestyle Specify spectral_axis_index in SDSS plate loader Greatly simply extract_bounding_spectral_region Account for variable spectral axis location in moment calculation, fix doc example Working on SpectrumCollection moment handling...not sure this is the way Need to add one to the axis index here Update narrative docs to reflect updates * Add back in the option to move the spectral axis to last, for back-compatibility Work around pixel unit slicing failure Change order on crop example Fix spectral slice handling in tuple input case (e.g. crop) Update output of crop example * Apply suggestions from code review Co-authored-by: Adam Ginsburg Apply suggestion from code review Add helpful comment * Address review comment about move_spectral_axis, more docs * Add suggested line to docstring Co-authored-by: Erik Tollerud * Add convenience method Make this a docstring * Add v2.0.0 changelog section --------- Co-authored-by: Erik Tollerud * Prepare changelog for 1.10.0 release * Fix Changelog * Fixed issues with ndcube 2.1 docs * Fix incorrect fluxes and uncertainties returned by FluxConservingResampler, increase computation speed (#1060) * new implementation of flux conserving resample * removed unused method * handle multi dimensional flux inputs * . * Update CHANGES.rst Co-authored-by: Erik Tollerud * omit removing units * added test to compare output to output from running SpectRes --------- Co-authored-by: Erik Tollerud * Update changelog for 1.11.0 release * Changelog back to unreleased * Working on retaining full GWCS information in Spectrum1D rather than just spectral coords * Handle getting the spectral axis out of a GWCS Add changelog heading Remove debugging prints Fix changelog Fix codestyle * Add changelog entry * Delete the commented-out old wavelength parsing code * More accurate changelog --------- Co-authored-by: Erik Tollerud Co-authored-by: Nabil Freij Co-authored-by: Clare Shanahan --- CHANGES.rst | 5 ++ specutils/io/default_loaders/jwst_reader.py | 44 ++------------ .../default_loaders/tests/test_jwst_reader.py | 2 +- specutils/spectra/spectrum1d.py | 59 ++++++++++++++----- 4 files changed, 54 insertions(+), 56 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index d66979614..e05078dfb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,11 @@ New Features - Spectral axis can now be any axis, rather than being forced to be last. See docs for more details. [#1033] +- Spectrum1D now properly handles GWCS input for wcs attribute. [#1074] + +- JWST reader no longer transposes the input data cube for 3D data and retains + full GWCS information (including spatial). [#1074] + Other Changes and Additions ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/specutils/io/default_loaders/jwst_reader.py b/specutils/io/default_loaders/jwst_reader.py index 423449c05..05e5b0e4a 100644 --- a/specutils/io/default_loaders/jwst_reader.py +++ b/specutils/io/default_loaders/jwst_reader.py @@ -8,8 +8,6 @@ from astropy.table import Table from astropy.io import fits from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance -from astropy.time import Time -from astropy.wcs import WCS from gwcs.wcstools import grid_from_bounding_box from ...spectra import Spectrum1D, SpectrumList @@ -579,38 +577,9 @@ def _jwst_s3d_loader(filename, **kwargs): except (ValueError, KeyError): flux_unit = None - # The spectral axis is first. We need it last - flux_array = hdu.data.T + flux_array = hdu.data flux = Quantity(flux_array, unit=flux_unit) - # Get the wavelength array from the GWCS object which returns a - # tuple of (RA, Dec, lambda). - # Since the spatial and spectral axes are orthogonal in s3d data, - # it is much faster to compute a slice down the spectral axis. - grid = grid_from_bounding_box(wcs.bounding_box)[:, :, 0, 0] - _, _, wavelength_array = wcs(*grid) - _, _, wavelength_unit = wcs.output_frame.unit - - wavelength = Quantity(wavelength_array, unit=wavelength_unit) - - # The GWCS is currently broken for some IFUs, here we work around that - wcs = None - if wavelength.shape[0] != flux.shape[-1]: - # Need MJD-OBS for this workaround - if 'MJD-OBS' not in hdu.header: - for key in ('MJD-BEG', 'DATE-OBS'): # Possible alternatives - if key in hdu.header: - if key.startswith('MJD'): - hdu.header['MJD-OBS'] = hdu.header[key] - break - else: - t = Time(hdu.header[key]) - hdu.header['MJD-OBS'] = t.mjd - break - wcs = WCS(hdu.header) - # Swap to match the flux transpose - wcs = wcs.swapaxes(-1, 0) - # Merge primary and slit headers and dump into meta slit_header = hdu.header header = primary_header.copy() @@ -621,7 +590,7 @@ def _jwst_s3d_loader(filename, **kwargs): ext_name = primary_header.get("ERREXT", "ERR") err_type = hdulist[ext_name].header.get("ERRTYPE", 'ERR') err_unit = hdulist[ext_name].header.get("BUNIT", None) - err_array = hdulist[ext_name].data.T + err_array = hdulist[ext_name].data # ERRTYPE can be one of "ERR", "IERR", "VAR", "IVAR" # but mostly ERR for JWST cubes @@ -639,13 +608,10 @@ def _jwst_s3d_loader(filename, **kwargs): # get mask information mask_name = primary_header.get("MASKEXT", "DQ") - mask = hdulist[mask_name].data.T + mask = hdulist[mask_name].data + + spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask, spectral_axis_index=0) - if wcs is not None: - spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask) - else: - spec = Spectrum1D(flux=flux, spectral_axis=wavelength, meta=meta, - uncertainty=err, mask=mask) spectra.append(spec) return SpectrumList(spectra) diff --git a/specutils/io/default_loaders/tests/test_jwst_reader.py b/specutils/io/default_loaders/tests/test_jwst_reader.py index 80b6d7799..50523b2f7 100644 --- a/specutils/io/default_loaders/tests/test_jwst_reader.py +++ b/specutils/io/default_loaders/tests/test_jwst_reader.py @@ -434,7 +434,7 @@ def test_jwst_s3d_single(tmp_path, cube): data = Spectrum1D.read(tmpfile, format='JWST s3d') assert type(data) is Spectrum1D - assert data.shape == (10, 10, 30) + assert data.shape == (30, 10, 10) assert data.uncertainty is not None assert data.mask is not None assert data.uncertainty.unit == 'MJy' diff --git a/specutils/spectra/spectrum1d.py b/specutils/spectra/spectrum1d.py index d57b5a34b..3bf4d4b07 100644 --- a/specutils/spectra/spectrum1d.py +++ b/specutils/spectra/spectrum1d.py @@ -3,9 +3,11 @@ import numpy as np from astropy import units as u +from astropy.coordinates import SpectralCoord from astropy.utils.decorators import lazyproperty from astropy.utils.decorators import deprecated from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin +from gwcs.wcs import WCS as GWCS from .spectral_axis import SpectralAxis from .spectrum_mixin import OneDSpectrumMixin @@ -228,24 +230,34 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None, f"of the corresponding flux axis ({flux.shape[self.spectral_axis_index]})") # If a WCS is provided, determine which axis is the spectral axis - if wcs is not None and hasattr(wcs, "naxis"): - if wcs.naxis > 1: + if wcs is not None: + naxis = None + if hasattr(wcs, "naxis"): + naxis = wcs.naxis + # GWCS doesn't have naxis + elif hasattr(wcs, "world_n_dim"): + naxis = wcs.world_n_dim + + if naxis is not None and naxis > 1: temp_axes = [] phys_axes = wcs.world_axis_physical_types - for i in range(len(phys_axes)): - if phys_axes[i] is None: - continue - if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect": - temp_axes.append(i) - if len(temp_axes) != 1: - raise ValueError("Input WCS must have exactly one axis with " - "spectral units, found {}".format(len(temp_axes))) - else: - # Due to FITS conventions, the WCS axes are listed in opposite - # order compared to the data array. - self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1 + if self._spectral_axis_index is None: + for i in range(len(phys_axes)): + if phys_axes[i] is None: + continue + if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect": + temp_axes.append(i) + if len(temp_axes) != 1: + raise ValueError("Input WCS must have exactly one axis with " + "spectral units, found {}".format(len(temp_axes))) + else: + # Due to FITS conventions, the WCS axes are listed in opposite + # order compared to the data array. + self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1 if move_spectral_axis is not None: + if isinstance(wcs, GWCS): + raise ValueError("move_spectral_axis cannot be used with GWCS") if isinstance(move_spectral_axis, str): if move_spectral_axis.lower() == 'first': move_to_index = 0 @@ -353,8 +365,23 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None, spec_axis = self.wcs.spectral.pixel_to_world( np.arange(self.flux.shape[self.spectral_axis_index])) else: - spec_axis = self.wcs.pixel_to_world( - np.arange(self.flux.shape[self.spectral_axis_index])) + # We now keep the entire GWCS, including spatial information, so we need to include + # all axes in the pixel_to_world call. Note that this assumes/requires that the + # dispersion is the same at all spatial locations. + wcs_args = [] + for i in range(len(self.flux.shape)): + wcs_args.append(np.zeros(self.flux.shape[self.spectral_axis_index])) + # Replace with arange for the spectral axis + wcs_args[self.spectral_axis_index] = np.arange(self.flux.shape[self.spectral_axis_index]) + wcs_args.reverse() + temp_coords = self.wcs.pixel_to_world(*wcs_args) + # If there are spatial axes, temp_coords will have a SkyCoord and a SpectralCoord + if isinstance(temp_coords, list): + for coords in temp_coords: + if isinstance(coords, SpectralCoord): + spec_axis = coords + else: + spec_axis = temp_coords try: if spec_axis.unit.is_equivalent(u.one):