Skip to content

Commit

Permalink
Merge pull request #86 from NextGenCMB/ldev
Browse files Browse the repository at this point in the history
MF in iterators, lensed noise in iterbias and minors
  • Loading branch information
Sebastian-Belkner committed Jul 23, 2024
2 parents 5d9d368 + 5ee31c8 commit 34703fe
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 101 deletions.
9 changes: 8 additions & 1 deletion delensalot/biases/iterbiasesN0N1.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def get_delcls(qe_key: str, itermax:int, cls_unl_fid: dict, cls_unl_true:dict, c
cls_len_true = dls2cls(lensed_cls(dls_unl_true, cldd_true))

cls_plen_true = cls_len_true
nls_plen_true = cls_noise_true
nls_plen_fid = cls_noise_fid

for irr, it in utils.enumerate_progress(range(itermax + 1)):
dls_unl_true, cldd_true = cls2dls(cls_unl_true)
dls_unl_fid, cldd_fid = cls2dls(cls_unl_fid)
Expand Down Expand Up @@ -305,8 +308,12 @@ def get_delcls(qe_key: str, itermax:int, cls_unl_fid: dict, cls_unl_true:dict, c
cldd_fid *= (1. - rho_sqd_phi) # What I think the residual lensing spec is
cls_plen_fid = dls2cls(lensed_cls(dls_unl_fid, cldd_fid))
cls_plen_true = dls2cls(lensed_cls(dls_unl_true, cldd_true))

if 'wNl' in version:
nls_plen_fid = dls2cls(lensed_cls(cls2dls(cls_noise_fid)[0], cldd_fid))
nls_plen_true = dls2cls(lensed_cls(cls2dls(cls_noise_true)[0], cldd_fid))

fal, dat_delcls, cls_w, cls_f = get_fals(qe_key, cls_plen_fid, cls_plen_true, cls_noise_fid, cls_noise_true, lmin_ivf, lmax_ivf)
fal, dat_delcls, cls_w, cls_f = get_fals(qe_key, cls_plen_fid, cls_plen_true, nls_plen_fid, nls_plen_true, lmin_ivf, lmax_ivf)

cls_ivfs_arr = utils.cls_dot([fal, dat_delcls, fal])
cls_ivfs = dict()
Expand Down
209 changes: 120 additions & 89 deletions delensalot/core/iterator/cs_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ def hlm2dlm(self, hlm, inplace):
return almxfl(hlm, h2d, self.mmax_qlm, False)


def dlm2hlm(self, dlm, inplace):
if self.h == 'd':
return dlm if inplace else dlm.copy()
if self.h == 'p':
d2h = cli(np.sqrt(np.arange(self.lmax_qlm + 1, dtype=float) * np.arange(1, self.lmax_qlm + 2, dtype=float)))
elif self.h == 'k':
d2h = 0.5 * np.sqrt(np.arange(self.lmax_qlm + 1, dtype=float) * np.arange(1, self.lmax_qlm + 2, dtype=float))
else:
assert 0, self.h + ' not implemented'
if inplace:
almxfl(dlm, d2h, self.mmax_qlm, True)
else:
return almxfl(dlm, d2h, self.mmax_qlm, False)

def _sk2plm(self, itr):
sk_fname = lambda k: 'rlm_sn_%s_%s' % (k, 'p')
rlm = self.cacher.load('phi_%slm_it000'%self.h)
Expand Down Expand Up @@ -186,7 +200,7 @@ def _is_qd_grad_done(self, itr, key):

@log_on_start(logging.DEBUG, "get_template_blm(it={it}) started")
@log_on_end(logging.DEBUG, "get_template_blm(it={it}) finished")
def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.ndarray=None, dlm_mod=None, perturbative=False, k='p_p', pwithn1=False):
def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.ndarray=None, dlm_mod=None, perturbative=False, k='p_p', pwithn1=False, plm=None):
"""Builds a template B-mode map with the iterated phi and input elm_wf
Args:
Expand All @@ -205,18 +219,19 @@ def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.n
"""
cache_cond = (lmin_plm >= 1) and (elm_wf is None)

