Skip to content

Commit

Permalink
fix: improve precision of function CIs (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve authored Dec 22, 2024
1 parent f0fb092 commit 9b599d7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
54 changes: 38 additions & 16 deletions src/elisa/infer/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def ci(
params: str | Iterable[str] | None = None,
fn: dict[str, Callable] | None = None,
method: Literal['profile', 'boot'] = 'profile',
rtol: float | dict[str, float] = 1e-3,
rtol: float | dict[str, float] = 1e-6,
parallel: bool = True,
) -> ConfidenceInterval:
"""Calculate confidence intervals.
Expand Down Expand Up @@ -616,7 +616,7 @@ def ci(
rtol : float, or dict of float, optional
The relative tolerance in determining the value of composite
parameters and `fn` when `method` is ``'profile'``.
The default is 1e-3.
The default is 1e-6.
parallel : bool, optional
Whether to evaluate `fn` in parallel when `method` is ``'boot'``.
The default is True.
Expand All @@ -642,9 +642,9 @@ def ci(
else:
rtol = jax.tree.map(float, dict(rtol))
for k in rtol_keys:
rtol.setdefault(k, 1e-3)
if np.any([i > 0.1 for i in rtol.values()]):
raise ValueError('rtol must be less than 0.1')
rtol.setdefault(k, 1e-6)
if np.any([i > 0.01 for i in rtol.values()]):
raise ValueError('rtol must be less than 0.01')

if method == 'profile':
self._warn_invalid_fit()
Expand Down Expand Up @@ -751,15 +751,37 @@ def _ci_fn(
params_mle = {k: v[0] for k, v in self._mle.items()}
fn_mle = {k: v(params_mle) for k, v in fn.items()}

interval = {}
status = {}
for name, mle in fn_mle.items():
loss = self._loss_factory(fn[name], rtol[name] * mle)
def get_minuit(name, mle, r) -> Minuit:
loss = self._loss_factory(fn[name], r)
grad = jax.jit(jax.grad(loss))
init = np.hstack([mle, self._minuit.values])
minuit = Minuit(loss, init, grad=grad)
minuit.strategy = 2
minuit.migrad()
return minuit

def get_minuit_iter_rtol(name, mle) -> Minuit:
rtol_desired = rtol[name]
minuit0 = get_minuit(name, mle, rtol_desired)
if not minuit0.accurate:
# When profiling the likelihood, deviance difference for
# 1-sigma confidence interval is 1, thus the deviance
# varies slow within 1-sigma confidence interval.
# The rtol should be less than 1% of the variance.
fn_var = self.covar(params=(), fn={'_': fn[name]}).matrix[0, 0]
rel_err = np.sqrt(fn_var) / np.abs(mle)
rtol_max = np.min([0.01, 0.01 * rel_err, 100 * rtol_desired])
if rtol_desired < rtol_max:
for r in np.geomspace(rtol_desired, rtol_max, num=15)[1:]:
minuit = get_minuit(name, mle, r)
if minuit.accurate:
return minuit
return minuit0

interval = {}
status = {}
for name, mle in fn_mle.items():
minuit = get_minuit_iter_rtol(name, mle)
minuit.minos(0, cl=cl)
ci = minuit.merrors[0]
interval[name] = (mle + ci.lower, mle + ci.upper)
Expand All @@ -772,15 +794,15 @@ def _ci_fn(

return interval, status

def _loss_factory(self, fn: Callable, atol: float):
def _loss_factory(self, fn: Callable, rtol: float):
"""Factory method to create joint loss of params and func of params.
Parameters
----------
fn : Callable
Function accepts model parameters and outputs a single value.
atol : float
Absolute tolerance of the function value.
rtol : float
Relative tolerance of the function value.
References
----------
Expand All @@ -789,16 +811,16 @@ def _loss_factory(self, fn: Callable, atol: float):
"""
helper = self._helper
params_free = helper.params_names['free']
atol_inv = 1.0 / atol

@jax.jit
def loss(x: np.ndarray):
"""Joint loss of params and func of params."""
unconstr_dic = dict(zip(params_free, x[1:]))
params = helper.unconstr_dic_to_params_dic(unconstr_dic)
fn_value = fn(params)
s = (fn_value - x[0]) * atol_inv
return helper.deviance_total(x[1:]) + s * s
s1 = fn_value / x[0] - 1.0
s2 = (s1 * s1) / rtol
return helper.deviance_total(x[1:]) + s2

return loss

Expand Down Expand Up @@ -959,7 +981,7 @@ def _(p):
return _

fn_dic = jax.tree.map(transform, fn_dic)
rtol = {k: 1e-4 for k in fn_dic.keys()}
rtol = {k: 1e-8 for k in fn_dic.keys()}
intervals, status = self._ci_fn(fn_dic, cl, rtol=rtol)
intervals = jax.tree.map(jnp.exp, intervals)
elif method == 'boot':
Expand Down
2 changes: 1 addition & 1 deletion tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_mle_ci_fn(mle_result2, powerlaw_flux):
params['PowerLaw.alpha'], params['PowerLaw.K'], emin, emax
)

ci0 = result.ci(params=[], fn={'fn': fn}, rtol={'fn': 5e-4})
ci0 = result.ci(params=[], fn={'fn': fn}, rtol={'fn': 1e-10})
ci1 = result.ci(params=[], fn={'fn': fn})
ci2 = result.ci(params=[], fn={'fn': fn}, method='boot')
ci3 = result.ci(params=[], fn={'fn': fn}, method='boot', parallel=False)
Expand Down

0 comments on commit 9b599d7

Please sign in to comment.