Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MF in iterators, lensed noise in iterbias and minors #86

Merged
merged 6 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion delensalot/biases/iterbiasesN0N1.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def get_delcls(qe_key: str, itermax:int, cls_unl_fid: dict, cls_unl_true:dict, c
cls_len_true = dls2cls(lensed_cls(dls_unl_true, cldd_true))

cls_plen_true = cls_len_true
nls_plen_true = cls_noise_true
nls_plen_fid = cls_noise_fid

for irr, it in utils.enumerate_progress(range(itermax + 1)):
dls_unl_true, cldd_true = cls2dls(cls_unl_true)
dls_unl_fid, cldd_fid = cls2dls(cls_unl_fid)
Expand Down Expand Up @@ -305,8 +308,12 @@ def get_delcls(qe_key: str, itermax:int, cls_unl_fid: dict, cls_unl_true:dict, c
cldd_fid *= (1. - rho_sqd_phi) # What I think the residual lensing spec is
cls_plen_fid = dls2cls(lensed_cls(dls_unl_fid, cldd_fid))
cls_plen_true = dls2cls(lensed_cls(dls_unl_true, cldd_true))

if 'wNl' in version:
nls_plen_fid = dls2cls(lensed_cls(cls2dls(cls_noise_fid)[0], cldd_fid))
nls_plen_true = dls2cls(lensed_cls(cls2dls(cls_noise_true)[0], cldd_fid))

fal, dat_delcls, cls_w, cls_f = get_fals(qe_key, cls_plen_fid, cls_plen_true, cls_noise_fid, cls_noise_true, lmin_ivf, lmax_ivf)
fal, dat_delcls, cls_w, cls_f = get_fals(qe_key, cls_plen_fid, cls_plen_true, nls_plen_fid, nls_plen_true, lmin_ivf, lmax_ivf)

cls_ivfs_arr = utils.cls_dot([fal, dat_delcls, fal])
cls_ivfs = dict()
Expand Down
209 changes: 120 additions & 89 deletions delensalot/core/iterator/cs_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ def hlm2dlm(self, hlm, inplace):
return almxfl(hlm, h2d, self.mmax_qlm, False)


def dlm2hlm(self, dlm, inplace):
if self.h == 'd':
return dlm if inplace else dlm.copy()
if self.h == 'p':
d2h = cli(np.sqrt(np.arange(self.lmax_qlm + 1, dtype=float) * np.arange(1, self.lmax_qlm + 2, dtype=float)))
elif self.h == 'k':
d2h = 0.5 * np.sqrt(np.arange(self.lmax_qlm + 1, dtype=float) * np.arange(1, self.lmax_qlm + 2, dtype=float))
else:
assert 0, self.h + ' not implemented'
if inplace:
almxfl(dlm, d2h, self.mmax_qlm, True)
else:
return almxfl(dlm, d2h, self.mmax_qlm, False)

def _sk2plm(self, itr):
sk_fname = lambda k: 'rlm_sn_%s_%s' % (k, 'p')
rlm = self.cacher.load('phi_%slm_it000'%self.h)
Expand Down Expand Up @@ -186,7 +200,7 @@ def _is_qd_grad_done(self, itr, key):

@log_on_start(logging.DEBUG, "get_template_blm(it={it}) started")
@log_on_end(logging.DEBUG, "get_template_blm(it={it}) finished")
def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.ndarray=None, dlm_mod=None, perturbative=False, k='p_p', pwithn1=False):
def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.ndarray=None, dlm_mod=None, perturbative=False, k='p_p', pwithn1=False, plm=None):
"""Builds a template B-mode map with the iterated phi and input elm_wf

Args:
Expand All @@ -205,18 +219,19 @@ def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.n

"""
cache_cond = (lmin_plm >= 1) and (elm_wf is None)

