forked from echen/restricted-boltzmann-machines
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrbmcmd
executable file
·93 lines (80 loc) · 3.25 KB
/
rbmcmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python
from __future__ import print_function
import pickle
import rbm
import numpy
import sys
if sys.version < '3':
range=xrange
def output_array(a):
for j in a:
if j < 0.001: sys.stdout.write("0")
elif j > 0.999: sys.stdout.write("1")
else: sys.stdout.write(str(j))
sys.stdout.write(" ")
sys.stdout.write("\n")
def main():
if len(sys.argv)<3 or sys.argv[1] == "--help":
print("Usage: rbmcmd statefile command parameters...")
print(" rbmcmd statefile init num_visible num_hidden")
print(" rbmcmd statefile train max_epochs learning_rate < whitespace-separated lines of num_visible numbers 0 or 1.")
print(" rbmcmd statefile run_visible < whitespace-separated lines of num_visible numbers > whitespace-separated lines of num_hidden numbers")
print(" rbmcmd statefile run_hidden < whitespace-separated lines of num_hidden numbers > whitespace-separated lines of num_visible numbers")
print(" rbmcmd statefile daydream_trace num_samples > num_samples whitespace-separted lines of num_visible numbers")
print(" rbmcmd statefile daydream num_samples num_dreams > num_dreams whitespace-separated lines of num_visible numbers")
return
fname = sys.argv[1]
cmd = sys.argv[2]
r = None
try:
with open(fname, "rb") as f:
r = pickle.load(f)
except IOError: pass
if cmd == "init":
r = rbm.RBM(num_visible = int(sys.argv[3]), num_hidden = int(sys.argv[4]))
r.debug_print = False
elif cmd == "train" or cmd=="run_visible" or cmd=="run_hidden":
maxchunk=1000
while True:
line = None
incoming_data = []
for i in range(0, maxchunk):
line = sys.stdin.readline()
if not line: break
if line[0] == "#": continue
row = map(float, line.split())
incoming_data.append(row)
if cmd == "train":
max_epochs = int(sys.argv[3])
learning_rate = float(sys.argv[4])
r.train(numpy.array(incoming_data), max_epochs=max_epochs, learning_rate=learning_rate)
elif cmd == "run_visible":
hidden_states = r.run_visible(numpy.array(incoming_data))
for i in hidden_states:
output_array(i)
elif cmd == "run_hidden":
visible_states = r.run_hidden(numpy.array(incoming_data))
for i in visible_states:
output_array(i)
if not line: break
elif cmd == "daydream_trace":
num_s = int(sys.argv[3])
out = r.daydream(num_s)
for i in out:
for j in i:
sys.stdout.write(str(j))
sys.stdout.write(" ")
sys.stdout.write("\n")
elif cmd == "daydream":
num_s = int(sys.argv[3])
num_d = int(sys.argv[4])
for k in range(0,num_d):
out = r.daydream(num_s)
lastrow = out[-1]
output_array(lastrow)
else:
print("Unknown command")
if cmd=='init' or cmd=='train':
with open(fname, "wb") as f:
pickle.dump(r, f)
if __name__ == '__main__': main()