diff --git a/specutils/analysis/template_comparison.py b/specutils/analysis/template_comparison.py index 42ec0b0c2..4de9fb82d 100644 --- a/specutils/analysis/template_comparison.py +++ b/specutils/analysis/template_comparison.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from astropy.nddata import StdDevUncertainty @@ -30,6 +32,7 @@ def _normalize_for_template_matching(observed_spectrum, template_spectrum, stdde """ if stddev is None: stddev = observed_spectrum.uncertainty.represent_as(StdDevUncertainty).quantity + num = np.nansum((observed_spectrum.flux*template_spectrum.flux) / (stddev**2)) # We need to limit this sum to where observed_spectrum is not NaN as well. template_filtered = ((template_spectrum.flux / stddev)**2) @@ -66,7 +69,8 @@ def _resample(resample_method, extrapolation_treatment): return None -def _chi_square_for_templates(observed_spectrum, template_spectrum, resample_method, extrapolation_treatment): +def _chi_square_for_templates(observed_spectrum, template_spectrum, resample_method, + extrapolation_treatment, warn_no_overlap=True): """ Resample the template spectrum to match the wavelength of the observed spectrum. Then, calculate chi2 on the flux of the two spectra. @@ -93,11 +97,26 @@ def _chi_square_for_templates(observed_spectrum, template_spectrum, resample_met template_obswavelength = fluxc_resample(template_spectrum, observed_spectrum.spectral_axis) + # With truncate, template may be smaller than observed spectrum if they don't fully overlap + matching_indices = np.intersect1d(observed_spectrum.spectral_axis, + template_obswavelength.spectral_axis, + return_indices=True)[1] + if len(matching_indices) == 0: + if warn_no_overlap: + warnings.warn("Template spectrum has no overlap with observed spectrum.") + return None, np.nan + + observed_truncated = observed_spectrum[matching_indices.min():matching_indices.max()+1] + + # If there was no overlap, we raise an error + #if len(template_obswavelength.flux) == 0: + # raise ValueError("The template spectrum and observed spectrum have no overlap") + # Convert the uncertainty to standard deviation if needed - stddev = observed_spectrum.uncertainty.represent_as(StdDevUncertainty).array + stddev = observed_truncated.uncertainty.represent_as(StdDevUncertainty).array # Normalize spectra - normalization = _normalize_for_template_matching(observed_spectrum, + normalization = _normalize_for_template_matching(observed_truncated, template_obswavelength, stddev) @@ -262,8 +281,11 @@ def template_redshift(observed_spectrum, template_spectrum, redshift, uncertainty=template_spectrum.uncertainty, meta=template_spectrum.meta) - normalized_spectral_template, chi2 = _chi_square_for_templates( - observed_spectrum, redshifted_spectrum, resample_method, extrapolation_treatment) + normalized_spectral_template, chi2 = _chi_square_for_templates(observed_spectrum, + redshifted_spectrum, + resample_method, + extrapolation_treatment, + warn_no_overlap=False) chi2_list.append(chi2) diff --git a/specutils/tests/test_template_comparison.py b/specutils/tests/test_template_comparison.py index d7cf8978c..9132a634c 100644 --- a/specutils/tests/test_template_comparison.py +++ b/specutils/tests/test_template_comparison.py @@ -29,16 +29,11 @@ def test_template_match_no_overlap(): uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) # Get result from template_match - tm_result = template_comparison.template_match(spec, spec1) - assert tm_result[3] == 0.0 + with pytest.warns(UserWarning, match="Template spectrum has no overlap with observed spectrum"): + tm_result = template_comparison.template_match(spec, spec1) + assert np.isnan(tm_result[3]) - # Create new spectrum for comparison - spec_result = Spectrum1D(spectral_axis=spec_axis, - flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1)) - try: - assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy) - except AssertionError: - pytest.xfail('TODO: investigate why this is failing') + assert tm_result[0] is None def test_template_match_minimal_overlap():