Skip to content

Commit

Permalink
Update from_cmdstan converter to follow schema convention (#1541)
Browse files Browse the repository at this point in the history
* Update from_cmdstan converter to follow schema convention

* Updated CHANGELOG.md
  • Loading branch information
utkarsh-maheshwari committed Feb 9, 2021
1 parent 6fa1ce8 commit 78ec5a0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Added confidence interval band to auto-correlation plot ([1535](https://github.com/arviz-devs/arviz/pull/1535))

### Maintenance and fixes
* Updated `from_cmdstan` and `from_numpyro` converter to follow schema convention ([1541](https://github.com/arviz-devs/arviz/pull/1541) and [1525](https://github.com/arviz-devs/arviz/pull/1525))

### Deprecation
* Removed Geweke diagnostic ([1545](https://github.com/arviz-devs/arviz/pull/1545))
Expand Down
22 changes: 12 additions & 10 deletions arviz/data/io_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,24 @@ def posterior_to_xarray(self):
def sample_stats_to_xarray(self):
"""Extract sample_stats from fit."""
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
rename_dict = {
"divergent": "diverging",
"n_leapfrog": "n_steps",
"treedepth": "tree_depth",
"stepsize": "step_size",
"accept_stat": "acceptance_rate",
}

sampler_params, sampler_params_warmup = self.sample_stats

for j, s_params in enumerate(sampler_params):
rename_dict = {}
for key in s_params:
key_, *end = key.split(".")
name = re.sub("__$", "", key_)
name = "diverging" if name == "divergent" else name
rename_dict[key] = ".".join((name, *end))
sampler_params[j][key] = s_params[key].astype(dtypes.get(key_))
sampler_params_warmup[j][key] = sampler_params_warmup[j][key].astype(
dtypes.get(key_)
name = re.sub("__$", "", key)
name = rename_dict.get(name, name)
sampler_params[j][name] = s_params[key].astype(dtypes.get(key))
sampler_params_warmup[j][name] = sampler_params_warmup[j][key].astype(
dtypes.get(key)
)
sampler_params[j] = sampler_params[j].rename(columns=rename_dict)
sampler_params_warmup[j] = sampler_params_warmup[j].rename(columns=rename_dict)
data = _unpack_dataframes(sampler_params)
data_warmup = _unpack_dataframes(sampler_params_warmup)
return (
Expand Down

0 comments on commit 78ec5a0

Please sign in to comment.