Skip to content

Commit

Permalink
fix: simulation number of boot/ppc (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve authored Aug 29, 2024
1 parent 936f3f3 commit 4f50114
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/elisa/infer/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,12 +663,16 @@ def boot(
update_rate : int, optional
The update rate of progress bar. The default is 50.
"""
n = int(n)
n_core = jax.local_device_count()
if parallel and (n % n_core):
n += n_core - n % n_core

# reuse the previous result if all setup is the same
if self._boot and self._boot.n == n and self._boot.seed == seed:
return

helper = self._helper
n = int(n)
seed = helper.seed['pred'] if seed is None else int(seed)
params = {i: self._mle[i][0] for i in helper.params_names['free']}
models = self._model_values
Expand Down Expand Up @@ -1415,13 +1419,17 @@ def ppc(
update_rate : int, optional
The update rate of progress bar. The default is 50.
"""
n = int(n)
n_core = jax.local_device_count()
if parallel and (n % n_core):
n += n_core - n % n_core

# reuse the previous result if all setup is the same
if self._ppc and self._ppc.n == n and self._ppc.seed == seed:
return

helper = self._helper
free_params = helper.params_names['free']
n = int(n)
seed = helper.seed['pred'] if seed is None else int(seed)

# randomly select n samples from posterior
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_trivial_max_like_fit():
assert result.dof == nbins - 1

# Check various methods of mle result
result.boot(1000)
result.boot(1009)
result.ci(method='boot')
result.flux(1, 2)
result.lumin(1, 10000, z=1)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_trivial_bayes_fit():
assert result.ndata['total'] == nbins
assert result.dof == nbins - 2

result.ppc(1000)
result.ppc(1009)
result.flux(1, 2)
result.lumin(1, 10000, z=1)
result.eiso(1, 10000, z=1, duration=spec_exposure)
Expand Down

0 comments on commit 4f50114

Please sign in to comment.