fn_blt = 'blt_p%03d_e%03d_lmax%s'%(it, it_e, lmaxb)
if dlm_mod is None:
pass
else:
fn_blt += '_dlmmod' * dlm_mod.any()
fn_blt += 'perturbative' * perturbative
fn_blt += '_wN1' * pwithn1
if cache_cond:
fn_blt = 'blt_p%03d_e%03d_lmax%s'%(it, it_e, lmaxb)
if dlm_mod is None:
pass
else:
fn_blt += '_dlmmod' * dlm_mod.any()
fn_blt += 'perturbative' * perturbative
fn_blt += '_wN1' * pwithn1

if self.blt_cacher.is_cached(fn_blt):
if cache_cond and self.blt_cacher.is_cached(fn_blt) :
return self.blt_cacher.load(fn_blt)
if elm_wf is None:
assert k in ['p', 'p_p'], "Need to have computed the WF for polarization E in terations"
if it_e > 0:
e_fname = 'wflm_%s_it%s' % ('p', it_e - 1)
assert self.wf_cacher.is_cached(e_fname)
Expand All @@ -229,8 +244,11 @@ def get_template_blm(self, it, it_e, lmaxb=1024, lmin_plm=1, elm_wf:None or np.n
elm_wf = elm_wf[1]
assert Alm.getlmax(elm_wf.size, self.mmax_filt) == self.lmax_filt, "{}, {}, {}, {}".format(elm_wf.size, self.mmax_filt, Alm.getlmax(elm_wf.size, self.mmax_filt), self.lmax_filt)
mmaxb = lmaxb
dlm = self.get_hlm(it, 'p', pwithn1)

if plm is None:
dlm = self.get_hlm(it, 'p', pwithn1)
else:
dlm = plm
# subtract field from phi
if dlm_mod is not None:
dlm = dlm - dlm_mod
Expand Down Expand Up @@ -330,6 +348,7 @@ def load_graddet(self, itr, key):
return self.cacher.load(fn)

def load_gradpri(self, itr, key):
"""Compared to formalism of the papers, this returns -g_LM^{PR}"""
assert key in ['p'], key + ' not implemented'
assert self.is_iter_done(itr -1 , key)
ret = self.get_hlm(itr, key)
Expand All @@ -344,7 +363,7 @@ def load_gradient(self, itr, key):
"""Loads the total gradient at iteration iter.

All necessary alm's must have been calculated previously

Compared to formalism of the papers, this returns -g_LM^{tot}
"""
if itr == 0:
g = self.load_gradpri(0, key)
Expand Down Expand Up @@ -420,7 +439,9 @@ def iterate(self, itr, key):
"""Performs iteration number 'itr'

This is done by collecting the gradients at level iter, and the lower level potential


Compared to formalism of the papers, the total gradient is -g_{LM}^{Tot}.
These are the gradients of -ln posterior that we should minimize
"""
assert key.lower() in ['p', 'o'], key # potential or curl potential.
if not self.is_iter_done(itr, key):
Expand All @@ -443,7 +464,7 @@ def iterate(self, itr, key):
@log_on_end(logging.DEBUG, "calc_gradlik(it={itr}, key={key}) finished")
def calc_gradlik(self, itr, key, iwantit=False):
"""Computes the quadratic part of the gradient for plm iteration 'itr'

Compared to formalism of the papers, this returns -g_LM^{QD}
"""
assert self.is_iter_done(itr - 1, key)
assert itr > 0, itr
Expand Down Expand Up @@ -490,6 +511,7 @@ def calc_gradlik(self, itr, key, iwantit=False):
@log_on_start(logging.DEBUG, "calc_graddet(it={itr}, key={key}) started, subclassed")
@log_on_end(logging.DEBUG, "calc_graddet(it={itr}, key={key}) finished, subclassed")
def calc_graddet(self, itr, key):
"""Compared to formalism of the papers, this should return +g_LM^{MF}"""
assert 0, 'subclass this'


Expand Down Expand Up @@ -603,7 +625,8 @@ def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
dat_maps:list or np.ndarray, plm0:np.ndarray, mf_key:int, pp_h0:np.ndarray,
cpp_prior:np.ndarray, cls_filt:dict, ninv_filt:opfilt_base.alm_filter_wl, k_geom:utils_geom.Geom,
chain_descr, stepper:steps.nrstep, mc_sims:np.ndarray=None, sub_nolensing:bool=False,
mf_cmb_phas=None, mf_noise_phas=None, shift_phas=False, **kwargs):
mf_cmb_phas=None, mf_noise_phas=None, shift_phas=False,
mf_lowpass_filter= lambda ell: np.ones(len(ell), dtype=float), **kwargs):
super(iterator_simf_mcs, self).__init__(lib_dir, h, lm_max_dlm, dat_maps, plm0, pp_h0, cpp_prior, cls_filt,
ninv_filt, k_geom, chain_descr, stepper, **kwargs)
self.mf_key = mf_key
Expand All @@ -612,11 +635,13 @@ def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
self.mf_cmb_phas = mf_cmb_phas
self.mf_noise_phas = mf_noise_phas
self.shift_phas = shift_phas
self.mf_lowpass_filter = mf_lowpass_filter

