Skip to content

aesara-devs/aehmc

Repository files navigation

Aehmc

Pypi Gitter Discord Twitter

AeHMC provides implementations for the HMC and NUTS samplers in Aesara.

FeaturesGet StartedInstallGet helpContribute

Get started

import aesara
from aesara import tensor as at
from aesara.tensor.random.utils import RandomStream

from aeppl import joint_logprob

from aehmc import nuts

# A simple normal distribution
Y_rv = at.random.normal(0, 1)


def logprob_fn(y):
    return joint_logprob(realized={Y_rv: y})[0]


# Build the transition kernel
srng = RandomStream(seed=0)
kernel = nuts.new_kernel(srng, logprob_fn)

# Compile a function that updates the chain
y_vv = Y_rv.clone()
initial_state = nuts.new_state(y_vv, logprob_fn)

step_size = at.as_tensor(1e-2)
inverse_mass_matrix=at.as_tensor(1.0)
chain_info, updates = kernel(initial_state, step_size, inverse_mass_matrix)

next_step_fn = aesara.function([y_vv], chain_info.state.position, updates=updates)

print(next_step_fn(0))
# 1.1034719409361107

Install

The latest release of AeHMC can be installed from PyPI using pip:

pip install aehmc

Or via conda-forge:

conda install -c conda-forge aehmc

The current development branch of AeHMC can be installed from GitHub using pip:

pip install git+https://github.com/aesara-devs/aehmc

Get help

Report bugs by opening an issue. If you have a question regarding the usage of AeHMC, start a discussion. For real-time feedback or more general chat about AeHMC use our Discord server or Gitter room.

Contribute

AeHMC welcomes contributions. A good place to start contributing is by looking at the issues.

If you want to implement a new feature, open a discussion or come chat with us on Discord or Gitter.