Skip to content

Commit

Permalink
Add pymc example.
Browse files Browse the repository at this point in the history
This turned up a mistake in the TFP vi implementation, where the log density needs to be vmapped.

PiperOrigin-RevId: 599262704
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 17, 2024
1 parent b1277ff commit 3d5cc79
Show file tree
Hide file tree
Showing 3 changed files with 472 additions and 4 deletions.
10 changes: 7 additions & 3 deletions bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,13 @@ def backward(self, value, *inputs):
bwd = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values(inv_rvs))

def logp_wrap(args):
return logp([args[k] for k in names])
def logp_wrap(*args, **kwargs):
# This clause is only required because the tfp vi routine tries to
# pass dictionaries as keyword arguments, so this allows either
# log_density(params) or log_density(**params)
if args:
kwargs = args[0]
return logp([kwargs[k] for k in names])

def fwd_wrap(args):
ret = fwd(*[args[k] for k in names])
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/vi/tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_fit_kwargs(log_density, kwargs):
tfp.vi.fit_surrogate_posterior_stateless)
fit_kwargs.pop("seed")
fit_kwargs["optimizer"] = optax.adam(learning_rate=0.01)
fit_kwargs["target_log_prob_fn"] = log_density
fit_kwargs["target_log_prob_fn"] = jax.vmap(log_density)

fit_kwargs = {
"sample_size": 16,
Expand Down
Loading

0 comments on commit 3d5cc79

Please sign in to comment.