Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 5, 2024
1 parent f890d68 commit ed12c4d
Show file tree
Hide file tree
Showing 22 changed files with 357 additions and 366 deletions.
2 changes: 1 addition & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"
interval: "monthly"
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ python:
- method: pip
path: .
extra_requirements:
- docs
- docs
2 changes: 1 addition & 1 deletion docs/_templates/breadcrumbs.html
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{%- extends "sphinx_rtd_theme/breadcrumbs.html" %}

{% block breadcrumbs_aside %}
{% endblock %}
{% endblock %}
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'numpydoc.numpydoc'
'numpydoc.numpydoc',
]

templates_path = ['_templates']
Expand Down
6 changes: 3 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import nox

PYTHON_VERSIONS = ["3.9", "3.10", "3.11"]
PYTHON_VERSIONS = ['3.9', '3.10', '3.11']


@nox.session(python=PYTHON_VERSIONS)
def test(session: nox.Session) -> None:
session.install(".[test]")
session.run("pytest", "--cov-report=xml", "--cov=elisa", *session.posargs)
session.install('.[test]')
session.run('pytest', '--cov-report=xml', '--cov=elisa', *session.posargs)
55 changes: 23 additions & 32 deletions src/elisa/infer/analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Subsequent analysis of likelihood or Bayesian fit."""
from __future__ import annotations

from typing import Literal, NamedTuple, Optional, Sequence
from collections.abc import Sequence
from typing import Literal, NamedTuple, Optional

import arviz as az
import jax
Expand Down Expand Up @@ -43,11 +44,7 @@ class BootstrapResult(NamedTuple):
class MLEResult:
"""MLE result obtained from likelihood fit."""

def __init__(
self,
minuit: Minuit,
fit: _fit.LikelihoodFit
):
def __init__(self, minuit: Minuit, fit: _fit.LikelihoodFit):
self._minuit = minuit
self._helper = helper = fit._helper
self._free_names = free_names = fit._free_names
Expand Down Expand Up @@ -86,8 +83,8 @@ def __init__(
'deviance': {
'total': stat_total,
'group': stat_group,
'point': stat_info['point']
}
'point': stat_info['point'],
},
}

k = len(free_names)
Expand All @@ -101,7 +98,7 @@ def __init__(
def __repr__(self):
tab = make_pretty_table(
['Parameter', 'Value', 'Error'],
[(k, f'{v[0]:.4g}', f'{v[1]:.4g}') for k, v in self._mle.items()]
[(k, f'{v[0]:.4g}', f'{v[1]:.4g}') for k, v in self._mle.items()],
)
s = 'MLE:\n' + tab.get_string() + '\n'

Expand Down Expand Up @@ -244,10 +241,7 @@ def ci(

cl_ = 1.0 - 2.0 * norm.sf(cl) if cl >= 1.0 else cl

mle = {
k: v for k, v in self._result['params'].items()
if k in params
}
mle = {k: v for k, v in self._result['params'].items() if k in params}

helper = self._helper

Expand All @@ -257,8 +251,7 @@ def ci(
mle0 = self._minuit.values.to_dict()

others = { # set other unconstrained free parameter to mle
i: mle0[i]
for i in (set(mle0.keys()) - set(free_params))
i: mle0[i] for i in (set(mle0.keys()) - set(free_params))
}

ci = self._minuit.merrors
Expand All @@ -284,12 +277,13 @@ def ci(
# confidence interval of function of parameters,
# see, e.g. https://doi.org/10.1007/s11222-021-10012-y
for p in composite_params:

def loss(x):
"""The loss when calculating CI of composite parameter."""
unconstr = {k: v for k, v in zip(self._free_names, x[1:])}
p0 = helper.to_params_dict(unconstr)[p]
diff = (p0 / x[0] - 1) / 1e-3
return helper.deviance_unconstr(x[1:]) + diff*diff
return helper.deviance_unconstr(x[1:]) + diff * diff

mle_p = mle[p]

Expand Down Expand Up @@ -317,8 +311,10 @@ def loss(x):
else:
boot_result = self.boot(n=n)
interval = jax.tree_map(
lambda x: tuple(np.quantile(x, q=(0.5 - cl_/2, 0.5 + cl_/2))),
{k: v for k, v in boot_result.params.items() if k in params}
lambda x: tuple(
np.quantile(x, q=(0.5 - cl_ / 2, 0.5 + cl_ / 2))
),
{k: v for k, v in boot_result.params.items() if k in params},
)
error = {
k: (interval[k][0] - mle[k], interval[k][1] - mle[k])
Expand All @@ -327,7 +323,7 @@ def loss(x):
status = {
'n': boot_result.n,
'n_valid': boot_result.n_valid,
'seed': boot_result.seed
'seed': boot_result.seed,
}

else:
Expand All @@ -344,14 +340,11 @@ def format_result(v):
error=format_result(error),
cl=cl_,
method=method,
status=status
status=status,
)

def boot(
self,
n: int = 10000,
parallel: bool = True,
seed: Optional[int] = None
self, n: int = 10000, parallel: bool = True, seed: Optional[int] = None
) -> BootstrapResult:
"""Parametric bootstrap.
Expand Down Expand Up @@ -389,7 +382,7 @@ def boot(
n,
self._seed,
parallel,
run_str='Bootstrap'
run_str='Bootstrap',
)

boot_result = BootstrapResult(
Expand All @@ -401,7 +394,7 @@ def boot(
n=n,
n_valid=result['valid'].sum(),
seed=self._seed,
results=boot_result
results=boot_result,
)

self._boot = boot_result
Expand All @@ -418,6 +411,7 @@ def plot_corner(self):

class CredibleInterval(NamedTuple):
"""Credible interval result."""

mle: dict[str, float]
median: dict[str, float]
interval: dict[str, tuple[float, float]]
Expand All @@ -428,6 +422,7 @@ class CredibleInterval(NamedTuple):

class PPCResult(NamedTuple):
"""Posterior predictive check result."""

...


Expand All @@ -441,7 +436,7 @@ def __init__(
reff: float,
lnZ: tuple[float, float],
sampler,
fit: _fit.BayesianFit
fit: _fit.BayesianFit,
):
self._idata = idata
self._ess = ess
Expand Down Expand Up @@ -621,12 +616,8 @@ def ci(

...


def ppc(
self,
n: int = 10000,
parallel: bool = True,
seed: Optional[int] = None
self, n: int = 10000, parallel: bool = True, seed: Optional[int] = None
) -> PPCResult:
"""Perform posterior predictive check.
Expand Down
Loading

0 comments on commit ed12c4d

Please sign in to comment.