Skip to content

Commit

Permalink
Working on resolving test failures, added warning for no overlap and …
Browse files Browse the repository at this point in the history
…return None instead of blank Spectrum1D
  • Loading branch information
rosteen committed Feb 1, 2024
1 parent dc9cdff commit e67c817
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
32 changes: 27 additions & 5 deletions specutils/analysis/template_comparison.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
from astropy.nddata import StdDevUncertainty

Expand Down Expand Up @@ -30,6 +32,7 @@ def _normalize_for_template_matching(observed_spectrum, template_spectrum, stdde
"""
if stddev is None:
stddev = observed_spectrum.uncertainty.represent_as(StdDevUncertainty).quantity

num = np.nansum((observed_spectrum.flux*template_spectrum.flux) / (stddev**2))
# We need to limit this sum to where observed_spectrum is not NaN as well.
template_filtered = ((template_spectrum.flux / stddev)**2)
Expand Down Expand Up @@ -66,7 +69,8 @@ def _resample(resample_method, extrapolation_treatment):
return None


def _chi_square_for_templates(observed_spectrum, template_spectrum, resample_method, extrapolation_treatment):
def _chi_square_for_templates(observed_spectrum, template_spectrum, resample_method,
extrapolation_treatment, warn_no_overlap=True):
"""
Resample the template spectrum to match the wavelength of the observed
spectrum. Then, calculate chi2 on the flux of the two spectra.
Expand All @@ -93,11 +97,26 @@ def _chi_square_for_templates(observed_spectrum, template_spectrum, resample_met
template_obswavelength = fluxc_resample(template_spectrum,
observed_spectrum.spectral_axis)

# With truncate, template may be smaller than observed spectrum if they don't fully overlap
matching_indices = np.intersect1d(observed_spectrum.spectral_axis,
template_obswavelength.spectral_axis,
return_indices=True)[1]
if len(matching_indices) == 0:
if warn_no_overlap:
warnings.warn("Template spectrum has no overlap with observed spectrum.")
return None, np.nan

observed_truncated = observed_spectrum[matching_indices.min():matching_indices.max()+1]

# If there was no overlap, we raise an error
#if len(template_obswavelength.flux) == 0:
# raise ValueError("The template spectrum and observed spectrum have no overlap")

# Convert the uncertainty to standard deviation if needed
stddev = observed_spectrum.uncertainty.represent_as(StdDevUncertainty).array
stddev = observed_truncated.uncertainty.represent_as(StdDevUncertainty).array

# Normalize spectra
normalization = _normalize_for_template_matching(observed_spectrum,
normalization = _normalize_for_template_matching(observed_truncated,
template_obswavelength,
stddev)

Expand Down Expand Up @@ -262,8 +281,11 @@ def template_redshift(observed_spectrum, template_spectrum, redshift,
uncertainty=template_spectrum.uncertainty,
meta=template_spectrum.meta)

normalized_spectral_template, chi2 = _chi_square_for_templates(
observed_spectrum, redshifted_spectrum, resample_method, extrapolation_treatment)
normalized_spectral_template, chi2 = _chi_square_for_templates(observed_spectrum,
redshifted_spectrum,
resample_method,
extrapolation_treatment,
warn_no_overlap=False)

chi2_list.append(chi2)

Expand Down
13 changes: 4 additions & 9 deletions specutils/tests/test_template_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,11 @@ def test_template_match_no_overlap():
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

# Get result from template_match
tm_result = template_comparison.template_match(spec, spec1)
assert tm_result[3] == 0.0
with pytest.warns(UserWarning, match="Template spectrum has no overlap with observed spectrum"):
tm_result = template_comparison.template_match(spec, spec1)
assert np.isnan(tm_result[3])

# Create new spectrum for comparison
spec_result = Spectrum1D(spectral_axis=spec_axis,
flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1))
try:
assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy)
except AssertionError:
pytest.xfail('TODO: investigate why this is failing')
assert tm_result[0] is None


def test_template_match_minimal_overlap():
Expand Down

0 comments on commit e67c817

Please sign in to comment.