Skip to content

Commit

Permalink
Update from_cmdstan converter to follow schema convention
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh-maheshwari committed Feb 7, 2021
1 parent 74c3d19 commit acc7535
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions arviz/data/io_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,16 @@ def sample_stats_to_xarray(self):
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}

sampler_params, sampler_params_warmup = self.sample_stats

rename_dict = {
"divergent__": "diverging",
"n_leapfrog__": "n_steps",
"treedepth__": "tree_depth",
}
for j, s_params in enumerate(sampler_params):
rename_dict = {}
for key in s_params:
key_, *end = key.split(".")
name = re.sub("__$", "", key_)
name = "diverging" if name == "divergent" else name
name = "n_steps" if name == "n_leapfrog" else name
name = "tree_depth" if name == "treedepth" else name
rename_dict[key] = ".".join((name, *end))
sampler_params[j][key] = s_params[key].astype(dtypes.get(key_))
sampler_params[j][key] = s_params[key].astype(dtypes.get(key))
sampler_params_warmup[j][key] = sampler_params_warmup[j][key].astype(
dtypes.get(key_)
dtypes.get(key)
)
sampler_params[j] = sampler_params[j].rename(columns=rename_dict)
sampler_params_warmup[j] = sampler_params_warmup[j].rename(columns=rename_dict)
Expand Down

0 comments on commit acc7535

Please sign in to comment.