Skip to content

Commit

Permalink
use chain serial rather than logical index with _DefaultTrace.insert (#…
Browse files Browse the repository at this point in the history
…1590)

* use chain serial rather than logical index with _DefaultTrace.insert

* add PR #1590 to changelog
  • Loading branch information
Spaak committed Mar 3, 2021
1 parent db61a6d commit c562fad
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Fix pareto k threshold typo in reloo function ([1580](https://github.com/arviz-devs/arviz/pull/1580))
* Preserve shape from Stan code in `from_cmdstanpy` ([1579](https://github.com/arviz-devs/arviz/pull/1579))
* Correctly use chain index when constructing PyMC3 `DefaultTrace` in `from_pymc3` ([1590](https://github.com/arviz-devs/arviz/pull/1590))

### Deprecation
* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
4 changes: 2 additions & 2 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,12 @@ def _extract_log_likelihood(self, trace):
"`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`."
) from err
for var, log_like_fun in cached:
for chain in trace.chains:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), chain)
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
return log_likelihood_dict.trace_dict

@requires("trace")
Expand Down

0 comments on commit c562fad

Please sign in to comment.