diff --git a/specreduce/tests/test_wavelength_calibration.py b/specreduce/tests/test_wavelength_calibration.py index c7210da3..7accf9dd 100644 --- a/specreduce/tests/test_wavelength_calibration.py +++ b/specreduce/tests/test_wavelength_calibration.py @@ -109,3 +109,29 @@ def test_fit_residuals_access(spec1d): line_wavelengths=w) test.residuals test.wcs + + +def test_unsorted_pixels_wavelengths(spec1d): + # make sure an error is raised if input matched pixels/wavelengths are + # not strictly increasing or decreasing. + + centers = np.array([0, 10, 5, 30]) + w = (0.5 * centers + 2) * u.AA + + with pytest.raises(ValueError, match='Pixels must be strictly increasing or decreasing.'): + WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w) + + # now test that it fails when wavelengths are unsorted + centers = np.array([0, 10, 20, 30]) + w = np.array([2, 5, 6, 1]) * u.AA + with pytest.raises(ValueError, match='Wavelengths must be strictly increasing or decreasing.'): + WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w) + + # and same if those wavelengths are provided in a table + table = QTable([w], names=["wavelength"]) + with pytest.raises(ValueError, match='Wavelengths must be strictly increasing or decreasing.'): + WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=table) + + # and again with decreasing pixels but unsorted wavelengths + with pytest.raises(ValueError, match='Wavelengths must be strictly increasing or decreasing.'): + WavelengthCalibration1D(spec1d, line_pixels=centers[::-1], line_wavelengths=w) diff --git a/specreduce/wavelength_calibration.py b/specreduce/wavelength_calibration.py index 8c100e73..43696c92 100644 --- a/specreduce/wavelength_calibration.py +++ b/specreduce/wavelength_calibration.py @@ -27,6 +27,15 @@ def concatenate_catalogs(): pass +def _check_arr_monotonic(arr): + # returns True if ``arr`` is either strictly increasing or strictly + # decreasing, otherwise returns False. + + sorted_increasing = np.all(arr[1:] >= arr[:-1]) + sorted_decreasing = np.all(arr[1:] <= arr[:-1]) + return sorted_increasing or sorted_decreasing + + class WavelengthCalibration1D(): def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None, @@ -40,18 +49,16 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None, and wavelengths populated. line_pixels: list, array, `~astropy.table.QTable`, optional List or array of line pixel locations to anchor the wavelength solution fit. - Will be converted to an astropy table internally if a list or array was input. Can also be input as an `~astropy.table.QTable` table with (minimally) a column named "pixel_center". line_wavelengths: `~astropy.units.Quantity`, `~astropy.table.QTable`, optional `astropy.units.Quantity` array of line wavelength values corresponding to the - line pixels defined in ``line_list``. Does not have to be in the same order] - (the lists will be sorted) but does currently need to be the same length as - line_list. Can also be input as an `~astropy.table.QTable` with (minimally) + line pixels defined in ``line_list``, assumed to be in the same + order. Can also be input as an `~astropy.table.QTable` with (minimally) a "wavelength" column. catalog: list, str, `~astropy.table.QTable`, optional The name of a catalog of line wavelengths to load and use in automated and - template-matching line matching. + template-matching line matching. NOTE: This option is currently not implemented. input_model: `~astropy.modeling.Model` The model to fit for the wavelength solution. Defaults to a linear model. fitter: `~astropy.modeling.fitting.Fitter`, optional @@ -70,6 +77,9 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None, self._potential_wavelengths = None self._catalog = catalog + if not isinstance(input_spectrum, Spectrum1D): + raise ValueError('Input spectrum must be Spectrum1D.') + # ToDo: Implement having line catalogs self._available_catalogs = get_available_catalogs() @@ -95,9 +105,11 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None, if self._matched_line_list["pixel_center"].unit is None: self._matched_line_list["pixel_center"].unit = u.pix - # Make sure our pixel locations are sorted - self._matched_line_list.sort("pixel_center") + # check that pixels are monotonic + if not _check_arr_monotonic(self._matched_line_list["pixel_center"]): + raise ValueError('Pixels must be strictly increasing or decreasing.') + # now that pixels have been determined from input, figure out wavelengths. if (line_wavelengths is None and catalog is None and "wavelength" not in self._matched_line_list.columns): raise ValueError("You must specify at least one of line_wavelengths, " @@ -115,15 +127,20 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None, if not isinstance(line_wavelengths, (u.Quantity, QTable)): raise ValueError("line_wavelengths must be specified as an astropy.units.Quantity" " array or as an astropy.table.QTable") + + # make sure wavelengths (or freq) are monotonic and add wavelengths + # to _matched_line_list if isinstance(line_wavelengths, u.Quantity): - # Ensure frequency is descending or wavelength is ascending - if str(line_wavelengths.unit.physical_type) == "frequency": - line_wavelengths[::-1].sort() - else: - line_wavelengths.sort() + if not _check_arr_monotonic(line_wavelengths): + if str(line_wavelengths.unit.physical_type) == "frequency": + raise ValueError('Frequencies must be strictly increasing or decreasing.') + raise ValueError('Wavelengths must be strictly increasing or decreasing.') + self._matched_line_list["wavelength"] = line_wavelengths + elif isinstance(line_wavelengths, QTable): - line_wavelengths.sort("wavelength") + if not _check_arr_monotonic(line_wavelengths['wavelength']): + raise ValueError('Wavelengths must be strictly increasing or decreasing.') self._matched_line_list = hstack([self._matched_line_list, line_wavelengths]) # Parse desired catalogs of lines for matching.