Skip to content

Commit

Permalink
add coadd tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
weaverba137 committed Nov 17, 2023
1 parent b72ab31 commit 93bc0ff
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 42 deletions.
45 changes: 13 additions & 32 deletions py/desispec/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,34 +758,13 @@ def from_specutils(cls, spectra):
#
# Load objects that are independent of band from the first item.
#
try:
fibermap = sl[0].meta['fibermap']
except KeyError:
fibermap = None
try:
exp_fibermap = sl[0].meta['exp_fibermap']
except KeyError:
exp_fibermap = None
try:
meta = sl[0].meta['desi_meta']
except KeyError:
meta = None
try:
single = sl[0].meta['single']
except KeyError:
single = False
try:
scores = sl[0].meta['scores']
except KeyError:
scores = None
try:
scores_comments = sl[0].meta['scores_comments']
except KeyError:
scores_comments = None
try:
extra_catalog = sl[0].meta['extra_catalog']
except KeyError:
extra_catalog = None
fibermap = sl[0].meta.get('fibermap', None)
exp_fibermap = sl[0].meta.get('exp_fibermap', None)
meta = sl[0].meta.get('desi_meta', None)
single = sl[0].meta.get('single', False)
scores = sl[0].meta.get('scores', None)
scores_comments = sl[0].meta.get('scores_comments', None)
extra_catalog = sl[0].meta.get('extra_catalog', None)
#
# Load band-dependent quantities.
#
Expand All @@ -795,14 +774,16 @@ def from_specutils(cls, spectra):
mask = dict()
resolution_data = None
extra = None
AA = Unit('Angstrom')
specunit = Unit('10-17 erg cm-2 s-1 AA-1')
for i, band in enumerate(bands):
wave[band] = sl[i].spectral_axis.value
flux[band] = sl[i].flux.value
wave[band] = sl[i].spectral_axis.to(AA).value
flux[band] = sl[i].flux.to(specunit).value
if isinstance(sl[i].uncertainty, InverseVariance):
ivar[band] = sl[i].uncertainty.array
ivar[band] = sl[i].uncertainty.quantity.to(specunit**-2).value
elif isinstance(sl[i].uncertainty, StdDevUncertainty):
# Future: may need np.isfinite() here?
ivar[band] = (sl[i].uncertainty.array)**-2
ivar[band] = (sl[i].uncertainty.quantity.to(specunit).value)**-2
else:
raise ValueError("Unknown uncertainty type!")
try:
Expand Down
19 changes: 9 additions & 10 deletions py/desispec/test/test_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def setUp(self):
"KEY1" : "VAL1",
"KEY2" : "VAL2"
}
self.nwave = 100
self.nspec = 5
self.nwave = 101
self.nspec = 6
self.ndiag = 3

fmap = empty_fibermap(self.nspec)
Expand Down Expand Up @@ -96,9 +96,9 @@ def setUp(self):
self.extra = {}

for s in range(self.nspec):
self.wave['b'] = np.linspace(3600, 5800, self.nwave, dtype=float)
self.wave['r'] = np.linspace(5760, 7620, self.nwave, dtype=float)
self.wave['z'] = np.linspace(7520, 9824, self.nwave, dtype=float)
self.wave['b'] = np.linspace(3500, 5800, self.nwave, dtype=float)
self.wave['r'] = np.linspace(5570, 7870, self.nwave, dtype=float)
self.wave['z'] = np.linspace(7640, 9940, self.nwave, dtype=float)
for b in self.bands:
self.flux[b] = np.repeat(np.arange(self.nspec, dtype=float),
self.nwave).reshape( (self.nspec, self.nwave) ) + 3.0
Expand Down Expand Up @@ -501,24 +501,23 @@ def test_from_specutils(self):
self.assertTrue((sp1.mask[self.bands[2]] == sp2.mask[self.bands[2]]).all())
self.assertDictEqual(sp1.meta, sp2.meta)

# @unittest.skipUnless(_specutils_imported, "Unable to import specutils.")
@unittest.expectedFailure
@unittest.skipUnless(_specutils_imported, "Unable to import specutils.")
def test_from_specutils_coadd(self):
"""Test conversion from a Spectrum1D object representing a coadd across cameras.
"""
sp0 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, exp_fibermap=self.efmap1,
meta=self.meta, extra=self.extra, scores=self.scores,
meta=self.meta, extra=None, scores=self.scores,
extra_catalog=self.extra_catalog)
sp1 = desispec.coaddition.coadd_cameras(sp0)
spectrum_list = sp1.to_specutils()
sp2 = Spectra.from_specutils(spectrum_list[0])
self.assertEqual(sp2.bands[0], 'brz')
self.assertListEqual(sp1.bands, sp2.bands)
self.assertTrue((sp1.flux[self.bands[0]] == sp2.flux[self.bands[0]]).all())
self.assertTrue((sp1.ivar[self.bands[1]] == sp2.ivar[self.bands[1]]).all())
self.assertTrue((sp1.mask[self.bands[2]] == sp2.mask[self.bands[2]]).all())
self.assertTrue((sp1.ivar[self.bands[0]] == sp2.ivar[self.bands[0]]).all())
self.assertTrue((sp1.mask[self.bands[0]] == sp2.mask[self.bands[0]]).all())
self.assertDictEqual(sp1.meta, sp2.meta)


Expand Down

0 comments on commit 93bc0ff

Please sign in to comment.