fn_blt = 'blt_p%03d_e%03d_lmax%s'%(it, it_e, lmaxb)
if dlm_mod is None:
pass
else:
fn_blt += '_dlmmod' * dlm_mod.any()
fn_blt += 'perturbative' * perturbative
fn_blt += '_wN1' * pwithn1
if cache_cond:
fn_blt = 'blt_p%03d_e%03d_lmax%s'%(it, it_e, lmaxb)
if dlm_mod is None:
pass
else:
fn_blt += '_dlmmod' * dlm_mod.any()
fn_blt += 'perturbative' * perturbative
fn_blt += '_wN1' * pwithn1

if self.blt_cacher.is_cached(fn_blt):
if cache_cond and self.blt_cacher.is_cached(fn_blt) :
return self.blt_cacher.load(fn_blt)
if elm_wf is None:
assert k in ['p', 'p_p'], "Need to have computed the WF for polarization E in terations"
if it_e > 0:
e_fname = 'wflm_%s_it%s' % ('p', it_e - 1)
assert self.wf_cacher.is_cached(e_fname)
Expand All @@ -229,8 +244,11 @@ def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.n
elm_wf = elm_wf[1]
assert Alm.getlmax(elm_wf.size, self.mmax_filt) == self.lmax_filt, "{}, {}, {}, {}".format(elm_wf.size, self.mmax_filt, Alm.getlmax(elm_wf.size, self.mmax_filt), self.lmax_filt)
mmaxb = lmaxb
dlm = self.get_hlm(it, 'p', pwithn1)

if plm is None:
dlm = self.get_hlm(it, 'p', pwithn1)
else:
dlm = plm
# subtract field from phi
if dlm_mod is not None:
dlm = dlm - dlm_mod
Expand Down Expand Up @@ -330,6 +348,7 @@ def load_graddet(self, itr, key):
return self.cacher.load(fn)

def load_gradpri(self, itr, key):
"""Compared to formalism of the papers, this returns -g_LM^{PR}"""
assert key in ['p'], key + ' not implemented'
assert self.is_iter_done(itr -1 , key)
ret = self.get_hlm(itr, key)
Expand All @@ -344,7 +363,7 @@ def load_gradient(self, itr, key):
"""Loads the total gradient at iteration iter.
All necessary alm's must have been calculated previously
Compared to formalism of the papers, this returns -g_LM^{tot}
"""
if itr == 0:
g = self.load_gradpri(0, key)
Expand Down Expand Up @@ -420,7 +439,9 @@ def iterate(self, itr, key):
"""Performs iteration number 'itr'
This is done by collecting the gradients at level iter, and the lower level potential
Compared to formalism of the papers, the total gradient is -g_{LM}^{Tot}.
These are the gradients of -ln posterior that we should minimize
"""
assert key.lower() in ['p', 'o'], key # potential or curl potential.
if not self.is_iter_done(itr, key):
Expand All @@ -443,7 +464,7 @@ def iterate(self, itr, key):
@log_on_end(logging.DEBUG, "calc_gradlik(it={itr}, key={key}) finished")
def calc_gradlik(self, itr, key, iwantit=False):
"""Computes the quadratic part of the gradient for plm iteration 'itr'
Compared to formalism of the papers, this returns -g_LM^{QD}
"""
assert self.is_iter_done(itr - 1, key)
assert itr > 0, itr
Expand Down Expand Up @@ -490,6 +511,7 @@ def calc_gradlik(self, itr, key, iwantit=False):
@log_on_start(logging.DEBUG, "calc_graddet(it={itr}, key={key}) started, subclassed")
@log_on_end(logging.DEBUG, "calc_graddet(it={itr}, key={key}) finished, subclassed")
def calc_graddet(self, itr, key):
"""Compared to formalism of the papers, this should return +g_LM^{MF}"""
assert 0, 'subclass this'


