Skip to content

Commit

Permalink
Use float dtype for numpyro (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
kp992 authored Nov 5, 2024
1 parent 12d2294 commit 1ead0a6
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions lectures/bayes_nonconj.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
---
jupytext:
text_representation:
extension: .myst
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.13.8
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -43,7 +43,6 @@ The two Python modules are

As usual, we begin by importing some Python code.


```{code-cell} ipython3
:tags: [hide-output]
Expand Down Expand Up @@ -80,10 +79,8 @@ from numpyro.infer import SVI as nSVI
from numpyro.infer import ELBO as nELBO
from numpyro.infer import Trace_ELBO as nTrace_ELBO
from numpyro.optim import Adam as nAdam
```


## Unleashing MCMC on a Binomial Likelihood

This lecture begins with the binomial example in the {doc}`quantecon lecture <prob_meaning>`.
Expand Down Expand Up @@ -252,7 +249,6 @@ We will use the following priors:

- The truncated Laplace can be created using `Numpyro`'s `TruncatedDistribution` class.


```{code-cell} ipython3
# used by Numpyro
def TruncatedLogNormal_trans(loc, scale):
Expand Down Expand Up @@ -560,19 +556,17 @@ class BayesianInference:
Computes numerically the posterior distribution with beta prior parametrized by (alpha0, beta0)
given data using MCMC
"""
# tensorize
data = torch.tensor(data)
# use pyro
if self.solver=='pyro':
# tensorize
data = torch.tensor(data)
nuts_kernel = NUTS(self.model)
mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=num_warmup, disable_progbar=True)
mcmc.run(data)
# use numpyro
elif self.solver=='numpyro':
data = np.array(data, dtype=float)
nuts_kernel = nNUTS(self.model)
mcmc = nMCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, progress_bar=False)
mcmc.run(self.rng_key, data=data)
Expand Down Expand Up @@ -655,15 +649,15 @@ class BayesianInference:
params : the learned parameters for guide
losses : a vector of loss at each step
"""
# tensorize data
if not torch.is_tensor(data):
data = torch.tensor(data)
# initiate SVI
svi = self.SVI_init(guide_dist=guide_dist)
# do gradient steps
if self.solver=='pyro':
# tensorize data
if not torch.is_tensor(data):
data = torch.tensor(data)
# store loss vector
losses = np.zeros(n_steps)
for step in range(n_steps):
Expand All @@ -676,6 +670,7 @@ class BayesianInference:
}
elif self.solver=='numpyro':
data = np.array(data, dtype=float)
result = svi.run(self.rng_key, n_steps, data, progress_bar=False)
params = dict(
(key, np.asarray(value)) for key, value in result.params.items()
Expand Down Expand Up @@ -898,7 +893,6 @@ For the same Beta prior, we shall
Let's start with the analytical method that we described in this quantecon lecture <https://python.quantecon.org/prob_meaning.html>
```{code-cell} ipython3
# First examine Beta priors
BETA_pyro = BayesianInference(param=(5,5), name_dist='beta', solver='pyro')
Expand Down Expand Up @@ -952,12 +946,10 @@ will be more accurate, as we shall see next.
(Increasing the step size increases computational time though).
```{code-cell} ipython3
BayesianInferencePlot(true_theta, num_list, BETA_numpyro).SVI_plot(guide_dist='beta', n_steps=100000)
```
## Non-conjugate Prior Distributions
Having assured ourselves that our MCMC and VI methods can work well when we have conjugate prior and so can also compute analytically, we
Expand Down Expand Up @@ -1052,7 +1044,6 @@ SVI_num_steps = 50000
example_CLASS = BayesianInference(param=(0,1), name_dist='uniform', solver='numpyro')
print(f'=======INFO=======\nParameters: {example_CLASS.param}\nPrior Dist: {example_CLASS.name_dist}\nSolver: {example_CLASS.solver}')
BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='normal', n_steps=SVI_num_steps)
```
```{code-cell} ipython3
Expand Down

0 comments on commit 1ead0a6

Please sign in to comment.