print(f'Iterator MF key: {self.mf_key}')

@log_on_start(logging.DEBUG, "load_graddet(it={itr}, key={key}) started")
@log_on_end(logging.DEBUG, "load_graddet(it={itr}, key={key}) finished")
def load_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims = None):
def load_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims = None, return_half_mean=False):
if mc_sims is None:
mc_sims = self.mc_sims

Expand Down Expand Up @@ -644,51 +669,76 @@ def load_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims = None):

print(f"****Iterator: MF estimated for iter {itr} with sims {mcs} ****")

Gmfs = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)
# Gmf = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if return_half_mean is True:
Gmf1 = self.get_grad_mf(itr, key, mcs[::2], nolensing=False, A_dlm=A_dlm)
Gmf2 = self.get_grad_mf(itr, key, mcs[1::2], nolensing=False, A_dlm=A_dlm)
if self.sub_nolensing:
Gmf1_nl = self.get_grad_mf(itr, key, mcs[::2], nolensing=True, A_dlm=A_dlm)
Gmf1 -= Gmf1_nl
Gmf2_nl = self.get_grad_mf(itr, key, mcs[1::2], nolensing=True, A_dlm=A_dlm)
Gmf2 -= Gmf2_nl
return Gmf1, Gmf2
else:
Gmf = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if self.sub_nolensing:
Gmfs_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmfs -= Gmfs_nl

return Gmfs if get_all_mcs else np.mean(Gmfs, axis=0)
if self.sub_nolensing:
Gmf_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmf -= Gmf_nl
return Gmf
# return Gmfs if get_all_mcs else np.mean(Gmfs, axis=0)


@log_on_start(logging.DEBUG, "calc_graddet(it={itr}, key={key}) started")
@log_on_end(logging.DEBUG, "calc_graddet(it={itr}, key={key}) finished")
def calc_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims=None):
def calc_graddet(self, itr, key, get_all_mcs=False, A_dlm=1., mc_sims=None, return_half_mean=False):
if mc_sims is None:
mc_sims = self.mc_sims

assert self.is_iter_done(itr - 1, key)
assert itr > 0, itr
# assert itr > 0, itr
if itr == 0:
return 0

assert key in ['p'], key + ' not implemented'
try:
mc_sims[itr]
except IndexError:
print(f'No MC sims defined for MF estimate at iter {itr}')
return 0

if self.mc_sims[itr] == 0:
if mc_sims[itr] == 0:
print(f"No MF subtraction is performed for iter {itr}")
return 0

if self.shift_phas:
sim0 = np.sum(self.mc_sims[:itr])
sim0 = np.sum(mc_sims[:itr])
else:
sim0 = 0
try:
mcs = np.arange(sim0, sim0 + self.mc_sims[itr])
except IndexError:
print(f'No MC sims defined for MF estimate at iter {itr}')
return 0

mcs = np.arange(sim0, sim0 + mc_sims[itr])

print(f"****Iterator: MF estimated for iter {itr} with sims {mcs} ****")

Gmfs = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if self.sub_nolensing:
Gmfs_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmfs -= Gmfs_nl