Expand Down Expand Up @@ -603,7 +625,8 @@ def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
dat_maps:list or np.ndarray, plm0:np.ndarray, mf_key:int, pp_h0:np.ndarray,
cpp_prior:np.ndarray, cls_filt:dict, ninv_filt:opfilt_base.alm_filter_wl, k_geom:utils_geom.Geom,
chain_descr, stepper:steps.nrstep, mc_sims:np.ndarray=None, sub_nolensing:bool=False,
mf_cmb_phas=None, mf_noise_phas=None, shift_phas=False, **kwargs):
mf_cmb_phas=None, mf_noise_phas=None, shift_phas=False,
mf_lowpass_filter= lambda ell: np.ones(len(ell), dtype=float), **kwargs):
super(iterator_simf_mcs, self).__init__(lib_dir, h, lm_max_dlm, dat_maps, plm0, pp_h0, cpp_prior, cls_filt,
ninv_filt, k_geom, chain_descr, stepper, **kwargs)
self.mf_key = mf_key
Expand All @@ -612,11 +635,13 @@ def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
self.mf_cmb_phas = mf_cmb_phas
self.mf_noise_phas = mf_noise_phas
self.shift_phas = shift_phas
self.mf_lowpass_filter = mf_lowpass_filter

print(f'Iterator MF key: {self.mf_key}')

@log_on_start(logging.DEBUG, "load_graddet(it={itr}, key={key}) started")
@log_on_end(logging.DEBUG, "load_graddet(it={itr}, key={key}) finished")
def load_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims = None):
def load_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims = None, return_half_mean=False):
if mc_sims is None:
mc_sims = self.mc_sims

Expand Down Expand Up @@ -644,51 +669,76 @@ def load_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims = None):

print(f"****Iterator: MF estimated for iter {itr} with sims {mcs} ****")

Gmfs = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)
# Gmf = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if return_half_mean is True:
Gmf1 = self.get_grad_mf(itr, key, mcs[::2], nolensing=False, A_dlm=A_dlm)
Gmf2 = self.get_grad_mf(itr, key, mcs[1::2], nolensing=False, A_dlm=A_dlm)
if self.sub_nolensing:
Gmf1_nl = self.get_grad_mf(itr, key, mcs[::2], nolensing=True, A_dlm=A_dlm)
Gmf1 -= Gmf1_nl
Gmf2_nl = self.get_grad_mf(itr, key, mcs[1::2], nolensing=True, A_dlm=A_dlm)
Gmf2 -= Gmf2_nl
return Gmf1, Gmf2
else:
Gmf = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if self.sub_nolensing:
Gmfs_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmfs -= Gmfs_nl

return Gmfs if get_all_mcs else np.mean(Gmfs, axis=0)
if self.sub_nolensing:
Gmf_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmf -= Gmf_nl
return Gmf
# return Gmfs if get_all_mcs else np.mean(Gmfs, axis=0)


@log_on_start(logging.DEBUG, "calc_graddet(it={itr}, key={key}) started")
@log_on_end(logging.DEBUG, "calc_graddet(it={itr}, key={key}) finished")
def calc_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims=None):
def calc_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims=None, return_half_mean=False):
if mc_sims is None:
mc_sims = self.mc_sims

assert self.is_iter_done(itr - 1, key)
assert itr > 0, itr
# assert itr > 0, itr
if itr == 0:
return 0

assert key in ['p'], key + ' not implemented'
try:
mc_sims[itr]
except IndexError:
print(f'No MC sims defined for MF estimate at iter {itr}')
return 0

if self.mc_sims[itr] == 0:
if mc_sims[itr] == 0:
print(f"No MF subtraction is performed for iter {itr}")
return 0

if self.shift_phas:
sim0 = np.sum(self.mc_sims[:itr])
sim0 = np.sum(mc_sims[:itr])
else:
sim0 = 0
try:
mcs = np.arange(sim0, sim0 + self.mc_sims[itr])
except IndexError:
print(f'No MC sims defined for MF estimate at iter {itr}')
return 0

mcs = np.arange(sim0, sim0 + mc_sims[itr])

print(f"****Iterator: MF estimated for iter {itr} with sims {mcs} ****")

Gmfs = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if self.sub_nolensing:
Gmfs_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmfs -= Gmfs_nl

