From e60cf2a969d3374995b42c5a426561c7d421df23 Mon Sep 17 00:00:00 2001 From: Ricky O'Steen <39831871+rosteen@users.noreply.github.com> Date: Mon, 27 Mar 2023 12:16:41 -0400 Subject: [PATCH] Allow spectral axis to be anywhere, instead of forcing it to be last (#1033) * Starting to work on flexible spectral axis location Debugging initial spectrum creation Set private attribute here Working on debugging failing tests More things are temporarily broken, but I don't want to lose this work so I'm committing here Set spectral axis index to 0 if flux is None Working through test failures Fix codestyle Allow passing spectral_axis_index to wcs_fits loader Require specification of spectral_axis_index if WCS is 1D and flux is multi-D Decrement spectral_axis_index when slicing with integers Propagate spectral_axis_index through resampling Fix last test to account for spectral axis staying first Fix codestyle Specify spectral_axis_index in SDSS plate loader Greatly simply extract_bounding_spectral_region Account for variable spectral axis location in moment calculation, fix doc example Working on SpectrumCollection moment handling...not sure this is the way Need to add one to the axis index here Update narrative docs to reflect updates * Add back in the option to move the spectral axis to last, for back-compatibility Work around pixel unit slicing failure Change order on crop example Fix spectral slice handling in tuple input case (e.g. crop) Update output of crop example * Apply suggestions from code review Co-authored-by: Adam Ginsburg Apply suggestion from code review Add helpful comment * Address review comment about move_spectral_axis, more docs * Add suggested line to docstring Co-authored-by: Erik Tollerud * Add convenience method Make this a docstring * Add v2.0.0 changelog section --------- Co-authored-by: Erik Tollerud --- CHANGES.rst | 8 +- docs/index.rst | 20 +- docs/spectral_cube.rst | 11 +- docs/spectrum1d.rst | 28 +- specutils/analysis/moment.py | 30 ++- specutils/fitting/fitmodels.py | 2 +- specutils/io/default_loaders/sdss.py | 3 +- specutils/io/default_loaders/wcs_fits.py | 5 +- .../manipulation/extract_spectral_region.py | 41 +-- specutils/manipulation/resample.py | 9 +- specutils/spectra/spectrum1d.py | 241 +++++++++++++----- specutils/spectra/spectrum_mixin.py | 7 +- specutils/tests/test_loaders.py | 10 +- specutils/tests/test_region_extract.py | 3 +- specutils/tests/test_spectrum1d.py | 31 ++- 15 files changed, 304 insertions(+), 145 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 0c8248dcb..d66979614 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,11 +1,11 @@ -1.14.0 (unreleased) -------------------- +2.0.0 (unreleased) +------------------ New Features ^^^^^^^^^^^^ -Bug Fixes -^^^^^^^^^ +- Spectral axis can now be any axis, rather than being forced to be last. See docs + for more details. [#1033] Other Changes and Additions ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/index.rst b/docs/index.rst index 516c6ad12..12f9f3cba 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -23,10 +23,22 @@ details about the underlying principles, see `APE13 `_, the guiding document for spectroscopic development in the Astropy Project. -.. note:: - While specutils is available for general use, the API is in an early enough - development stage that some interfaces may change if user feedback and - experience warrants it. + +Changes in version 2 +==================== + +Specutils version 2 implemented a major change in that `~specutils.Spectrum1D` +no longer forces the spectral axis to be last for multi-dimensional data. This +was motivated by the desire for greater flexibility to allow for interoperability +with other packages that may wish to use ``specutils`` classes as the basis for +their own, and by the desire for consistency with the axis order that results +from a simple ``astropy.io.fits.read`` of a file. The legacy behavior can be +replicated by setting ``move_spectral_axis='last'`` when creating a new +`~specutils.Spectrum1D` object. + +For a summary of other changes in version 2, please see the +`release notes `_. + Getting started with :ref:`specutils ` ================================================= diff --git a/docs/spectral_cube.rst b/docs/spectral_cube.rst index 8a8d4bd8b..6b548bb01 100644 --- a/docs/spectral_cube.rst +++ b/docs/spectral_cube.rst @@ -41,7 +41,7 @@ The cube has 74x74 spaxels with 4563 spectral axis points in each one: .. code-block:: python >>> sc.shape # doctest: +REMOTE_DATA - (74, 74, 4563) + (4563, 74, 74) Print the contents of 3 spectral axis points in a 3x3 spaxel array: @@ -104,9 +104,8 @@ Moments ======= The `~specutils.analysis.moment` function can be used to compute moments of any order -along one of the cube's axes. By default, ``axis=-1``, which computes moments -along the spectral axis (remember that the spectral axis is always last in a -:class:`~specutils.Spectrum1D`). +along one of the cube's axes. By default, ``axis='spectral'``, in which case the moment +is computed along the spectral axis. .. code-block:: python @@ -153,8 +152,8 @@ cube, using `~specutils.manipulation.spectral_slab` and # Convert flux density to microJy and correct negative flux offset for # this particular dataset - ha_flux = (np.sum(subspec.flux.value, axis=(0,1)) + 0.0093) * 1.0E-6*u.Jy - ha_flux_wide = (np.sum(subspec_wide.flux.value, axis=(0,1)) + 0.0093) * 1.0E-6*u.Jy + ha_flux = (np.sum(subspec.flux.value, axis=(1,2)) + 0.0093) * 1.0E-6*u.Jy + ha_flux_wide = (np.sum(subspec_wide.flux.value, axis=(1,2)) + 0.0093) * 1.0E-6*u.Jy # Compute moment maps for H-alpha line moment0_halpha = moment(subspec, order=0) diff --git a/docs/spectrum1d.rst b/docs/spectrum1d.rst index f844a5fe5..c3bb8dcc2 100644 --- a/docs/spectrum1d.rst +++ b/docs/spectrum1d.rst @@ -231,19 +231,32 @@ Providing a FITS-style WCS >>> spec.wcs.pixel_to_world(np.arange(3)) # doctest: +FLOAT_CMP +When creating a `~specutils.Spectrum1D` using a WCS, you can also use the +``move_spectral_axis`` argument to force the spectral axis to a certain dimension +of a multi-dimenasional flux array. Prior to ``specutils`` version 2.0, the flux +array was always reordered such that the spectral axis corresponded to the last +flux axis - this behavior can be reproduced by setting ``move_spectral_axis=-1`` +or ``move_spectral_axis='last'``. Note that the relevant axes in the flux, mask, +and uncertainty arrays are simply swapped, and the swap is also reflected in the +resulting WCS. No check is currently done to ensure that the resulting array has +the spatial axes (most often RA and Dec) in any particular order. Multi-dimensional Data Sets --------------------------- `~specutils.Spectrum1D` also supports the multidimensional case where you -have, say, an ``(n_spectra, n_pix)`` +have, for example, an ``(n_spectra, n_pix)`` shaped data set where each ``n_spectra`` element provides a different flux -data array and so ``flux`` and ``uncertainty`` may be multidimensional as -long as the last dimension matches the shape of spectral_axis This is meant +data array. ``flux`` and ``uncertainty`` may be multidimensional as +long as one dimension matches the shape of the spectral_axis. This is meant to allow fast operations on collections of spectra that share the same ``spectral_axis``. While it may seem to conflict with the ā€œ1Dā€ in the class name, this name scheme is meant to communicate the presence of a single -common spectral axis. +common spectral axis. In cases where the flux axis corresponding to the spectral +axis cannot be determined automatically (for example, if multiple flux axes +have the same length as the spectral axis), the spectral axis must be specified +with the ``spectral_axis_index`` argument when initializing the +`~specutils.Spectrum1D`. .. note:: The case where each flux data array is related to a *different* spectral axis is encapsulated in the :class:`~specutils.SpectrumCollection` @@ -263,8 +276,7 @@ common spectral axis. 0.33281393, 0.59830875, 0.18673419, 0.67275604, 0.94180287] Jy> While the above example only shows two dimensions, this concept generalizes to -any number of dimensions for `~specutils.Spectrum1D`, as long as the spectral -axis is always the last. +any number of dimensions for `~specutils.Spectrum1D`. Slicing @@ -323,8 +335,8 @@ value will apply to the lower bound input. ... 'SPECSYS': 'BARYCENT', 'RADESYS': 'ICRS', 'EQUINOX': 2000.0, ... 'LONPOLE': 180.0, 'LATPOLE': 27.004754}) >>> spec = Spectrum1D(flux=np.random.default_rng(12345).random((20, 5, 10)) * u.Jy, wcs=w) # doctest: +IGNORE_WARNINGS - >>> lower = [SpectralCoord(4.9, unit=u.um), SkyCoord(ra=205, dec=26, unit=u.deg)] - >>> upper = [SpectralCoord(4.9, unit=u.um), SkyCoord(ra=205.5, dec=27.5, unit=u.deg)] + >>> lower = [SkyCoord(ra=205, dec=26, unit=u.deg), SpectralCoord(4.9, unit=u.um)] + >>> upper = [SkyCoord(ra=205.5, dec=27.5, unit=u.deg), SpectralCoord(4.9, unit=u.um)] >>> spec.crop(lower, upper) # doctest: +IGNORE_WARNINGS +FLOAT_CMP len(spectral_axis.shape): - _shape = flux.shape[:-1] + (1,) - dispersion = np.tile(spectral_axis, _shape) + for i in range(flux.ndim): + if i != calc_spectrum.spectral_axis_index: + dispersion = np.expand_dims(dispersion, i) + dispersion = np.repeat(dispersion, flux.shape[i], i) if order == 1: return np.sum(flux * dispersion, axis=axis) / np.sum(flux, axis=axis) diff --git a/specutils/fitting/fitmodels.py b/specutils/fitting/fitmodels.py index 66f19de49..5f80853b6 100644 --- a/specutils/fitting/fitmodels.py +++ b/specutils/fitting/fitmodels.py @@ -276,7 +276,7 @@ def fit_lines(spectrum, model, fitter=fitting.LevMarLSQFitter(calc_uncertainties exclude_regions : list of `~specutils.SpectralRegion` List of regions to exclude in the fitting. weights : array-like or 'unc', optional - If 'unc', the unceratinties from the spectrum object are used to + If 'unc', the uncertainties from the spectrum object are used to to calculate the weights. If array-like, represents the weights to use in the fitting. Note that if a mask is present on the spectrum, it will be applied to the ``weights`` as it would be to the spectrum diff --git a/specutils/io/default_loaders/sdss.py b/specutils/io/default_loaders/sdss.py index 261034141..97ccd095a 100644 --- a/specutils/io/default_loaders/sdss.py +++ b/specutils/io/default_loaders/sdss.py @@ -263,4 +263,5 @@ def spPlate_loader(file_obj, limit=None, **kwargs): wcs=fixed_wcs, uncertainty=uncertainty, meta=meta, - mask=mask) + mask=mask, + spectral_axis_index=-1) diff --git a/specutils/io/default_loaders/wcs_fits.py b/specutils/io/default_loaders/wcs_fits.py index 9ac9d435c..3055d439b 100644 --- a/specutils/io/default_loaders/wcs_fits.py +++ b/specutils/io/default_loaders/wcs_fits.py @@ -105,6 +105,8 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, if verbose: print("Spectrum file looks like wcs1d-fits") + spectral_axis_index = kwargs.get("spectral_axis_index") + with read_fileobj_or_hdulist(file_obj, **kwargs) as hdulist: if hdu is None: for ext in ('FLUX', 'SCI', 'DATA', 'PRIMARY'): @@ -212,7 +214,8 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, if wcs.naxis > 4: raise ValueError('FITS file input to wcs1d_fits_loader is > 4D') - return Spectrum1D(flux=data, wcs=wcs, mask=mask, uncertainty=uncertainty, meta=meta) + return Spectrum1D(flux=data, wcs=wcs, mask=mask, uncertainty=uncertainty, + meta=meta, spectral_axis_index=spectral_axis_index) @custom_writer("wcs1d-fits") diff --git a/specutils/manipulation/extract_spectral_region.py b/specutils/manipulation/extract_spectral_region.py index 720754130..05793cec1 100644 --- a/specutils/manipulation/extract_spectral_region.py +++ b/specutils/manipulation/extract_spectral_region.py @@ -1,5 +1,3 @@ -import sys - from math import floor, ceil # faster than int(np.floor/ceil(float)) import numpy as np @@ -162,7 +160,13 @@ def extract_region(spectrum, region, return_single_spectrum=False): flux=[]*spectrum.flux.unit) extracted_spectrum.append(empty_spectrum) else: - extracted_spectrum.append(spectrum[..., left_index:right_index]) + slices = [slice(None),] * len(spectrum.shape) + slices[spectrum.spectral_axis_index] = slice(left_index, right_index) + if len(slices) == 1: + slices = slices[0] + else: + slices = tuple(slices) + extracted_spectrum.append(spectrum[slices]) # If there is only one subregion in the region then we will # just return a spectrum. @@ -270,33 +274,10 @@ def extract_bounding_spectral_region(spectrum, region): if len(region) == 1: return extract_region(spectrum, region) - min_left = sys.maxsize - max_right = -sys.maxsize - 1 - # Look for indices that bound the entire set of sub-regions. - index_list = [_subregion_to_edge_pixels(sr, spectrum) for sr in region._subregions] - - for left_index, right_index in index_list: - if left_index is not None: - min_left = min(left_index, min_left) - if right_index is not None: - max_right = max(right_index, max_right) - - # If both indices are out of bounds then return an empty spectrum - if min_left is None and max_right is None: - empty_spectrum = Spectrum1D(spectral_axis=[]*spectrum.spectral_axis.unit, - flux=[]*spectrum.flux.unit) - return empty_spectrum - else: - # If only one index is out of bounds then set it to - # the lower or upper extent - if min_left is None: - min_left = 0 - - if max_right is None: - max_right = len(spectrum.spectral_axis) + min_list = [min(sr) for sr in region._subregions] + max_list = [max(sr) for sr in region._subregions] - if min_left > max_right: - min_left, max_right = max_right, min_left + single_region = SpectralRegion(min(min_list), max(max_list)) - return spectrum[..., min_left:max_right] + return extract_region(spectrum, single_region) diff --git a/specutils/manipulation/resample.py b/specutils/manipulation/resample.py index e208d2d14..af6e642c3 100644 --- a/specutils/manipulation/resample.py +++ b/specutils/manipulation/resample.py @@ -300,7 +300,8 @@ def resample1d(self, orig_spectrum, fin_spec_axis): resampled_spectrum = Spectrum1D(flux=output_fluxes, spectral_axis=fin_spec_axis, - uncertainty=new_errs) + uncertainty=new_errs, + spectral_axis_index = orig_spectrum.spectral_axis_index) return resampled_spectrum @@ -383,7 +384,8 @@ def resample1d(self, orig_spectrum, fin_spec_axis): return Spectrum1D(spectral_axis=fin_spec_axis, flux=out_flux, - uncertainty=new_unc) + uncertainty=new_unc, + spectral_axis_index = orig_spectrum.spectral_axis_index) class SplineInterpolatedResampler(ResamplerBase): @@ -475,4 +477,5 @@ def resample1d(self, orig_spectrum, fin_spec_axis): return Spectrum1D(spectral_axis=fin_spec_axis, flux=out_flux_val*orig_spectrum.flux.unit, - uncertainty=new_unc) + uncertainty=new_unc, + spectral_axis_index = orig_spectrum.spectral_axis_index) diff --git a/specutils/spectra/spectrum1d.py b/specutils/spectra/spectrum1d.py index 7702b58f5..d57b5a34b 100644 --- a/specutils/spectra/spectrum1d.py +++ b/specutils/spectra/spectrum1d.py @@ -42,6 +42,16 @@ class Spectrum1D(OneDSpectrumMixin, NDCube, NDIOMixin, NDArithmeticMixin): Dispersion information with the same shape as the last (or only) dimension of flux, or one greater than the last dimension of flux if specifying bin edges. + spectral_axis_index : integer, optional + If it is ambiguous which axis is the spectral axis (e.g., if there are multiple + axes in the flux array with the same length as the input spectral_axis), + this argument is used to specify which is the spectral axis. + move_spectral_axis : int, str, optional + Force the spectral axis to be either the last axis (the default behavior prior + to version 2.0) by setting this argument to 'last' or -1, or the first axis by + setting this argument to 'first' or 0. + This will do a simple ``swapaxis`` between the relevant axis and original + spectral axis. If None, the spectral axis is left wherever it is in the input. wcs : `~astropy.wcs.WCS` or `~gwcs.wcs.WCS` WCS information object that either has a spectral component or is only spectral. @@ -70,9 +80,25 @@ class Spectrum1D(OneDSpectrumMixin, NDCube, NDIOMixin, NDArithmeticMixin): Arbitrary container for any user-specific information to be carried around with the spectrum container object. """ - def __init__(self, flux=None, spectral_axis=None, wcs=None, - velocity_convention=None, rest_value=None, redshift=None, - radial_velocity=None, bin_specification=None, **kwargs): + def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None, + wcs=None, velocity_convention=None, rest_value=None, + redshift=None, radial_velocity=None, bin_specification=None, + move_spectral_axis=None, **kwargs): + + # If the flux (data) argument is already a Spectrum1D (as it would + # be for internal arithmetic operations), avoid setup entirely. + if isinstance(flux, Spectrum1D): + super().__init__(flux) + return + + self._spectral_axis_index = spectral_axis_index + # Might as well handle this right away + if spectral_axis_index is None and flux is not None: + if flux.ndim == 1: + self._spectral_axis_index = 0 + elif flux is None: + self._spectral_axis_index = 0 + # Check for pre-defined entries in the kwargs dictionary. unknown_kwargs = set(kwargs).difference( {'data', 'unit', 'uncertainty', 'meta', 'mask', 'copy', @@ -82,12 +108,6 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None, raise ValueError("Initializer contains unknown arguments(s): {}." "".format(', '.join(map(str, unknown_kwargs)))) - # If the flux (data) argument is already a Spectrum1D (as it would - # be for internal arithmetic operations), avoid setup entirely. - if isinstance(flux, Spectrum1D): - super().__init__(flux) - return - # Handle initializing from NDCube objects elif isinstance(flux, NDCube): if flux.unit is None: @@ -163,27 +183,51 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None, "energy/wavelength/frequency equivalent.") # If flux and spectral axis are both specified, check that their lengths - # match or are off by one (implying the spectral axis stores bin edges) + # match or are off by one (implying the spectral axis stores bin edges). + # If we can't determine which flux axis corresponds to the spectral axis + # we raise an error. if flux is not None and spectral_axis is not None: - if spectral_axis.shape[0] == flux.shape[-1]: + if spectral_axis_index is None: + if flux.ndim == 1: + self._spectral_axis_index = 0 + else: + matching_axes = [] + if bin_specification == "centers": + add_elements = [0,] + elif bin_specification == "edges": + add_elements = [1,] + elif bin_specification is None: + add_elements = [0,1] + for i in range(flux.ndim): + for add_element in add_elements: + if spectral_axis.shape[0] == flux.shape[i] + add_element: + matching_axes.append(i) + + if len(matching_axes) == 1: + self._spectral_axis_index = matching_axes[0] + else: + raise ValueError("Unable to determine which flux axis corresponds to " + "the spectral axis. Please specify spectral_axis_index" + " or provide a spectral_axis matching a flux axis.") + + # Make sure the length of the spectral axis matches the appropriate flux axis + if spectral_axis.shape[0] == flux.shape[self.spectral_axis_index]: if bin_specification == "edges": - raise ValueError("A spectral axis input as bin edges" - "must have length one greater than the flux axis") + raise ValueError("A spectral axis input as bin edges must " + "have length one greater than the flux axis") bin_specification = "centers" - elif spectral_axis.shape[0] == flux.shape[-1]+1: + elif spectral_axis.shape[0] == flux.shape[self.spectral_axis_index]+1: if bin_specification == "centers": - raise ValueError("A spectral axis input as bin centers" + raise ValueError("A spectral axis input as bin centers " "must be the same length as the flux axis") bin_specification = "edges" else: raise ValueError( - "Spectral axis length ({}) must be the same size or one " - "greater (if specifying bin edges) than that of the last " - "flux axis ({})".format(spectral_axis.shape[0], - flux.shape[-1])) + f"Spectral axis length ({spectral_axis.shape[0]}) must be the " + "same size or one greater (if specifying bin edges) than that " + f"of the corresponding flux axis ({flux.shape[self.spectral_axis_index]})") - # If a WCS is provided, check that the spectral axis is last and reorder - # the arrays if not + # If a WCS is provided, determine which axis is the spectral axis if wcs is not None and hasattr(wcs, "naxis"): if wcs.naxis > 1: temp_axes = [] @@ -196,33 +240,57 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None, if len(temp_axes) != 1: raise ValueError("Input WCS must have exactly one axis with " "spectral units, found {}".format(len(temp_axes))) + else: + # Due to FITS conventions, the WCS axes are listed in opposite + # order compared to the data array. + self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1 + + if move_spectral_axis is not None: + if isinstance(move_spectral_axis, str): + if move_spectral_axis.lower() == 'first': + move_to_index = 0 + elif move_spectral_axis.lower() == 'last': + move_to_index = wcs.naxis - 1 + else: + raise ValueError("move_spectral_axis must be either 'first' or 'last'") + elif isinstance(move_spectral_axis, int): + move_to_index = move_spectral_axis + else: + raise ValueError("move_spectral_axis must be an integer or 'first'/'last'") + + if move_to_index != self._spectral_axis_index: + wcs = wcs.swapaxes(self._spectral_axis_index, move_to_index) + if flux is not None: + flux = np.swapaxes(flux, self._spectral_axis_index, move_to_index) + if "mask" in kwargs: + if kwargs["mask"] is not None: + kwargs["mask"] = np.swapaxes(kwargs["mask"], + self._spectral_axis_index, move_to_index) + if "uncertainty" in kwargs: + if kwargs["uncertainty"] is not None: + if isinstance(kwargs["uncertainty"], NDUncertainty): + # Account for Astropy uncertainty types + temp_unc = np.swapaxes(kwargs["uncertainty"].array, + self._spectral_axis_index, move_to_index) + if kwargs["uncertainty"].unit is not None: + temp_unc = temp_unc * u.Unit(kwargs["uncertainty"].unit) + kwargs["uncertainty"] = type(kwargs["uncertainty"])(temp_unc) + else: + kwargs["uncertainty"] = np.swapaxes(kwargs["uncertainty"], + self._spectral_axis_index, move_to_index) + + self._spectral_axis_index = move_to_index - # Due to FITS conventions, a WCS with spectral axis first corresponds - # to a flux array with spectral axis last. - if temp_axes[0] != 0: - warnings.warn("Input WCS indicates that the spectral axis is not" - " last. Reshaping arrays to put spectral axis last.") - wcs = wcs.swapaxes(0, temp_axes[0]) - if flux is not None: - flux = np.swapaxes(flux, len(flux.shape)-temp_axes[0]-1, -1) - if "mask" in kwargs: - if kwargs["mask"] is not None: - kwargs["mask"] = np.swapaxes(kwargs["mask"], - len(kwargs["mask"].shape)-temp_axes[0]-1, -1) - if "uncertainty" in kwargs: - if kwargs["uncertainty"] is not None: - if isinstance(kwargs["uncertainty"], NDUncertainty): - # Account for Astropy uncertainty types - unc_len = len(kwargs["uncertainty"].array.shape) - temp_unc = np.swapaxes(kwargs["uncertainty"].array, - unc_len-temp_axes[0]-1, -1) - if kwargs["uncertainty"].unit is not None: - temp_unc = temp_unc * u.Unit(kwargs["uncertainty"].unit) - kwargs["uncertainty"] = type(kwargs["uncertainty"])(temp_unc) - else: - kwargs["uncertainty"] = np.swapaxes(kwargs["uncertainty"], - len(kwargs["uncertainty"].shape) - - temp_axes[0]-1, -1) + else: + if flux is not None and flux.ndim == 1: + self._spectral_axis_index = 0 + else: + if self.spectral_axis_index is None: + raise ValueError("WCS is 1D but flux is multi-dimensional. Please" + " specify spectral_axis_index.") + + elif move_spectral_axis is not None: + raise ValueError("Unable to use `move_spectral_axis` without a multi-dimensional WCS") # Attempt to parse the spectral axis. If none is given, try instead to # parse a given wcs. This is put into a GWCS object to @@ -236,10 +304,6 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None, # If spectral axis is provided as an astropy Quantity, convert it # to a specutils SpectralAxis object. if not isinstance(spectral_axis, SpectralAxis): - if spectral_axis.shape[0] == flux.shape[-1] + 1: - bin_specification = "edges" - else: - bin_specification = "centers" self._spectral_axis = SpectralAxis( spectral_axis, redshift=redshift, radial_velocity=radial_velocity, doppler_rest=rest_value, @@ -262,7 +326,13 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None, elif wcs is None: # If no spectral axis or wcs information is provided, initialize # with an empty gwcs based on the flux. - size = flux.shape[-1] if not flux.isscalar else 1 + if self.spectral_axis_index is None: + if flux.ndim == 1: + self._spectral_axis_index = 0 + else: + raise ValueError("Must specify spectral_axis_index if no WCS or spectral" + " axis is input.") + size = flux.shape[self.spectral_axis_index] if not flux.isscalar else 1 wcs = gwcs_from_array(np.arange(size) * u.Unit("")) super().__init__( @@ -277,11 +347,14 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None, if hasattr(self.wcs, "spectral"): # Handle generated 1D WCS that aren't set to spectral if not self.wcs.is_spectral and self.wcs.naxis == 1: - spec_axis = self.wcs.pixel_to_world(np.arange(self.flux.shape[-1])) + spec_axis = self.wcs.pixel_to_world( + np.arange(self.flux.shape[self.spectral_axis_index])) else: - spec_axis = self.wcs.spectral.pixel_to_world(np.arange(self.flux.shape[-1])) + spec_axis = self.wcs.spectral.pixel_to_world( + np.arange(self.flux.shape[self.spectral_axis_index])) else: - spec_axis = self.wcs.pixel_to_world(np.arange(self.flux.shape[-1])) + spec_axis = self.wcs.pixel_to_world( + np.arange(self.flux.shape[self.spectral_axis_index])) try: if spec_axis.unit.is_equivalent(u.one): @@ -322,22 +395,33 @@ def __getitem__(self, item): The first case is handled by the parent class, while the second is handled here. """ + new_spectral_axis_index = self.spectral_axis_index if self.flux.ndim > 1 or (isinstance(item, tuple) and item[0] is Ellipsis): if isinstance(item, tuple): - if len(item) == len(self.flux.shape) or item[0] is Ellipsis: - spec_item = item[-1] + if len(item) == self.flux.ndim or item[0] == Ellipsis: + spec_item = item[self.spectral_axis_index] if not isinstance(spec_item, slice): if isinstance(item, u.Quantity): raise ValueError("Indexing on single spectral axis " "values is not currently allowed, " "please use a slice.") spec_item = slice(spec_item, spec_item+1, None) - item = item[:-1] + (spec_item,) + # We have to hack around updating this tuple entry + item = list(item) + item[self.spectral_axis_index] = spec_item + item = tuple(item) else: # Slicing on less than the full number of axes means we want # to keep the whole spectral axis spec_item = slice(None, None, None) + # If any slices are single integers, need to decrement the spectral axis index + + for i in range(len(item)-1): + # Decrement spectral_axis_index for each single element slice + if isinstance(item[i], int): + new_spectral_axis_index -= 1 + elif isinstance(item, slice) and (isinstance(item.start, u.Quantity) or isinstance(item.stop, u.Quantity)): # We only allow slicing with world coordinates along the spectral @@ -353,7 +437,10 @@ def __getitem__(self, item): spec_item = item else: # Slicing with a single integer or slice uses the leading axis, - # so we keep the whole spectral axis, which is last + # so we keep the whole spectral axis, which is last. Need to decrement + # the spectral axis index by one + if isinstance(item, int): + new_spectral_axis_index -= 1 spec_item = slice(None, None, None) if (isinstance(spec_item.start, u.Quantity) or @@ -377,7 +464,7 @@ def __getitem__(self, item): uncertainty=self.uncertainty[item] if self.uncertainty is not None else None, mask=self.mask[item] if self.mask is not None else None, - meta=new_meta, wcs=None) + meta=new_meta, wcs=None, spectral_axis_index=new_spectral_axis_index) if not isinstance(item, slice): if isinstance(item, u.Quantity): @@ -394,6 +481,20 @@ def __getitem__(self, item): elif (isinstance(item.start, u.Quantity) or isinstance(item.stop, u.Quantity)): return self._spectral_slice(item) + # Work around error in SpectralCoord creation in super().__getitem__ for + # spectral axis in pixels. + if self.spectral_axis.unit == u.pix: + if "original_wcs" not in self.meta: + new_meta = deepcopy(self.meta) + new_meta["original_wcs"] = deepcopy(self.wcs) + return self._copy( + flux=self.flux[item], + spectral_axis=self.spectral_axis[item], + uncertainty=self.uncertainty[item] + if self.uncertainty is not None else None, + mask=self.mask[item] if self.mask is not None else None, + meta=new_meta, wcs=None, spectral_axis_index=new_spectral_axis_index) + tmp_spec = super().__getitem__(item) # TODO: this is a workaround until we figure out how to deal with non- @@ -428,7 +529,8 @@ def _copy(self, **kwargs): meta=deepcopy(self.meta), unit=deepcopy(self.unit), velocity_convention=deepcopy(self.velocity_convention), - rest_value=deepcopy(self.rest_value)) + rest_value=deepcopy(self.rest_value), + spectral_axis_index=deepcopy(self.spectral_axis_index)) alt_kwargs.update(kwargs) @@ -492,10 +594,11 @@ def collapse(self, method, axis=None): if isinstance(axis, str): if axis == 'spectral': - axis = -1 + axis = self.spectral_axis_index elif axis == 'spatial': # generate tuple if needed for multiple spatial axes - axis = tuple([x for x in range(len(self.flux.shape) - 1)]) + axis = tuple([x for x in range(len(self.flux.shape)) if + x != self.spectral_axis_index]) else: raise ValueError("String axis input must be 'spatial' or 'spectral'") @@ -512,9 +615,9 @@ def collapse(self, method, axis=None): collapsed_flux = collapse_funcs[method](flux_to_collapse, axis=axis) # Return a Spectrum1D if we collapsed over the spectral axis, a Quantity if not - if axis in (-1, None, len(self.flux.shape)-1): + if axis in (self.spectral_axis_index, None): return collapsed_flux - elif isinstance(axis, tuple) and -1 in axis: + elif isinstance(axis, tuple) and self.spectral_axis_index in axis: return collapsed_flux else: return Spectrum1D(collapsed_flux, wcs=self.wcs) @@ -670,6 +773,14 @@ def shift_spectrum_to(self, *, redshift=None, radial_velocity=None): else: raise ValueError("One of redshift or radial_velocity must be set.") + def with_spectral_axis_last(self): + """ + Convenience method to return a new copy of theSpectrum1D with the spectral axis last. + """ + return Spectrum1D(flux=self.flux, wcs=self.wcs, + mask=self.mask, uncertainty=self.uncertainty, + redshift=self.redshift, move_spectral_axis="last") + @redshift.setter @deprecated('1.8.0', alternative='set_redshift_to or shift_spectrum_to') def redshift(self, val): diff --git a/specutils/spectra/spectrum_mixin.py b/specutils/spectra/spectrum_mixin.py index 3f3001749..7c1b7c237 100644 --- a/specutils/spectra/spectrum_mixin.py +++ b/specutils/spectra/spectrum_mixin.py @@ -15,8 +15,8 @@ class OneDSpectrumMixin(): @property - def _spectral_axis_numpy_index(self): - return self.data.ndim - 1 - self.wcs.wcs.spec + def spectral_axis_index(self): + return self._spectral_axis_index @property def _spectral_axis_len(self): @@ -242,6 +242,9 @@ def with_spectral_axis_unit(self, unit, velocity_convention=None, return self.__class__(flux=self.flux, spectral_axis=new_spectral_axis, meta=meta, uncertainty=self.uncertainty, mask=self.mask) + def _axis_length_validation(self): + pass + def _new_wcs_argument_validation(self, unit, velocity_convention, rest_value): # Allow string specification of units, for example diff --git a/specutils/tests/test_loaders.py b/specutils/tests/test_loaders.py index 87bc7f17f..52d998e3a 100644 --- a/specutils/tests/test_loaders.py +++ b/specutils/tests/test_loaders.py @@ -149,7 +149,7 @@ def test_manga_cube(): assert isinstance(spec, Spectrum1D) assert spec.flux.size > 0 assert spec.meta['header']['INSTRUME'] == 'MaNGA' - assert spec.shape == (34, 34, 4563) + assert spec.shape == (4563, 34, 34) @pytest.mark.remote_data @@ -1081,11 +1081,11 @@ def test_wcs1d_fits_multid(tmp_path, spectral_axis): shape = [-1, 1] for i in range(2, 5): flux = flux * np.arange(i, i+5).reshape(*shape) - spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr)) + spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), spectral_axis_index=-1) tmpfile = tmp_path / f'wcs_{i}d.fits' spectrum.write(tmpfile, format='wcs1d-fits') - spec = Spectrum1D.read(tmpfile, format='wcs1d-fits') + spec = Spectrum1D.read(tmpfile, format='wcs1d-fits', spectral_axis_index=-1) assert spec.flux.ndim == i assert quantity_allclose(spec.spectral_axis, disp) assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis) @@ -1094,7 +1094,7 @@ def test_wcs1d_fits_multid(tmp_path, spectral_axis): # Test exception for NAXIS > 4 flux = flux * np.arange(i+1, i+6).reshape(*shape) - spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr)) + spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), spectral_axis_index=-1) tmpfile = tmp_path / f'wcs_{i+1}d.fits' spectrum.write(tmpfile, format='wcs1d-fits') @@ -1113,7 +1113,7 @@ def test_wcs1d_fits_non1d(tmp_path, spectral_axis): 'CRPIX1': 1, 'CRVAL1': 1, 'CDELT1': 0.01} # Create a small 2D data set flux = np.arange(1, 11)**2 * np.arange(4).reshape(-1, 1) * 1.e-14 * u.Jy - spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr)) + spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), spectral_axis_index=-1) tmpfile = tmp_path / f'wcs_{2}d.fits' spectrum.write(tmpfile, format='wcs1d-fits') diff --git a/specutils/tests/test_region_extract.py b/specutils/tests/test_region_extract.py index f40eb5b9f..3a7dc5f3d 100644 --- a/specutils/tests/test_region_extract.py +++ b/specutils/tests/test_region_extract.py @@ -128,7 +128,8 @@ def test_slab_simple(simulated_spectra): def test_slab_pixels(): range_iter = range(5) spectrum = Spectrum1D(flux=np.stack( - [np.zeros((2, 2)) + i for i in range_iter], axis=-1) * u.nJy) + [np.zeros((2, 2)) + i for i in range_iter], axis=-1) * u.nJy, + spectral_axis_index=2) for i in range_iter: sub_spectrum = spectral_slab(spectrum, i * u.pix, (i + 0.5) * u.pix) assert_quantity_allclose(sub_spectrum.flux, i * u.nJy) diff --git a/specutils/tests/test_spectrum1d.py b/specutils/tests/test_spectrum1d.py index 7b8e1aac4..78ca5c07b 100644 --- a/specutils/tests/test_spectrum1d.py +++ b/specutils/tests/test_spectrum1d.py @@ -5,6 +5,7 @@ import pytest from astropy.nddata import StdDevUncertainty from astropy.coordinates import SpectralCoord +from astropy.tests.helper import quantity_allclose from astropy.wcs import WCS from numpy.testing import assert_allclose @@ -121,12 +122,27 @@ def test_create_from_cube(): w = WCS(wcs_dict) spec = Spectrum1D(flux=flux, wcs=w) + spec_axis_from_wcs = (np.exp(np.array([1,2])*w.wcs.cdelt[-1]/w.wcs.crval[-1]) * + w.wcs.crval[-1]*spec.spectral_axis.unit) - assert spec.flux.shape == (4,3,2) - assert spec.flux[3,2,1] == 23*u.Jy - assert_allclose( - spec.spectral_axis.value, - np.exp(np.array([1, 2]) * w.wcs.cdelt[-1] / w.wcs.crval[-1]) * w.wcs.crval[-1]) + assert spec.flux.shape == (2,3,4) + assert spec.flux[1,2,3] == 23*u.Jy + assert quantity_allclose(spec.spectral_axis, spec_axis_from_wcs) + + with pytest.raises(ValueError): + spec2 = Spectrum1D(flux=flux, wcs=w, move_spectral_axis='Bad string') + + # Test moving spectral axis from first to last + spec2 = Spectrum1D(flux=flux, wcs=w, move_spectral_axis='last') + assert spec2.flux.shape == (4,3,2) + assert spec2.flux[3,2,1] == 23*u.Jy + assert quantity_allclose(spec2.spectral_axis, spec_axis_from_wcs) + + # Test moving spectral axis from last to first + spec3 = Spectrum1D(flux=spec2.flux, wcs=spec2.wcs, move_spectral_axis='first') + assert spec3.flux.shape == (2,3,4) + assert spec3.flux[1,2,3] == 23*u.Jy + assert quantity_allclose(spec3.spectral_axis, spec_axis_from_wcs) def test_spectral_axis_conversions(): @@ -192,7 +208,8 @@ def test_spectral_slice(): # Test higher dimensional slicing spec = Spectrum1D(spectral_axis=np.linspace(100, 1000, 10) * u.nm, - flux=np.random.random((10, 10)) * u.Jy) + flux=np.random.random((10, 10)) * u.Jy, + spectral_axis_index=1) sliced_spec = spec[300*u.nm:600*u.nm] assert np.all(sliced_spec.spectral_axis == [300, 400, 500] * u.nm) @@ -499,7 +516,7 @@ def test_collapse_flux(): flux = [[2,4,6], [0, 8, 12]] * u.Jy sa = [100,200,300]*u.um mask = [[False, True, False], [True, False, False]] - spec = Spectrum1D(flux, sa, mask=mask) + spec = Spectrum1D(flux, sa, mask=mask, spectral_axis_index=1) assert spec.mean() == 7 * u.Jy assert spec.max() == 12 * u.Jy