Skip to content

Commit

Permalink
Fix bayeux after blackjax update.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634490737
  • Loading branch information
ColCarroll authored and The bayeux Authors committed May 16, 2024
1 parent 316ccd0 commit 40d1672
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
}


def _convert_algorithm(algorithm):
# Remove this after blackjax is stable
if hasattr(algorithm, "differentiable"):
return algorithm.differentiable
return algorithm


def get_extra_kwargs(kwargs):
defaults = {
"chain_method": "vectorized",
Expand All @@ -64,8 +71,8 @@ def get_kwargs(self, **kwargs):
adapt_fn, algorithm, constrained_log_density, extra_parameters | kwargs)
return {adapt_fn: adaptation_kwargs,
"adapt.run": run_kwargs,
algorithm: get_algorithm_kwargs(
algorithm, constrained_log_density, kwargs),
_convert_algorithm(algorithm): get_algorithm_kwargs(
_convert_algorithm(algorithm), constrained_log_density, kwargs),
"extra_parameters": extra_parameters}

def __call__(self, seed, **kwargs):
Expand Down Expand Up @@ -171,7 +178,7 @@ def _blackjax_inference(
(states, infos), adaptation_parameters
"""

algorithm_kwargs = kwargs[algorithm] | adapt_parameters
algorithm_kwargs = kwargs[_convert_algorithm(algorithm)] | adapt_parameters
inference_algorithm = algorithm(**algorithm_kwargs)
_, states, infos = blackjax.util.run_inference_algorithm(
rng_key=seed,
Expand Down Expand Up @@ -257,8 +264,8 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs):
adaptation_required.remove("algorithm")
adaptation_kwargs["algorithm"] = algorithm
adaptation_kwargs = (
get_algorithm_kwargs(algorithm, log_density, kwargs) | adaptation_kwargs
)
get_algorithm_kwargs(_convert_algorithm(algorithm), log_density, kwargs)
| adaptation_kwargs)

adaptation_required = adaptation_required - adaptation_kwargs.keys()

Expand Down

0 comments on commit 40d1672

Please sign in to comment.