return Gmfs if get_all_mcs else np.mean(Gmfs, axis=0)
if return_half_mean is True:
Gmf1 = self.get_grad_mf(itr, key, mcs[::2], nolensing=False, A_dlm=A_dlm)
Gmf2 = self.get_grad_mf(itr, key, mcs[1::2], nolensing=False, A_dlm=A_dlm)
if self.sub_nolensing:
Gmf1_nl = self.get_grad_mf(itr, key, mcs[::2], nolensing=True, A_dlm=A_dlm)
Gmf1 -= Gmf1_nl
Gmf2_nl = self.get_grad_mf(itr, key, mcs[1::2], nolensing=True, A_dlm=A_dlm)
Gmf2 -= Gmf2_nl
return Gmf1, Gmf2
else:
Gmf = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if self.sub_nolensing:
Gmf_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmf -= Gmf_nl
return Gmf


def get_grad_mf(self, itr:int, key:str, mcs:np.ndarray, nolensing:bool=False, A_dlm=1.):
def get_grad_mf(self, itr:int, key:str, mcs:np.ndarray, nolensing:bool=False, A_dlm=1., debug_prints=False):
"""Returns the QD gradients for a set of simulations, and cache the results
Args:
Expand All @@ -698,36 +748,42 @@ def get_grad_mf(self, itr:int, key:str, mcs:np.ndarray, nolensing:bool=False, A_
nolensing: estimate the gradient without lensing (useful to subtract sims with same phase but no lensing to reduce variance)
A_dlm: change the lensing deflection amplitude (useful for tests)
Returns:
G_mfs: an array of the gradients estimate for each simulation
G_mf: the mean of the gradients estimate for the set of simulations
"""
dlm = self.get_hlm(itr - 1, key) * A_dlm

mf_cacher = cachers.cacher_npy(opj(self.lib_dir, f'mf_sims_itr{itr:03d}'))
if debug_prints:
print(mf_cacher.lib_dir)
fn_qlm = lambda this_idx : f'qlm_mf{self.mf_key}_sim_{this_idx:04d}' + '_nolensing' * nolensing + f'_Adlm{A_dlm}' * (A_dlm!=1.)

if nolensing:
dlm *= 0.

self.hlm2dlm(dlm, True)
dlm = np.zeros(Alm.getsize(self.lmax_qlm, self.mmax_qlm), dtype='complex128')
else:
dlm = self.get_hlm(itr - 1, key) * A_dlm
self.hlm2dlm(dlm, True)

ffi = self.filter.ffi.change_dlm([dlm, None], self.mmax_qlm, cachers.cacher_mem(safe=False))
self.filter.set_ffi(ffi)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter)
t0 = time.time()
# t0 = time.time()
q_geom = self.filter.ffi.pbgeom

mf_cacher = cachers.cacher_npy(opj(self.lib_dir, f'mf_sims_itr{itr:03d}'))
fn_qlm = lambda this_idx : f'qlm_mf{self.mf_key}_sim_{this_idx:04d}' + '_nolensing' * nolensing + f'_Adlm{A_dlm}' * (A_dlm!=1.)

_Gmfs = []
_Gmfs = np.zeros(Alm.getsize(ffi.lmax_dlm, ffi.mmax_dlm), dtype='complex128')
for i, idx in enumerate_progress(mcs, label='Getting MF sims' + ' no lensing' * nolensing):
if debug_prints:
print(fn_qlm(idx))
if not mf_cacher.is_cached(fn_qlm(idx)):
#FIXME: the cmb_phas and noise_phas should have more than one field for Pol or MV estimators
#!FIXME: the cmb_phas and noise_phas should have more than one field for Pol or MV estimators
G, C = self.filter.get_qlms_mf(
self.mf_key, q_geom, mchain,
phas=self.mf_cmb_phas.get_sim(idx, idf=0),
noise_phas=self.mf_noise_phas.get_sim(idx, idf=0), cls_filt=self.cls_filt)

mf_cacher.cache(fn_qlm(idx), G)
_Gmfs.append(mf_cacher.load(fn_qlm(idx)))
return np.array(_Gmfs)
_Gmfs += mf_cacher.load(fn_qlm(idx))