return Gmfs if get_all_mcs else np.mean(Gmfs, axis=0)
if return_half_mean is True:
Gmf1 = self.get_grad_mf(itr, key, mcs[::2], nolensing=False, A_dlm=A_dlm)
Gmf2 = self.get_grad_mf(itr, key, mcs[1::2], nolensing=False, A_dlm=A_dlm)
if self.sub_nolensing:
Gmf1_nl = self.get_grad_mf(itr, key, mcs[::2], nolensing=True, A_dlm=A_dlm)
Gmf1 -= Gmf1_nl
Gmf2_nl = self.get_grad_mf(itr, key, mcs[1::2], nolensing=True, A_dlm=A_dlm)
Gmf2 -= Gmf2_nl
return Gmf1, Gmf2
else:
Gmf = self.get_grad_mf(itr, key, mcs, nolensing=False, A_dlm=A_dlm)

if self.sub_nolensing:
Gmf_nl = self.get_grad_mf(itr, key, mcs, nolensing=True, A_dlm=A_dlm)
Gmf -= Gmf_nl
return Gmf


def get_grad_mf(self, itr:int, key:str, mcs:np.ndarray, nolensing:bool=False, A_dlm=1.):
def get_grad_mf(self, itr:int, key:str, mcs:np.ndarray, nolensing:bool=False, A_dlm=1., debug_prints=False):
"""Returns the QD gradients for a set of simulations, and cache the results

Args:
Expand All @@ -698,36 +748,42 @@ def get_grad_mf(self, itr:int, key:str, mcs:np.ndarray, nolensing:bool=False, A_
nolensing: estimate the gradient without lensing (useful to subtract sims with same phase but no lensing to reduce variance)
A_dlm: change the lensing deflection amplitude (useful for tests)
Returns:
G_mfs: an array of the gradients estimate for each simulation
G_mf: the mean of the gradients estimate for the set of simulations

"""
dlm = self.get_hlm(itr - 1, key) * A_dlm

mf_cacher = cachers.cacher_npy(opj(self.lib_dir, f'mf_sims_itr{itr:03d}'))
if debug_prints:
print(mf_cacher.lib_dir)
fn_qlm = lambda this_idx : f'qlm_mf{self.mf_key}_sim_{this_idx:04d}' + '_nolensing' * nolensing + f'_Adlm{A_dlm}' * (A_dlm!=1.)

if nolensing:
dlm *= 0.

self.hlm2dlm(dlm, True)
dlm = np.zeros(Alm.getsize(self.lmax_qlm, self.mmax_qlm), dtype='complex128')
else:
dlm = self.get_hlm(itr - 1, key) * A_dlm
self.hlm2dlm(dlm, True)

ffi = self.filter.ffi.change_dlm([dlm, None], self.mmax_qlm, cachers.cacher_mem(safe=False))
self.filter.set_ffi(ffi)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter)
t0 = time.time()
# t0 = time.time()
q_geom = self.filter.ffi.pbgeom

mf_cacher = cachers.cacher_npy(opj(self.lib_dir, f'mf_sims_itr{itr:03d}'))
fn_qlm = lambda this_idx : f'qlm_mf{self.mf_key}_sim_{this_idx:04d}' + '_nolensing' * nolensing + f'_Adlm{A_dlm}' * (A_dlm!=1.)

_Gmfs = []
_Gmfs = np.zeros(Alm.getsize(ffi.lmax_dlm, ffi.mmax_dlm), dtype='complex128')
for i, idx in enumerate_progress(mcs, label='Getting MF sims' + ' no lensing' * nolensing):
if debug_prints:
print(fn_qlm(idx))
if not mf_cacher.is_cached(fn_qlm(idx)):
#FIXME: the cmb_phas and noise_phas should have more than one field for Pol or MV estimators
#!FIXME: the cmb_phas and noise_phas should have more than one field for Pol or MV estimators
G, C = self.filter.get_qlms_mf(
self.mf_key, q_geom, mchain,
phas=self.mf_cmb_phas.get_sim(idx, idf=0),
noise_phas=self.mf_noise_phas.get_sim(idx, idf=0), cls_filt=self.cls_filt)

