Skip to content

Commit

Permalink
allow matched pixels/wavelengths to be in either direction (#177)
Browse files Browse the repository at this point in the history
* enforce spectral axis being left to right for wavelength calibration

* remove reference to specutils branch, add check here temporarily

* Update specreduce/wavelength_calibration.py

Co-authored-by: Ricky O'Steen <39831871+rosteen@users.noreply.github.com>

* .

* remove sort, add check for increasing/decreasing

* code style

* added test coverage

* review suggestion

* review suggestion

---------

Co-authored-by: Ricky O'Steen <39831871+rosteen@users.noreply.github.com>
  • Loading branch information
cshanahan1 and rosteen authored Jun 9, 2023
1 parent e0299cc commit 0d800fd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
26 changes: 26 additions & 0 deletions specreduce/tests/test_wavelength_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,29 @@ def test_fit_residuals_access(spec1d):
line_wavelengths=w)
test.residuals
test.wcs


def test_unsorted_pixels_wavelengths(spec1d):
# make sure an error is raised if input matched pixels/wavelengths are
# not strictly increasing or decreasing.

centers = np.array([0, 10, 5, 30])
w = (0.5 * centers + 2) * u.AA

with pytest.raises(ValueError, match='Pixels must be strictly increasing or decreasing.'):
WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w)

# now test that it fails when wavelengths are unsorted
centers = np.array([0, 10, 20, 30])
w = np.array([2, 5, 6, 1]) * u.AA
with pytest.raises(ValueError, match='Wavelengths must be strictly increasing or decreasing.'):
WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w)

# and same if those wavelengths are provided in a table
table = QTable([w], names=["wavelength"])
with pytest.raises(ValueError, match='Wavelengths must be strictly increasing or decreasing.'):
WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=table)

# and again with decreasing pixels but unsorted wavelengths
with pytest.raises(ValueError, match='Wavelengths must be strictly increasing or decreasing.'):
WavelengthCalibration1D(spec1d, line_pixels=centers[::-1], line_wavelengths=w)
43 changes: 30 additions & 13 deletions specreduce/wavelength_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def concatenate_catalogs():
pass


def _check_arr_monotonic(arr):
# returns True if ``arr`` is either strictly increasing or strictly
# decreasing, otherwise returns False.

sorted_increasing = np.all(arr[1:] >= arr[:-1])
sorted_decreasing = np.all(arr[1:] <= arr[:-1])
return sorted_increasing or sorted_decreasing


class WavelengthCalibration1D():

def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None,
Expand All @@ -40,18 +49,16 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None,
and wavelengths populated.
line_pixels: list, array, `~astropy.table.QTable`, optional
List or array of line pixel locations to anchor the wavelength solution fit.
Will be converted to an astropy table internally if a list or array was input.
Can also be input as an `~astropy.table.QTable` table with (minimally) a column
named "pixel_center".
line_wavelengths: `~astropy.units.Quantity`, `~astropy.table.QTable`, optional
`astropy.units.Quantity` array of line wavelength values corresponding to the
line pixels defined in ``line_list``. Does not have to be in the same order]
(the lists will be sorted) but does currently need to be the same length as
line_list. Can also be input as an `~astropy.table.QTable` with (minimally)
line pixels defined in ``line_list``, assumed to be in the same
order. Can also be input as an `~astropy.table.QTable` with (minimally)
a "wavelength" column.
catalog: list, str, `~astropy.table.QTable`, optional
The name of a catalog of line wavelengths to load and use in automated and
template-matching line matching.
template-matching line matching. NOTE: This option is currently not implemented.
input_model: `~astropy.modeling.Model`
The model to fit for the wavelength solution. Defaults to a linear model.
fitter: `~astropy.modeling.fitting.Fitter`, optional
Expand All @@ -70,6 +77,9 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None,
self._potential_wavelengths = None
self._catalog = catalog

if not isinstance(input_spectrum, Spectrum1D):
raise ValueError('Input spectrum must be Spectrum1D.')

# ToDo: Implement having line catalogs
self._available_catalogs = get_available_catalogs()

Expand All @@ -95,9 +105,11 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None,
if self._matched_line_list["pixel_center"].unit is None:
self._matched_line_list["pixel_center"].unit = u.pix

# Make sure our pixel locations are sorted
self._matched_line_list.sort("pixel_center")
# check that pixels are monotonic
if not _check_arr_monotonic(self._matched_line_list["pixel_center"]):
raise ValueError('Pixels must be strictly increasing or decreasing.')

# now that pixels have been determined from input, figure out wavelengths.
if (line_wavelengths is None and catalog is None
and "wavelength" not in self._matched_line_list.columns):
raise ValueError("You must specify at least one of line_wavelengths, "
Expand All @@ -115,15 +127,20 @@ def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None,
if not isinstance(line_wavelengths, (u.Quantity, QTable)):
raise ValueError("line_wavelengths must be specified as an astropy.units.Quantity"
" array or as an astropy.table.QTable")

# make sure wavelengths (or freq) are monotonic and add wavelengths
# to _matched_line_list
if isinstance(line_wavelengths, u.Quantity):
# Ensure frequency is descending or wavelength is ascending
if str(line_wavelengths.unit.physical_type) == "frequency":
line_wavelengths[::-1].sort()
else:
line_wavelengths.sort()
if not _check_arr_monotonic(line_wavelengths):
if str(line_wavelengths.unit.physical_type) == "frequency":
raise ValueError('Frequencies must be strictly increasing or decreasing.')
raise ValueError('Wavelengths must be strictly increasing or decreasing.')

self._matched_line_list["wavelength"] = line_wavelengths

elif isinstance(line_wavelengths, QTable):
line_wavelengths.sort("wavelength")
if not _check_arr_monotonic(line_wavelengths['wavelength']):
raise ValueError('Wavelengths must be strictly increasing or decreasing.')
self._matched_line_list = hstack([self._matched_line_list, line_wavelengths])

# Parse desired catalogs of lines for matching.
Expand Down

0 comments on commit 0d800fd

Please sign in to comment.