almxfl(_Gmfs, self.mf_lowpass_filter(np.arange(ffi.lmax_dlm+1)), ffi.mmax_dlm, inplace=True)
return _Gmfs / len(mcs)


class iterator_simf(qlm_iterator):
Expand All @@ -739,15 +795,15 @@ class iterator_simf(qlm_iterator):
def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
dat_maps:list or np.ndarray, plm0:np.ndarray, mf_key:int, pp_h0:np.ndarray,
cpp_prior:np.ndarray, cls_filt:dict, ninv_filt:opfilt_base.alm_filter_wl, k_geom:utils_geom.Geom,
chain_descr, stepper:steps.nrstep, sub_nolensing:bool=False, **kwargs):
chain_descr, stepper:steps.nrstep, **kwargs):
super(iterator_simf, self).__init__(lib_dir, h, lm_max_dlm, dat_maps, plm0, pp_h0, cpp_prior, cls_filt,
ninv_filt, k_geom, chain_descr, stepper, **kwargs)
self.mf_key = mf_key
self.sub_nolensing = sub_nolensing


@log_on_start(logging.DEBUG, "calc_graddet(it={itr}, key={key}) started")
@log_on_end(logging.DEBUG, "calc_graddet(it={itr}, key={key}) finished")
def calc_graddet(self, itr, key, nolensing=False):
def calc_graddet(self, itr, key):
assert self.is_iter_done(itr - 1, key)
assert itr > 0, itr
assert key in ['p'], key + ' not implemented'
Expand All @@ -757,40 +813,15 @@ def calc_graddet(self, itr, key, nolensing=False):
self.filter.set_ffi(ffi)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter)
t0 = time.time()
# q_geom = pbdGeometry(self.k_geom, pbounds(0., 2 * np.pi))
q_geom = self.filter.ffi.pbgeom
if self.sub_nolensing:
phas = self.filter.get_unit_variance()
else:
phas = None
G, C = self.filter.get_qlms_mf(self.mf_key, q_geom, mchain, phas=phas, cls_filt=self.cls_filt)
q_geom = pbdGeometry(self.k_geom, pbounds(0., 2 * np.pi))
G, C = self.filter.get_qlms_mf(self.mf_key, q_geom, mchain, cls_filt=self.cls_filt)
almxfl(G if key.lower() == 'p' else C, self._h2p(self.lmax_qlm), self.mmax_qlm, True)
log.info('get_qlm_mf calculation done; (%.0f secs)' % (time.time() - t0))

#NOTE + or - here ?
ret = G if key.lower() == 'p' else C

if self.sub_nolensing:
log.info('Subtracting MF with same phase but no lensing in likelihood')
dlm *= 0
self.hlm2dlm(dlm, True)
ffi = self.filter.ffi.change_dlm([dlm, None], self.mmax_qlm, cachers.cacher_mem(safe=False))
self.filter.set_ffi(ffi)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter)

q_geom = self.filter.ffi.pbgeom
G0, C0 = self.filter.get_qlms_mf(self.mf_key, q_geom, mchain, phas=phas, cls_filt=self.cls_filt)
almxfl(G0 if key.lower() == 'p' else C0, self._h2p(self.lmax_qlm), self.mmax_qlm, True)

ret -= G0 if key.lower() == 'p' else C0


if itr == 1: # We need the gradient at 0 and the yk's to be able to rebuild all gradients
fn_lik = '%slm_grad%sdet_it%03d' % (self.h, key.lower(), 0)
self.cacher.cache(fn_lik,ret)

return ret

self.cacher.cache(fn_lik, -G if key.lower() == 'p' else -C)
# !FIXME: This sign is probably wrong, as it should return +g^MF
return -G if key.lower() == 'p' else -C

class iterator_cstmf_bfgs0(iterator_cstmf):
"""Variant of the iterator where the initial curvature guess is itself a bfgs update from phi =0 to input plm
Expand Down
Loading

0 comments on commit 34703fe

Please sign in to comment.