-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathdiscrete_hmm.py
98 lines (78 loc) · 3 KB
/
discrete_hmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Example: Discrete HMM
=====================
"""
import argparse
from collections import OrderedDict
import torch
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.interpreter import reinterpret
from funsor.optimizer import apply_optimizer
def main(args):
funsor.set_backend("torch")
# Declare parameters.
trans_probs = torch.tensor([[0.2, 0.8], [0.7, 0.3]], requires_grad=True)
emit_probs = torch.tensor([[0.4, 0.6], [0.1, 0.9]], requires_grad=True)
params = [trans_probs, emit_probs]
# A discrete HMM model.
def model(data):
log_prob = funsor.to_funsor(0.0)
trans = dist.Categorical(
probs=funsor.Tensor(
trans_probs,
inputs=OrderedDict([("prev", funsor.Bint[args.hidden_dim])]),
)
)
emit = dist.Categorical(
probs=funsor.Tensor(
emit_probs,
inputs=OrderedDict([("latent", funsor.Bint[args.hidden_dim])]),
)
)
x_curr = funsor.Number(0, args.hidden_dim)
for t, y in enumerate(data):
x_prev = x_curr
# A delayed sample statement.
x_curr = funsor.Variable("x_{}".format(t), funsor.Bint[args.hidden_dim])
log_prob += trans(prev=x_prev, value=x_curr)
if not args.lazy and isinstance(x_prev, funsor.Variable):
log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)
log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))
log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
# Train model parameters.
data = torch.ones(args.time_steps, dtype=torch.long)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
if args.lazy:
with funsor.interpretations.lazy:
log_prob = apply_optimizer(model(data))
log_prob = reinterpret(log_prob)
else:
log_prob = model(data)
assert not log_prob.inputs, "free variables remain"
loss = -log_prob.data
loss.backward()
optim.step()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Kalman filter example")
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-n", "--train-steps", default=101, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.05, type=float)
parser.add_argument("-d", "--hidden-dim", default=2, type=int)
parser.add_argument("--lazy", action="store_true")
parser.add_argument("--filter", action="store_true")
parser.add_argument("--xfail-if-not-implemented", action="store_true")
args = parser.parse_args()
if args.xfail_if_not_implemented:
try:
main(args)
except NotImplementedError:
print("XFAIL")
else:
main(args)