Skip to content

Commit

Permalink
implementing multigrid preconditioned CG
Browse files Browse the repository at this point in the history
  • Loading branch information
louisl3grand committed Aug 23, 2024
1 parent 422e680 commit 73ad5e5
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 51 deletions.
46 changes: 29 additions & 17 deletions delensalot/core/cg/cd_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def PTR(p, t, r):
return lambda i: max(0, i - max(p, int(min(t, np.mod(i, r)))))


tr_cg = (lambda i: i - 1)
tr_cd = (lambda i: 0)
tr_cg = (lambda i: i - 1) # Conjugate gradient
tr_cd = (lambda i: 0) # Conjugate descent ?


class cache_mem(dict):
Expand Down Expand Up @@ -49,57 +49,69 @@ def cd_solve(x, b, fwd_op, pre_ops, dot_op, criterion, tr, cache=cache_mem(), ro
Note:
fwd_op, pre_op(s) and dot_op must not modify their arguments!
LL comments:
Is the preconditioner updated in the loops?
Does the calls to pre_op will also make cd_solve?
Adding comments following notations of wikipedia article on Conjugate Gradient method.
"""

n_pre_ops = len(pre_ops)

# r_0 = b - A x_0, where A is the fwd operation, r_0 is the initial residual, and x_0 the inital guess
residual = b - fwd_op(x)
searchdirs = [op(residual) for op in pre_ops]

# z_0 = M^-1 r_0, where M is the preconditioner (dense or diag)
searchdirs = [op(residual) for op in pre_ops] # If the pre_op is multigrid, this will call cd_solve recursively
# p_0 = z_0, where p_0 is the initial search direction

iter = 0
while not criterion(iter, x, residual):
searchfwds = [fwd_op(searchdir) for searchdir in searchdirs]
deltas = [dot_op(searchdir, residual) for searchdir in searchdirs]
#TODO This combines all the preconditioned search directions into a single search direction ?
searchfwds = [fwd_op(searchdir) for searchdir in searchdirs] # A p_k
deltas = [dot_op(searchdir, residual) for searchdir in searchdirs] # \delta_{k} = r_k^T z_k

# calculate (D^T A D)^{-1}
# calculate (p_{k}^T A p_k)^{-1}
dTAd = np.zeros((n_pre_ops, n_pre_ops))
for ip1 in range(0, n_pre_ops):
for ip2 in range(0, ip1 + 1):
dTAd[ip1, ip2] = dTAd[ip2, ip1] = dot_op(searchdirs[ip1], searchfwds[ip2])
dTAd_inv = np.linalg.inv(dTAd)
dTAd[ip1, ip2] = dTAd[ip2, ip1] = dot_op(searchdirs[ip1], searchfwds[ip2]) # p_{k}^T A p_k
dTAd_inv = np.linalg.inv(dTAd) # (p_{k}^T A p_k)^{-1}

# search.
alphas = np.dot(dTAd_inv, deltas)
alphas = np.dot(dTAd_inv, deltas) # alpha_{k} = r_k^T z_k / (p_{k}^T A p_k)
for (searchdir, alpha) in zip(searchdirs, alphas):
x += searchdir * alpha
x += searchdir * alpha # x_{k+1} = x_k + \alpha_k p_k

# append to cache.
cache.store(iter, [dTAd_inv, searchdirs, searchfwds])

# update residual
iter += 1
if np.mod(iter, roundoff) == 0:
residual = b - fwd_op(x)
# In this case compute exact residual
residual = b - fwd_op(x) # r_{k+1 } = b - A x_{k+1}
else:
for (searchfwd, alpha) in zip(searchfwds, alphas):
residual -= searchfwd * alpha
residual -= searchfwd * alpha # r_{k+1} = r_k - \alpha_k A p_k

# initial choices for new search directions.
searchdirs = [pre_op(residual) for pre_op in pre_ops]
searchdirs = [pre_op(residual) for pre_op in pre_ops] # z_{k+1} = M^{-1} r_{k+1}

# orthogonalize w.r.t. previous searches.
prev_iters = range(tr(iter), iter)
# For CG we have only one previous search direction, but for CD we have multiple previous search directions.

for titer in prev_iters:
[prev_dTAd_inv, prev_searchdirs, prev_searchfwds] = cache.restore(titer)

for searchdir in searchdirs:
proj = [dot_op(searchdir, prev_searchfwd) for prev_searchfwd in prev_searchfwds]
betas = np.dot(prev_dTAd_inv, proj)
proj = [dot_op(searchdir, prev_searchfwd) for prev_searchfwd in prev_searchfwds] # z_{k+1}^T A p_k
betas = np.dot(prev_dTAd_inv, proj) # beta_{k} = z_{k+1}^T A p_k / (p_{k}^T A p_k)

for (beta, prev_searchdir) in zip(betas, prev_searchdirs):
searchdir -= prev_searchdir * beta
searchdir -= prev_searchdir * beta # p_{k+1} = z_{k+1} - \beta_k p_k

# clear old keys from cache
cache.trim(range(tr(iter + 1), iter))
Expand Down
48 changes: 27 additions & 21 deletions delensalot/core/cg/multigrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, ids, pre_ops_descr, lmax, nside, iter_max, eps_min, tr, cache


class multigrid_chain:
def __init__(self, opfilt, chain_descr, s_cls, n_inv_filt, debug_log_prefix=None, plogdepth=0):
def __init__(self, opfilt, chain_descr, s_cls, n_inv_filt, debug_log_prefix=None, plogdepth=0, no_lensing_precond=False):
self.debug_log_prefix = debug_log_prefix
self.plogdepth = plogdepth

Expand All @@ -31,14 +31,17 @@ def __init__(self, opfilt, chain_descr, s_cls, n_inv_filt, debug_log_prefix=None

self.s_cls = s_cls
self.n_inv_filt = n_inv_filt
self.no_lensing_precond = no_lensing_precond # Switch off lensing for the preconditioner estimates

stages = {}
for [id, pre_ops_descr, lmax, nside, iter_max, eps_min, tr, cache] in self.chain_descr:
print(f'creating multigrid stage: id = {id}, pre_ops_descr = {pre_ops_descr}, lmax = {lmax}, nside = {nside}, iter_max = {iter_max}, eps_min = {eps_min:.2e}')
stages[id] = multigrid_stage(id, pre_ops_descr, lmax, nside, iter_max, eps_min, tr, cache)
for pre_op_descr in pre_ops_descr: # recursively add all stages to stages[0]
print('adding pre_op: ', pre_op_descr)
stages[id].pre_ops.append(parse_pre_op_descr(pre_op_descr, opfilt=self.opfilt,
s_cls=self.s_cls, n_inv_filt=self.n_inv_filt,
stages=stages, lmax=lmax, nside=nside, chain=self))
stages=stages, lmax=lmax, nside=nside, chain=self, no_lensing=self.no_lensing_precond))
self.bstage = stages[0] # these are the pre_ops called in cd_solve

def solve(self, soltn, tpn_map, apply_fini='', dot_op=None):
Expand Down Expand Up @@ -119,28 +122,31 @@ def parse_pre_op_descr(pre_op_descr, **kwargs):
kwargs_low = copy.copy(kwargs);
kwargs_low['lmax'] = lsplit
kwargs_hgh = copy.copy(kwargs);
kwargs_hgh['lmin'] = lsplit + 1
kwargs_hgh['lmin'] = lsplit + 1 # FIXME: this is never used ?
pre_op_low = parse_pre_op_descr(low_descr, **kwargs_low)
pre_op_hgh = parse_pre_op_descr(hgh_descr, **kwargs_hgh)

return pre_op_split(lsplit, kwargs['lmax'], pre_op_low, pre_op_hgh)

elif re.match("diag_cl\Z", pre_op_descr):
return kwargs['opfilt'].pre_op_diag(kwargs['s_cls'], kwargs['n_inv_filt'])
elif re.match("dense\Z", pre_op_descr):
# FIXME: remove this option in favor of dense() below.
print('creating dense preconditioner. (nside = %d, lmax = %d)' % (kwargs['nside'], kwargs['lmax']))

fwd_op = kwargs['opfilt'].fwd_op(kwargs['s_cls'], kwargs['n_inv_filt'].degrade(kwargs['nside']))

return kwargs['opfilt'].pre_op_dense(kwargs['lmax'], fwd_op)
# TODO: the lmin, which defines the min ell below which it is the dense block, is never used ?
return kwargs['opfilt'].pre_op_diag(kwargs['s_cls'], kwargs['n_inv_filt'].degrade(kwargs['nside'], kwargs['lmax'], kwargs['lmax'], set_deflection_to_zero=kwargs['no_lensing']))

# elif re.match("dense\Z", pre_op_descr):
# # FIXME: remove this option in favor of dense() below.
# print('creating dense preconditioner. (nside = %d, lmax = %d)' % (kwargs['nside'], kwargs['lmax']))
# fwd_op = kwargs['opfilt'].fwd_op(kwargs['s_cls'], kwargs['n_inv_filt'].degrade(kwargs['nside'], kwargs['lmax'], kwargs['lmax']))
# return kwargs['opfilt'].pre_op_dense(kwargs['lmax'], fwd_op)

elif re.match("dense\((.*)\)\Z", pre_op_descr):
(dense_cache_fname,) = re.match("dense\((.*)\)\Z", pre_op_descr).groups()
if dense_cache_fname == '': dense_cache_fname = None

print('creating dense preconditioner. (nside = %d, lmax = %d, cache = %s)' % (
kwargs['nside'], kwargs['lmax'], dense_cache_fname))
fwd_op = kwargs['opfilt'].fwd_op(kwargs['s_cls'], kwargs['n_inv_filt'].degrade(kwargs['nside']))
fwd_op = kwargs['opfilt'].fwd_op(kwargs['s_cls'], kwargs['n_inv_filt'].degrade(kwargs['nside'], kwargs['lmax'], kwargs['lmax'], set_deflection_to_zero=kwargs['no_lensing']))
return kwargs['opfilt'].pre_op_dense(kwargs['lmax'], fwd_op, cache_fname=dense_cache_fname)

elif re.match("stage\(.*\)\Z", pre_op_descr):
(stage_id,) = re.match("stage\((.*)\)\Z", pre_op_descr).groups()
print('creating multigrid preconditioner: stage_id = ', stage_id)
Expand All @@ -150,9 +156,9 @@ def parse_pre_op_descr(pre_op_descr, **kwargs):
chain.log(stage, iter, eps, **kwargs))

assert (stage.lmax == kwargs['lmax'])

#TODO: Check if should be with no_lensing here ?
return pre_op_multigrid(kwargs['opfilt'], stage.lmax, stage.nside,
kwargs['s_cls'], kwargs['n_inv_filt'].degrade(stage.nside),
kwargs['s_cls'], kwargs['n_inv_filt'].degrade(stage.nside, stage.lmax, stage.lmax, set_deflection_to_zero=kwargs['no_lensing']),
stage.pre_ops, logger, stage.tr, stage.cache,
stage.iter_max, stage.eps_min)
else:
Expand All @@ -163,7 +169,6 @@ class pre_op_split:
def __init__(self, lsplit, lmax, pre_op_low, pre_op_hgh):
self.lsplit = lsplit
self.lmax = lmax

self.pre_op_low = pre_op_low
self.pre_op_hgh = pre_op_hgh

Expand All @@ -175,8 +180,8 @@ def __call__(self, talm):
def calc(self, talm):
self.iter += 1

talm_low = self.pre_op_low(alm_copy(talm, lmax=self.lsplit))
talm_hgh = self.pre_op_hgh(alm_copy(talm, lmax=self.lmax))
talm_low = self.pre_op_low(alm_copy(talm, None, lmaxout=self.lsplit, mmaxout=self.lsplit))
talm_hgh = self.pre_op_hgh(alm_copy(talm, None, lmaxout=self.lmax, mmaxout=self.lmax))

return alm_splice(talm_low, talm_hgh, self.lsplit)

Expand Down Expand Up @@ -205,10 +210,11 @@ def __call__(self, talm):
return self.calc(talm)

def calc(self, talm):
monitor = cd_monitors.monitor_basic(self.opfilt.dot_op(),
monitor = cd_monitors.monitor_basic(self.opfilt.dot_op(self.lmax, None),
iter_max=self.iter_max, eps_min=self.eps_min, logger=self.logger)
soltn = talm * 0.0
cd_solve.cd_solve(soltn, alm_copy(talm, lmax=self.lmax),
self.fwd_op, self.pre_ops, self.opfilt.dot_op(), monitor, tr=self.tr, cache=self.cache)
soltn = talm * 0.0
#TODO: Is zero the best guess we have here ? Cannot we use previous estimates, or talm itself ?
cd_solve.cd_solve(soltn, alm_copy(talm, None, lmaxout=self.lmax, mmaxout=self.lmax),
self.fwd_op, self.pre_ops, self.opfilt.dot_op(self.lmax, None), monitor, tr=self.tr, cache=self.cache)

return alm_splice(soltn, talm, self.lmax)
12 changes: 9 additions & 3 deletions delensalot/core/iterator/cs_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
k_geom:utils_geom.Geom,
chain_descr, stepper:steps.nrstep,
logger=None,
NR_method=100, tidy=0, verbose=True, soltn_cond=True, wflm0=None, _usethisE=None):
NR_method=100, tidy=0, verbose=True, soltn_cond=True, wflm0=None, _usethisE=None,
no_lensing_precond=False):
"""Lensing map iterator
The bfgs hessian updates are called 'hlm's and are either in plm, dlm or klm space
Expand All @@ -71,7 +72,7 @@ def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
k_geom: lenspyx geometry for once-per-iterations operations (like checking for invertibility etc, QE evals...)
stepper: custom calculation of NR-step
wflm0(optional): callable with Wiener-filtered CMB map search starting point
no_lensing_precond: if True, the preconditioner will not include lensing
"""
assert h in ['k', 'p', 'd']
lmax_qlm, mmax_qlm = lm_max_dlm
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(self, lib_dir:str, h:str, lm_max_dlm:tuple,
self.logger.startup(self)

self._usethisE = _usethisE
self.no_lensing_precond = no_lensing_precond

def _p2h(self, lmax):
if self.h == 'p':
Expand Down Expand Up @@ -465,6 +467,10 @@ def iterate(self, itr, key):
def calc_gradlik(self, itr, key, iwantit=False):
"""Computes the quadratic part of the gradient for plm iteration 'itr'
Compared to formalism of the papers, this returns -g_LM^{QD}
Args:
itr: iteration index
key: 'p' or 'o'
iwantit: if True, forces the calculation of the gradient and return it
"""
assert self.is_iter_done(itr - 1, key)
assert itr > 0, itr
Expand All @@ -475,7 +481,7 @@ def calc_gradlik(self, itr, key, iwantit=False):
self.hlm2dlm(dlm, True)
ffi = self.filter.ffi.change_dlm([dlm, None], self.mmax_qlm, cachers.cacher_mem(safe=False))
self.filter.set_ffi(ffi)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter)
mchain = multigrid.multigrid_chain(self.opfilt, self.chain_descr, self.cls_filt, self.filter, no_lensing_precond=self.no_lensing_precond)
if self._usethisE is not None:
if callable(self._usethisE):
log.info("iterator: using custom WF E")
Expand Down
64 changes: 60 additions & 4 deletions delensalot/core/opfilt/MAP_opfilt_aniso_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@

from delensalot.core.opfilt import MAP_opfilt_iso_t
from plancklens.sims import phas
import healpy as hp
from lenspyx.remapping import deflection
from plancklens.qcinv import template_removal

pre_op_dense = None # not implemented
# pre_op_dense = None # not implemented
from plancklens.qcinv.opfilt_tt import pre_op_dense

#FIXME: This is the same fw_op as the QE, is it correct ?
fwd_op = MAP_opfilt_iso_t.fwd_op


def apply_fini(*args, **kwargs):
"""cg-inversion post-operation
Expand All @@ -34,7 +41,8 @@ def apply_fini(*args, **kwargs):
class alm_filter_ninv_wl(opfilt_base.alm_filter_wl):
def __init__(self, ninv_geom:utils_geom.Geom, ninv: np.ndarray, ffi:remapping.deflection, transf:np.ndarray,
unlalm_info:tuple, lenalm_info:tuple, sht_threads:int,verbose=False,
lmin_dotop=0, tpl:tni.template_tfilt or None =None, rescal=None):
lmin_dotop=0, tpl:tni.template_tfilt or None =None,
marge_monopole = False, marge_dipole = False, rescal=None):
r"""CMB inverse-variance and Wiener filtering instance, using unlensed E and lensing deflection
Args:
Expand Down Expand Up @@ -79,9 +87,17 @@ def __init__(self, ninv_geom:utils_geom.Geom, ninv: np.ndarray, ffi:remapping.de

self.template = tpl

self.marge_monopole = marge_monopole
self.marge_dipole = marge_dipole

self.templates = [self.template]

if marge_monopole: self.templates.append(template_removal.template_monopole())
if marge_dipole: self.templates.append(template_removal.template_dipole())

def hashdict(self):
return {'ninv':self._ninv_hash(), 'transf':clhash(self.b_transf_tlm),
'deflection':self.ffi.hashdict(),
# 'deflection':self.ffi.hashdict(), #TODO: Deflection hashdict is not implemented
'unalm':(self.lmax_sol, self.mmax_sol), 'lenalm':(self.lmax_len, self.mmax_len) }

def _ninv_hash(self):
Expand All @@ -96,8 +112,48 @@ def get_ftl(self):
n_inv_cl_t = self.b_transf_tlm ** 2 / (self._nlevt / 180. / 60. * np.pi) ** 2
return n_inv_cl_t

def degrade(self, nside, lmax, mmax, set_deflection_to_zero=True):
"""Reproducing plancklens function, useful for multigrid preconditioner
# TODO: Check if this matches the Lensit implementation
"""
if nside == hp.npix2nside(len(self.n_inv)) and set_deflection_to_zero is False:
return self
else:
print(f"MAP OPFILT ANISO T: Degrading filtered maps to nside:{nside}, lmax:{lmax}")

if set_deflection_to_zero is True:
print("Setting deflection to zero")
_ffi = deflection(utils_geom.Geom.get_healpix_geometry(nside), np.zeros(hp.Alm.getsize(lmax)), mmax,
numthreads=self.sht_threads, verbosity=0, single_prec=False, epsilon=self.ffi.epsilon)
else:
print(f"Using the same deflection, rescaled to the new nside {nside}")
dlm = alm_copy(self.ffi.dlm, None, lmax, mmax)
if self.ffi.dclm is not None:
dclm = alm_copy(self.ffi.dclm, None, lmax, mmax)
else:
dclm = None
_ffi = deflection(utils_geom.Geom.get_healpix_geometry(nside), dlm, mmax,
dclm=dclm, numthreads=self.ffi.sht_tr,
verbosity=self.ffi.verbosity, single_prec=self.ffi.single_prec, epsilon=self.ffi.epsilon)

# tpl = tni.template_tfilt(
# self.template.lmax,
# geom=utils_geom.Geom.get_healpix_geometry(nside),
# sht_threads=self.template.sht_threads)

return alm_filter_ninv_wl(
utils_geom.Geom.get_healpix_geometry(nside),
hp.ud_grade(self.n_inv, nside, power=-2),
_ffi, self.b_transf_tlm,
(lmax, mmax), (lmax, mmax), self.sht_threads, self.verbose,
lmin_dotop=self.lmin_dotop, tpl=self.template,
marge_monopole = self.marge_monopole,
marge_dipole = self.marge_dipole,
rescal = cli(self.rescali)[:lmax+1])


def dot_op(self):
return dot_op(self.lmax_sol, self.mmax_sol, lmin=self.lmin_dotop)
return dot_op(lmax=self.lmax_sol, mmax=self.mmax_sol, lmin=self.lmin_dotop)

def apply_map(self, tmap):
"""Applies pixel inverse-noise variance maps
Expand Down
1 change: 1 addition & 0 deletions delensalot/core/opfilt/QE_opfilt_iso_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class fwd_op:
def __init__(self, s_cls:dict, ninv_filt:alm_filter_nlev):
self.icls = {'tt': cli(s_cls['tt'][:ninv_filt.lmax_sol + 1]) * ninv_filt.rescali ** 2}
self.ninv_filt = ninv_filt
self.n_inv_filt = self.ninv_filt
self.lmax_sol = ninv_filt.lmax_sol
self.mmax_sol = ninv_filt.mmax_sol

