Skip to content

Commit

Permalink
Improved GWCS handling in Spectrum1D (astropy#1074)
Browse files Browse the repository at this point in the history
* Allow spectral axis to be anywhere, instead of forcing it to be last (astropy#1033)

* Starting to work on flexible spectral axis location

Debugging initial spectrum creation

Set private attribute here

Working on debugging failing tests

More things are temporarily broken, but I don't want to lose this work so I'm committing here

Set spectral axis index to 0 if flux is None

Working through test failures

Fix codestyle

Allow passing spectral_axis_index to wcs_fits loader

Require specification of spectral_axis_index if WCS is 1D and flux is multi-D

Decrement spectral_axis_index when slicing with integers

Propagate spectral_axis_index through resampling

Fix last test to account for spectral axis staying first

Fix codestyle

Specify spectral_axis_index in SDSS plate loader

Greatly simply extract_bounding_spectral_region

Account for variable spectral axis location in moment calculation, fix doc example

Working on SpectrumCollection moment handling...not sure this is the way

Need to add one to the axis index here

Update narrative docs to reflect updates

* Add back in the option to move the spectral axis to last, for back-compatibility

Work around pixel unit slicing failure

Change order on crop example

Fix spectral slice handling in tuple input case (e.g. crop)

Update output of crop example

* Apply suggestions from code review

Co-authored-by: Adam Ginsburg <keflavich@gmail.com>

Apply suggestion from code review

Add helpful comment

* Address review comment about move_spectral_axis, more docs

* Add suggested line to docstring

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* Add convenience method

Make this a docstring

* Add v2.0.0 changelog section

---------

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* Prepare changelog for 1.10.0 release

* Fix Changelog

* Fixed issues with ndcube 2.1 docs

* Fix incorrect fluxes and uncertainties returned by FluxConservingResampler, increase computation speed  (astropy#1060)

* new implementation of flux conserving resample

* removed unused method

* handle multi dimensional flux inputs

* .

* Update CHANGES.rst

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* omit removing units

* added test to compare output to output from running SpectRes

---------

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* Update changelog for 1.11.0 release

* Changelog back to unreleased

* Working on retaining full GWCS information in Spectrum1D rather than just spectral coords

* Handle getting the spectral axis out of a GWCS

Add changelog heading

Remove debugging prints

Fix changelog

Fix codestyle

* Add changelog entry

* Delete the commented-out old wavelength parsing code

* More accurate changelog

---------

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>
Co-authored-by: Nabil Freij <nabil.freij@gmail.com>
Co-authored-by: Clare Shanahan <cshanahan@stsci.edu>
  • Loading branch information
4 people committed Jun 18, 2024
1 parent 8471326 commit d7367f6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 56 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ New Features
- Spectral axis can now be any axis, rather than being forced to be last. See docs
for more details. [#1033]

- Spectrum1D now properly handles GWCS input for wcs attribute. [#1074]

- JWST reader no longer transposes the input data cube for 3D data and retains
full GWCS information (including spatial). [#1074]

Other Changes and Additions
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
44 changes: 5 additions & 39 deletions specutils/io/default_loaders/jwst_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from astropy.table import Table
from astropy.io import fits
from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance
from astropy.time import Time
from astropy.wcs import WCS
from gwcs.wcstools import grid_from_bounding_box

from ...spectra import Spectrum1D, SpectrumList
Expand Down Expand Up @@ -579,38 +577,9 @@ def _jwst_s3d_loader(filename, **kwargs):
except (ValueError, KeyError):
flux_unit = None

# The spectral axis is first. We need it last
flux_array = hdu.data.T
flux_array = hdu.data
flux = Quantity(flux_array, unit=flux_unit)

# Get the wavelength array from the GWCS object which returns a
# tuple of (RA, Dec, lambda).
# Since the spatial and spectral axes are orthogonal in s3d data,
# it is much faster to compute a slice down the spectral axis.
grid = grid_from_bounding_box(wcs.bounding_box)[:, :, 0, 0]
_, _, wavelength_array = wcs(*grid)
_, _, wavelength_unit = wcs.output_frame.unit

wavelength = Quantity(wavelength_array, unit=wavelength_unit)

# The GWCS is currently broken for some IFUs, here we work around that
wcs = None
if wavelength.shape[0] != flux.shape[-1]:
# Need MJD-OBS for this workaround
if 'MJD-OBS' not in hdu.header:
for key in ('MJD-BEG', 'DATE-OBS'): # Possible alternatives
if key in hdu.header:
if key.startswith('MJD'):
hdu.header['MJD-OBS'] = hdu.header[key]
break
else:
t = Time(hdu.header[key])
hdu.header['MJD-OBS'] = t.mjd
break
wcs = WCS(hdu.header)
# Swap to match the flux transpose
wcs = wcs.swapaxes(-1, 0)

# Merge primary and slit headers and dump into meta
slit_header = hdu.header
header = primary_header.copy()
Expand All @@ -621,7 +590,7 @@ def _jwst_s3d_loader(filename, **kwargs):
ext_name = primary_header.get("ERREXT", "ERR")
err_type = hdulist[ext_name].header.get("ERRTYPE", 'ERR')
err_unit = hdulist[ext_name].header.get("BUNIT", None)
err_array = hdulist[ext_name].data.T
err_array = hdulist[ext_name].data

# ERRTYPE can be one of "ERR", "IERR", "VAR", "IVAR"
# but mostly ERR for JWST cubes
Expand All @@ -639,13 +608,10 @@ def _jwst_s3d_loader(filename, **kwargs):

# get mask information
mask_name = primary_header.get("MASKEXT", "DQ")
mask = hdulist[mask_name].data.T
mask = hdulist[mask_name].data

spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask, spectral_axis_index=0)

if wcs is not None:
spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask)
else:
spec = Spectrum1D(flux=flux, spectral_axis=wavelength, meta=meta,
uncertainty=err, mask=mask)
spectra.append(spec)

return SpectrumList(spectra)
2 changes: 1 addition & 1 deletion specutils/io/default_loaders/tests/test_jwst_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def test_jwst_s3d_single(tmp_path, cube):

data = Spectrum1D.read(tmpfile, format='JWST s3d')
assert type(data) is Spectrum1D
assert data.shape == (10, 10, 30)
assert data.shape == (30, 10, 10)
assert data.uncertainty is not None
assert data.mask is not None
assert data.uncertainty.unit == 'MJy'
59 changes: 43 additions & 16 deletions specutils/spectra/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import numpy as np
from astropy import units as u
from astropy.coordinates import SpectralCoord
from astropy.utils.decorators import lazyproperty
from astropy.utils.decorators import deprecated
from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin
from gwcs.wcs import WCS as GWCS

from .spectral_axis import SpectralAxis
from .spectrum_mixin import OneDSpectrumMixin
Expand Down Expand Up @@ -228,24 +230,34 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
f"of the corresponding flux axis ({flux.shape[self.spectral_axis_index]})")

# If a WCS is provided, determine which axis is the spectral axis
if wcs is not None and hasattr(wcs, "naxis"):
if wcs.naxis > 1:
if wcs is not None:
naxis = None
if hasattr(wcs, "naxis"):
naxis = wcs.naxis
# GWCS doesn't have naxis
elif hasattr(wcs, "world_n_dim"):
naxis = wcs.world_n_dim

if naxis is not None and naxis > 1:
temp_axes = []
phys_axes = wcs.world_axis_physical_types
for i in range(len(phys_axes)):
if phys_axes[i] is None:
continue
if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect":
temp_axes.append(i)
if len(temp_axes) != 1:
raise ValueError("Input WCS must have exactly one axis with "
"spectral units, found {}".format(len(temp_axes)))
else:
# Due to FITS conventions, the WCS axes are listed in opposite
# order compared to the data array.
self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1
if self._spectral_axis_index is None:
for i in range(len(phys_axes)):
if phys_axes[i] is None:
continue
if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect":
temp_axes.append(i)
if len(temp_axes) != 1:
raise ValueError("Input WCS must have exactly one axis with "
"spectral units, found {}".format(len(temp_axes)))
else:
# Due to FITS conventions, the WCS axes are listed in opposite
# order compared to the data array.
self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1

if move_spectral_axis is not None:
if isinstance(wcs, GWCS):
raise ValueError("move_spectral_axis cannot be used with GWCS")
if isinstance(move_spectral_axis, str):
if move_spectral_axis.lower() == 'first':
move_to_index = 0
Expand Down Expand Up @@ -353,8 +365,23 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
spec_axis = self.wcs.spectral.pixel_to_world(
np.arange(self.flux.shape[self.spectral_axis_index]))
else:
spec_axis = self.wcs.pixel_to_world(
np.arange(self.flux.shape[self.spectral_axis_index]))
# We now keep the entire GWCS, including spatial information, so we need to include
# all axes in the pixel_to_world call. Note that this assumes/requires that the
# dispersion is the same at all spatial locations.
wcs_args = []
for i in range(len(self.flux.shape)):
wcs_args.append(np.zeros(self.flux.shape[self.spectral_axis_index]))
# Replace with arange for the spectral axis
wcs_args[self.spectral_axis_index] = np.arange(self.flux.shape[self.spectral_axis_index])
wcs_args.reverse()
temp_coords = self.wcs.pixel_to_world(*wcs_args)
# If there are spatial axes, temp_coords will have a SkyCoord and a SpectralCoord
if isinstance(temp_coords, list):
for coords in temp_coords:
if isinstance(coords, SpectralCoord):
spec_axis = coords
else:
spec_axis = temp_coords

try:
if spec_axis.unit.is_equivalent(u.one):
Expand Down

0 comments on commit d7367f6

Please sign in to comment.