mf_cacher.cache(fn_qlm(idx), G)
_Gmfs.append(mf_cacher.load(fn_qlm(idx)))
return np.array(_Gmfs)
_Gmfs += mf_cacher.load(fn_qlm(idx))

almxfl(_Gmfs, self.mf_lowpass_filter(np.arange(ffi.lmax_dlm+1)), ffi.mmax_dlm, inplace=True)
return _Gmfs / len(mcs)


class iterator_simf(qlm_iterator):
Expand All @@ -739,15 +795,15 @@ class iterator_simf(qlm_iterator):
def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
dat_maps:list or np.ndarray, plm0:np.ndarray, mf_key:int, pp_h0:np.ndarray,
cpp_prior:np.ndarray, cls_filt:dict, ninv_filt:opfilt_base.alm_filter_wl, k_geom:utils_geom.Geom,
chain_descr, stepper:steps.nrstep, sub_nolensing:bool=False, **kwargs):
chain_descr, stepper:steps.nrstep, **kwargs):
super(iterator_simf, self).__init__(lib_dir, h, lm_max_dlm, dat_maps, plm0, pp_h0, cpp_prior, cls_filt,
ninv_filt, k_geom, chain_descr, stepper, **kwargs)
self.mf_key = mf_key
self.sub_nolensing = sub_nolensing


@log_on_start(logging.DEBUG, "calc_graddet(it={itr}, key={key}) started")
@log_on_end(logging.DEBUG, "calc_graddet(it={itr}, key={key}) finished")
def calc_graddet(self, itr, key, nolensing=False):
def calc_graddet(self, itr, key):
assert self.is_iter_done(itr - 1, key)
assert itr > 0, itr
assert key in ['p'], key + ' not implemented'
Expand All @@ -757,40 +813,15 @@ def calc_graddet(self, itr, key, nolensing=False):
self.filter.set_ffi(ffi)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter)
t0 = time.time()
# q_geom = pbdGeometry(self.k_geom, pbounds(0., 2 * np.pi))
q_geom = self.filter.ffi.pbgeom
if self.sub_nolensing:
phas = self.filter.get_unit_variance()
else:
phas = None
G, C = self.filter.get_qlms_mf(self.mf_key, q_geom, mchain, phas=phas, cls_filt=self.cls_filt)
q_geom = pbdGeometry(self.k_geom, pbounds(0., 2 * np.pi))
G, C = self.filter.get_qlms_mf(self.mf_key, q_geom, mchain, cls_filt=self.cls_filt)
almxfl(G if key.lower() == 'p' else C, self._h2p(self.lmax_qlm), self.mmax_qlm, True)
log.info('get_qlm_mf calculation done; (%.0f secs)' % (time.time() - t0))

#NOTE + or - here ?
ret = G if key.lower() == 'p' else C

if self.sub_nolensing:
log.info('Subtracting MF with same phase but no lensing in likelihood')
dlm *= 0
self.hlm2dlm(dlm, True)
ffi = self.filter.ffi.change_dlm([dlm, None], self.mmax_qlm, cachers.cacher_mem(safe=False))
self.filter.set_ffi(ffi)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter)

q_geom = self.filter.ffi.pbgeom
G0, C0 = self.filter.get_qlms_mf(self.mf_key, q_geom, mchain, phas=phas, cls_filt=self.cls_filt)
almxfl(G0 if key.lower() == 'p' else C0, self._h2p(self.lmax_qlm), self.mmax_qlm, True)

ret -= G0 if key.lower() == 'p' else C0


if itr == 1: # We need the gradient at 0 and the yk's to be able to rebuild all gradients
fn_lik = '%slm_grad%sdet_it%03d' % (self.h, key.lower(), 0)
self.cacher.cache(fn_lik,ret)

return ret

self.cacher.cache(fn_lik, -G if key.lower() == 'p' else -C)
# !FIXME: This sign is probably wrong, as it should return +g^MF
return -G if key.lower() == 'p' else -C

class iterator_cstmf_bfgs0(iterator_cstmf):
"""Variant of the iterator where the initial curvature guess is itself a bfgs update from phi =0 to input plm
Expand Down
Loading
Loading