Expand Down
2 changes: 1 addition & 1 deletion delensalot/core/opfilt/tmodes_ninv.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def dot(self, tmap):
This includes a factor npix / 4pi, as the transpose differs from the inverse by that factor
"""
assert tmap.size == self.npix
assert tmap.size == self.npix, (tmap.size, self.npix)
tlm = self.geom.adjoint_synthesis(tmap, 0, self.lmax, self.lmax, self.sht_threads, apply_weights=False).squeeze()
return self._blm2rlm(tlm) # Units weight transform

Expand Down
8 changes: 5 additions & 3 deletions delensalot/utility/cpp_sims.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ def libdir_sim(self, simidx, tol=None, eps=None):
def get_itlib_sim(self, simidx, tol=None, eps=None):
if tol is None: tol = self.tol
if eps is None: eps = self.eps
tol_iter = 10 ** (- tol)
epsilon = 10**(-eps)
return self.param.get_itlib(self.k, simidx, self.version, cg_tol=tol_iter, epsilon=epsilon)
# tol_iter = 10 ** (- tol)
# epsilon = 10**(-eps)
# return self.param.get_itlib(self.k, simidx, self.version, cg_tol=tol_iter, epsilon=epsilon)
# qe_key:str, simidx:int, version:str, qe_version:str, tol:float, epsilon=5, nbump=0, rscal=0, verbose=False, numthreads=0
return self.param.get_itlib(self.k, simidx, self.version, self.qe_version, tol=tol, epsilon=eps)

def cacher_sim(self, simidx, verbose=False):
if self.cache_in_home is False:
Expand Down
Loading

0 comments on commit 73ad5e5

Please sign in to comment.