From a44175fdda5721dea97f1647340efb6082d0c0d9 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Mon, 22 Jul 2024 17:20:43 -0500 Subject: [PATCH] Update hospital_admissions_model.qmd --- docs/source/tutorials/hospital_admissions_model.qmd | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 7e89df33..b6f2e716 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -183,12 +183,14 @@ gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") class MyRt(metaclass.RandomVariable): + def __init__(self, sd_rv): + self.sd_rv = sd_rv def validate(self): pass def sample(self, n_steps: int, **kwargs) -> tuple: - sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) + sd_rt, *_ = self.sd_rv() rt_rv = metaclass.TransformedRandomVariable( "Rt_rv", @@ -207,7 +209,9 @@ class MyRt(metaclass.RandomVariable): return rt_rv.sample(n_steps=n_steps, **kwargs) -rtproc = MyRt() +rtproc = MyRt( + metaclass.DistributionalRV(dist.HalfNormal(0.025), "Rt_random_walk_sd") +) # The observation model