From 9b599d7c47a217cde3c83ab021707135789f72df Mon Sep 17 00:00:00 2001 From: "W.-C. Xue" <58248583+wcxve@users.noreply.github.com> Date: Mon, 23 Dec 2024 03:59:22 +0800 Subject: [PATCH] fix: improve precision of function CIs (#151) --- src/elisa/infer/results.py | 54 +++++++++++++++++++++++++++----------- tests/test_results.py | 2 +- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/elisa/infer/results.py b/src/elisa/infer/results.py index 26cd2ee..3a041f1 100644 --- a/src/elisa/infer/results.py +++ b/src/elisa/infer/results.py @@ -583,7 +583,7 @@ def ci( params: str | Iterable[str] | None = None, fn: dict[str, Callable] | None = None, method: Literal['profile', 'boot'] = 'profile', - rtol: float | dict[str, float] = 1e-3, + rtol: float | dict[str, float] = 1e-6, parallel: bool = True, ) -> ConfidenceInterval: """Calculate confidence intervals. @@ -616,7 +616,7 @@ def ci( rtol : float, or dict of float, optional The relative tolerance in determining the value of composite parameters and `fn` when `method` is ``'profile'``. - The default is 1e-3. + The default is 1e-6. parallel : bool, optional Whether to evaluate `fn` in parallel when `method` is ``'boot'``. The default is True. @@ -642,9 +642,9 @@ def ci( else: rtol = jax.tree.map(float, dict(rtol)) for k in rtol_keys: - rtol.setdefault(k, 1e-3) - if np.any([i > 0.1 for i in rtol.values()]): - raise ValueError('rtol must be less than 0.1') + rtol.setdefault(k, 1e-6) + if np.any([i > 0.01 for i in rtol.values()]): + raise ValueError('rtol must be less than 0.01') if method == 'profile': self._warn_invalid_fit() @@ -751,15 +751,37 @@ def _ci_fn( params_mle = {k: v[0] for k, v in self._mle.items()} fn_mle = {k: v(params_mle) for k, v in fn.items()} - interval = {} - status = {} - for name, mle in fn_mle.items(): - loss = self._loss_factory(fn[name], rtol[name] * mle) + def get_minuit(name, mle, r) -> Minuit: + loss = self._loss_factory(fn[name], r) grad = jax.jit(jax.grad(loss)) init = np.hstack([mle, self._minuit.values]) minuit = Minuit(loss, init, grad=grad) minuit.strategy = 2 minuit.migrad() + return minuit + + def get_minuit_iter_rtol(name, mle) -> Minuit: + rtol_desired = rtol[name] + minuit0 = get_minuit(name, mle, rtol_desired) + if not minuit0.accurate: + # When profiling the likelihood, deviance difference for + # 1-sigma confidence interval is 1, thus the deviance + # varies slow within 1-sigma confidence interval. + # The rtol should be less than 1% of the variance. + fn_var = self.covar(params=(), fn={'_': fn[name]}).matrix[0, 0] + rel_err = np.sqrt(fn_var) / np.abs(mle) + rtol_max = np.min([0.01, 0.01 * rel_err, 100 * rtol_desired]) + if rtol_desired < rtol_max: + for r in np.geomspace(rtol_desired, rtol_max, num=15)[1:]: + minuit = get_minuit(name, mle, r) + if minuit.accurate: + return minuit + return minuit0 + + interval = {} + status = {} + for name, mle in fn_mle.items(): + minuit = get_minuit_iter_rtol(name, mle) minuit.minos(0, cl=cl) ci = minuit.merrors[0] interval[name] = (mle + ci.lower, mle + ci.upper) @@ -772,15 +794,15 @@ def _ci_fn( return interval, status - def _loss_factory(self, fn: Callable, atol: float): + def _loss_factory(self, fn: Callable, rtol: float): """Factory method to create joint loss of params and func of params. Parameters ---------- fn : Callable Function accepts model parameters and outputs a single value. - atol : float - Absolute tolerance of the function value. + rtol : float + Relative tolerance of the function value. References ---------- @@ -789,7 +811,6 @@ def _loss_factory(self, fn: Callable, atol: float): """ helper = self._helper params_free = helper.params_names['free'] - atol_inv = 1.0 / atol @jax.jit def loss(x: np.ndarray): @@ -797,8 +818,9 @@ def loss(x: np.ndarray): unconstr_dic = dict(zip(params_free, x[1:])) params = helper.unconstr_dic_to_params_dic(unconstr_dic) fn_value = fn(params) - s = (fn_value - x[0]) * atol_inv - return helper.deviance_total(x[1:]) + s * s + s1 = fn_value / x[0] - 1.0 + s2 = (s1 * s1) / rtol + return helper.deviance_total(x[1:]) + s2 return loss @@ -959,7 +981,7 @@ def _(p): return _ fn_dic = jax.tree.map(transform, fn_dic) - rtol = {k: 1e-4 for k in fn_dic.keys()} + rtol = {k: 1e-8 for k in fn_dic.keys()} intervals, status = self._ci_fn(fn_dic, cl, rtol=rtol) intervals = jax.tree.map(jnp.exp, intervals) elif method == 'boot': diff --git a/tests/test_results.py b/tests/test_results.py index 3bc81be..f57b132 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -94,7 +94,7 @@ def test_mle_ci_fn(mle_result2, powerlaw_flux): params['PowerLaw.alpha'], params['PowerLaw.K'], emin, emax ) - ci0 = result.ci(params=[], fn={'fn': fn}, rtol={'fn': 5e-4}) + ci0 = result.ci(params=[], fn={'fn': fn}, rtol={'fn': 1e-10}) ci1 = result.ci(params=[], fn={'fn': fn}) ci2 = result.ci(params=[], fn={'fn': fn}, method='boot') ci3 = result.ci(params=[], fn={'fn': fn}, method='boot', parallel=False)