-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrtrandomwalk.py
125 lines (102 loc) · 3.41 KB
/
rtrandomwalk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.metaclass import RandomVariable
from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess
from pyrenew.transform import AbstractTransform, LogTransform
class RtRandomWalkProcess(RandomVariable):
r"""Rt Randomwalk Process
Notes
-----
The process is defined as follows:
.. math::
Rt(0) &\sim \text{Rt0_dist} \\
Rt(t) &\sim \text{Rt_transform}(\text{Rt_transformed_rw}(t))
"""
def __init__(
self,
Rt0_dist: dist.Distribution = dist.TruncatedNormal(
loc=1.2, scale=0.2, low=0
),
Rt_transform: AbstractTransform = LogTransform(),
Rt_rw_dist: dist.Distribution = dist.Normal(0, 0.025),
) -> None:
"""
Default constructor
Parameters
----------
Rt0_dist : dist.Distribution, optional
Initial distribution of Rt, defaults to
dist.TruncatedNormal( loc=1.2, scale=0.2, low=0 )
Rt_transform : AbstractTransform, optional
Transformation applied to the sampled Rt0, defaults
to LogTransform().
Rt_rw_dist : dist.Distribution, optional
Randomwalk process, defaults to dist.Normal(0, 0.025)
Returns
-------
None
"""
RtRandomWalkProcess.validate(Rt0_dist, Rt_transform, Rt_rw_dist)
self.Rt0_dist = Rt0_dist
self.Rt_transform = Rt_transform
self.Rt_rw_dist = Rt_rw_dist
return None
@staticmethod
def validate(
Rt0_dist: dist.Distribution,
Rt_transform: AbstractTransform,
Rt_rw_dist: dist.Distribution,
) -> None:
"""
Validates Rt0_dist, Rt_transform, and Rt_rw_dist.
Parameters
----------
Rt0_dist : dist.Distribution, optional
Initial distribution of Rt, expected dist.Distribution
Rt_transform : any
Transformation applied to the sampled Rt0, expected
AbstractTransform
Rt_rw_dist : any
Randomwalk process, expected dist.Distribution.
Returns
-------
None
Raises
------
AssertionError
If Rt0_dist or Rt_rw_dist are not dist.Distribution or if
Rt_transform is not AbstractTransform.
"""
assert isinstance(Rt0_dist, dist.Distribution)
assert isinstance(Rt_transform, AbstractTransform)
assert isinstance(Rt_rw_dist, dist.Distribution)
def sample(
self,
n_timepoints: int,
**kwargs,
) -> tuple:
"""
Generate samples from the process
Parameters
----------
n_timepoints : int
Number of timepoints to sample.
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample()
calls, should there be any.
Returns
-------
tuple
"""
Rt0 = npro.sample("Rt0", self.Rt0_dist)
Rt0_trans = self.Rt_transform(Rt0)
Rt_trans_proc = SimpleRandomWalkProcess(self.Rt_rw_dist)
Rt_trans_ts, *_ = Rt_trans_proc.sample(
duration=n_timepoints,
name="Rt_transformed_rw",
init=Rt0_trans,
)
Rt = npro.deterministic("Rt", self.Rt_transform.inverse(Rt_trans_ts))
return (Rt,)