-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain_svi.py
117 lines (102 loc) · 5.34 KB
/
main_svi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import pdb
import sys
import jax.random as jrandom
from jax.config import config
config.update("jax_enable_x64", True)
from data_generation import gen_slds_nica
from train_svi import full_train
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
# uncomment to debug NaNs
#config.update("jax_debug_nans", True)
def parse():
"""Argument parser for all configs.
"""
parser = argparse.ArgumentParser(description='')
# data generation args
parser.add_argument('-n', type=int, default=3,
help="number of ICs")
parser.add_argument('-m', type=int, default=12,
help="dimension of observed data")
parser.add_argument('-l', type=int, default=0,
help="number of nonlinear layers; 0 = linear ICA")
parser.add_argument('-d', type=int, default=2,
help="dimension of lds state")
parser.add_argument('-k', type=int, default=2,
help="number of HMM states")
parser.add_argument('-t', type=int, default=100000,
help="number of timesteps")
parser.add_argument('--whiten', action='store_true', default=False,
help="PCA whiten data as preprocessing")
parser.add_argument('--gt-gm-params', action='store_true', default=False,
help="debug with GM parameters at ground truth")
# set seeds
parser.add_argument('--param-seed', type=int, default=50,
help="seed for initializing data generation params")
parser.add_argument('--data-seed', type=int, default=1,
help="seed for initializing data generation sampling")
parser.add_argument('--est-seed', type=int, default=99,
help="seed for initializing function estimators")
# inference & training & optimization parameters
parser.add_argument('--inference-iters', type=int, default=5,
help="num. of inference iterations")
parser.add_argument('--num-samples', type=int, default=1,
help="num. of samples for elbo")
parser.add_argument('--hidden-units-enc', type=int, default=64,
help="num. of hidden units in encoder estimator MLP")
parser.add_argument('--hidden-units-dec', type=int, default=32,
help="num. of hidden units in decoder estimator MLP")
parser.add_argument('--hidden-layers-enc', type=int, default=0,
help="num. of hidden layers in encoder estimator MLP")
parser.add_argument('--hidden-layers-dec', type=int, default=0,
help="num. of hidden layers in decoder estimator MLP")
parser.add_argument('--nn-learning-rate', type=float, default=1e-2,
help="learning rate for training function estimators")
parser.add_argument('--gm-learning-rate', type=float, default=1e-2,
help="learning rate for training GM parameters")
parser.add_argument('--burnin', type=float, default=500,
help="keep output precision fixed for _ iterations")
parser.add_argument('--num-epochs', type=int, default=100000,
help="number of training epochs")
parser.add_argument('--decay-rate', type=float, default=1.,
help="decay rate for training (default to no decay)")
parser.add_argument('--decay-interval', type=int, default=1e10,
help="interval (in iterations) for full decay of LR")
parser.add_argument('--subseq-len', type=int, default=100,
help="T is split into this length sub-chains")
parser.add_argument('--minib-size', type=int, default=32,
help="number of subchains in a single minibatch")
parser.add_argument('--plot-freq', type=int, default=100,
help="plotting frequency")
parser.add_argument('--eval-freq', type=int, default=10,
help="evaluation frequency")
# saving and loading
parser.add_argument('--out-dir', type=str, default="output/",
help="location where data is saved")
parser.add_argument('--resume-best', action='store_true', default=False,
help="resume from best chkpoint for current args")
parser.add_argument('--eval-only', action='store_true', default=False,
help="eval only wihtout training")
args = parser.parse_args()
return args
def main():
args = parse()
# generate data
param_key = jrandom.PRNGKey(args.param_seed)
data_key = jrandom.PRNGKey(args.data_seed)
# generate simulated data
# !BEWARE d=2, k=2 fixed in data generation
x, f, z, z_mu, states, *params = gen_slds_nica(args.n, args.m, args.t,
args.k, args.d, args.l,
param_key, data_key,
repeat_layers=True)
# we have not tried this option but could be useful in some cases
if args.whiten:
pca = PCA(whiten=True)
x = pca.fit_transform(x.T).T
# train
est_params, posteriors, best_elbo = full_train(x, f, z, z_mu, states,
params, args, args.est_seed)
if __name__ == "__main__":
sys.exit(main())