Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in numpyro's AffineTransform #376

Closed
gvegayon opened this issue Aug 14, 2024 · 6 comments
Closed

Possible bug in numpyro's AffineTransform #376

gvegayon opened this issue Aug 14, 2024 · 6 comments

Comments

@gvegayon
Copy link
Member

After @damonbayer's comment in #344 (comment), I found what I think is a bug in the numpyro.distributions.transforms.AffineTransform function. Using the DistributionalVariable yields unexpected results:

import jax.numpy as jnp
import numpyro as npro
import jax
from pyrenew.metaclass import DistributionalRV, Model

# My dist

d1 = DistributionalRV(
    'x0',
    npro.distributions.TransformedDistribution(
        npro.distributions.Dirichlet(jnp.ones(7)),
        npro.distributions.transforms.AffineTransform(0, 7),
    ) 
)

d2 = DistributionalRV(
    'x',
    npro.distributions.TransformedDistribution(
        npro.distributions.Dirichlet(jnp.ones(7)),
        npro.distributions.transforms.AffineTransform(0, 7),
    ) 
)

d3 = DistributionalRV(
    'x',
    npro.distributions.Dirichlet(jnp.ones(7))
)

d4 = DistributionalRV(
    'x',
    npro.distributions.TransformedDistribution(
        npro.distributions.Dirichlet(jnp.ones(7)),
        npro.distributions.transforms.AffineTransform(0, 1),
    ) 
)

# Dummy model
class MyModel(Model):
    def __init__(self, rv):
        super().__init__()
        self.rv = rv

    def validate(self):
        pass

    def sample(self):
        return self.rv()

mymodel2 = MyModel(d2)
mymodel3 = MyModel(d3)
mymodel4 = MyModel(d4)

mymodel2.run(
    num_samples=1000,
    num_warmup=0,
    rng_key=jax.random.key(4)
    )

mymodel3.run(
    num_samples=1000,
    num_warmup=0,
    rng_key=jax.random.key(4)
    )

mymodel4.run(
    num_samples=1000,
    num_warmup=0,
    rng_key=jax.random.key(4)
    )

# sampling from a dirichlet distribution
with npro.handlers.seed(rng_seed=0):
    
    x2 = d1.sample()[0].value
    for i in range(1000):
        x2 = jnp.hstack([x2, d1.sample()[0].value])
  0%|          | 0/1000 [00:00<?, ?it/s]sample:   0%|          | 1/1000 [00:00<10:17,  1.62it/s, 1 steps of size 1.00e+00. acc. prob=0.00]sample:  51%|█████     | 509/1000 [00:00<00:00, 954.37it/s, 1023 steps of size 1.00e+00. acc. prob=0.89]sample:  84%|████████▍ | 841/1000 [00:00<00:00, 1454.02it/s, 209 steps of size 1.00e+00. acc. prob=0.93]sample: 100%|██████████| 1000/1000 [00:00<00:00, 1162.46it/s, 1023 steps of size 1.00e+00. acc. prob=0.94]
  0%|          | 0/1000 [00:00<?, ?it/s]sample:   0%|          | 1/1000 [00:00<09:44,  1.71it/s, 3 steps of size 1.00e+00. acc. prob=0.77]sample: 100%|██████████| 1000/1000 [00:00<00:00, 1519.08it/s, 7 steps of size 1.00e+00. acc. prob=0.84]
  0%|          | 0/1000 [00:00<?, ?it/s]sample:   0%|          | 1/1000 [00:00<11:36,  1.43it/s, 1 steps of size 1.00e+00. acc. prob=0.00]sample:  54%|█████▍    | 540/1000 [00:00<00:00, 917.67it/s, 1023 steps of size 1.00e+00. acc. prob=0.90]sample:  90%|████████▉ | 896/1000 [00:00<00:00, 1433.69it/s, 1023 steps of size 1.00e+00. acc. prob=0.94]sample: 100%|██████████| 1000/1000 [00:00<00:00, 1081.25it/s, 1023 steps of size 1.00e+00. acc. prob=0.94]

Looking at the samples

print(f'x2 mean: {x2.mean()}')
print(f'mymodel2')
mymodel2.print_summary()
print(f'mymodel3')
mymodel3.print_summary()
print(f'mymodel4')
mymodel4.print_summary()
x2 mean: 1.0
mymodel2

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      x[0]   5074.25   5014.87   3495.59      1.52  13601.10      4.37      1.14
      x[1]   2513.24   2755.99   1235.35      0.10   6909.12      5.60      1.21
      x[2]   7556.17   6975.86   6669.29      0.47  18388.43      2.67      3.22
      x[3]   7253.91   6509.18   8881.08      0.17  15493.91      2.50      4.47
      x[4]   2760.71   2696.15   2026.63      0.21   6670.87      6.61      1.03
      x[5]   2979.05   3857.88   1342.97      0.03   9023.24      6.11      1.19
      x[6]   2376.66   2192.88   1508.76      0.31   5635.58     11.43      1.46

Number of divergences: 517
mymodel3

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      x[0]      0.14      0.12      0.10      0.00      0.32   1108.14      1.00
      x[1]      0.15      0.13      0.11      0.00      0.32   1437.83      1.00
      x[2]      0.14      0.12      0.11      0.00      0.33   1329.15      1.00
      x[3]      0.14      0.12      0.11      0.00      0.31    984.57      1.00
      x[4]      0.15      0.12      0.11      0.00      0.33   1200.59      1.00
      x[5]      0.14      0.12      0.11      0.00      0.31    879.64      1.00
      x[6]      0.15      0.13      0.11      0.00      0.33    931.85      1.00

Number of divergences: 0
mymodel4

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      x[0]   5074.25   5014.87   3495.59      1.52  13601.10      4.37      1.14
      x[1]   2513.24   2755.99   1235.35      0.10   6909.12      5.60      1.21
      x[2]   7556.17   6975.86   6669.29      0.47  18388.43      2.67      3.22
      x[3]   7253.91   6509.18   8881.08      0.17  15493.91      2.50      4.47
      x[4]   2760.71   2696.15   2026.63      0.21   6670.87      6.61      1.03
      x[5]   2979.05   3857.88   1342.97      0.03   9023.24      6.11      1.19
      x[6]   2376.66   2192.88   1508.76      0.31   5635.58     11.43      1.46

Number of divergences: 517
@gvegayon
Copy link
Member Author

@CDCgov/multisignal-epi-inference-devs what am I missing?

@gvegayon
Copy link
Member Author

Possibly related to this: pyro-ppl/numpyro#1756.

@damonbayer
Copy link
Collaborator

No idea what is going on here, but wanted to note my original comment suggests using a TransformedRandomVariable, not a TransformedDistribution.

Investigated a bit on my own and realized we do not properly record TransformedRandomVariable. Opening an issue.

@gvegayon
Copy link
Member Author

No idea what is going on here, but wanted to note my original comment suggests using a TransformedRandomVariable, not a TransformedDistribution.

Investigated a bit on my own and realized we do not properly record TransformedRandomVariable. Opening an issue.

Yes! That's why I followed this approach instead. I realized we were not recording the transformed RV as well. Either way, until we fix either, I am not able to plot the posterior distribution of the DOW effect sample.

@dylanhmorris
Copy link
Collaborator

Record it with an explicit numpyro.deterministic call for now?

@damonbayer
Copy link
Collaborator

Feel free to re-open if you think this needs further discussion @gvegayon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants