Skip to content

Commit

Permalink
support finer control of rtol in calculating CIs
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Nov 20, 2024
1 parent e394f87 commit 8801f91
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/elisa/infer/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def ci(
params: str | Iterable[str] | None = None,
fn: dict[str, Callable] | None = None,
method: Literal['profile', 'boot'] = 'profile',
rtol: float = 1e-3,
rtol: float | dict[str, float] = 1e-3,
parallel: bool = True,
) -> ConfidenceInterval:
"""Calculate confidence intervals.
Expand Down Expand Up @@ -613,9 +613,10 @@ def ci(
called before using this method.
The default is ``'profile'``.
rtol : float
The relative tolerance in determine the function value when
`method` is ``'profile'``. The default is 1e-3.
rtol : float, or dict of float, optional
The relative tolerance in determining the value of composite
parameters and `fn` when `method` is ``'profile'``.
The default is 1e-3.
parallel : bool, optional
Whether to evaluate `fn` in parallel when `method` is ``'boot'``.
The default is True.
Expand All @@ -625,9 +626,6 @@ def ci(
ConfidenceInterval
The confidence intervals.
"""
if rtol > 0.1:
raise ValueError('rtol must be less than 0.1')

cl = self._to_unit_cl(cl)
params = check_params(params, self._helper)
params_set = set(params)
Expand All @@ -638,6 +636,15 @@ def ci(
assert free | composite == params_set

fn = self._check_fn(fn)
rtol_keys = tuple(fn.keys()) + tuple(composite)
if isinstance(rtol, float):
rtol = {k: rtol for k in rtol_keys}
else:
rtol = jax.tree.map(float, dict(rtol))
for k in rtol_keys:
rtol.set_default[k] = 1e-3
if np.any([i > 0.1 for i in rtol.values()]):
raise ValueError('rtol must be less than 0.1')

if method == 'profile':
self._warn_invalid_fit()
Expand Down Expand Up @@ -734,15 +741,20 @@ def _ci_free(self, names: Iterable[str], cl: float | int):
}
return interval, status

def _ci_fn(self, fn: dict[str, Callable], cl: float | int, rtol=1e-3):
def _ci_fn(
self,
fn: dict[str, Callable],
cl: float | int,
rtol: dict[str, float],
):
"""Confidence intervals of function of free parameters."""
params_mle = {k: v[0] for k, v in self._mle.items()}
fn_mle = {k: v(params_mle) for k, v in fn.items()}

interval = {}
status = {}
for name, mle in fn_mle.items():
loss = self._loss_factory(fn[name], rtol * mle)
loss = self._loss_factory(fn[name], rtol[name] * mle)
grad = jax.jit(jax.grad(loss))
init = np.hstack([mle, self._minuit.values])
minuit = Minuit(loss, init, grad=grad)
Expand Down Expand Up @@ -947,7 +959,8 @@ def _(p):
return _

fn_dic = jax.tree.map(transform, fn_dic)
intervals, status = self._ci_fn(fn_dic, cl, rtol=1e-4)
rtol = {k: 1e-4 for k in fn_dic.keys()}
intervals, status = self._ci_fn(fn_dic, cl, rtol=rtol)
intervals = jax.tree.map(jnp.exp, intervals)
elif method == 'boot':
intervals, status = self._ci_boot(cl, [], fn_dic, True, params)
Expand Down

0 comments on commit 8801f91

Please sign in to comment.