-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathpenntreebank_inspect.py
76 lines (63 loc) · 2.67 KB
/
penntreebank_inspect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import sys, cPickle
import matplotlib.pyplot as plt
def split_path(pathlike):
i, pathlike = pathlike
try:
name, path = pathlike.split(":")
except (ValueError, AttributeError):
name, path = i, pathlike
print("%s: %s" % (name, path))
return name, path
def load_instance(pathlike):
name, path = split_path(pathlike)
with open(path, "rb") as file:
thing = cPickle.load(file)
return dict(name=name, path=path, **thing)
# arguments: (optionally labeled) paths to pickle files generated by penntreebank_evaluate.py, in the form [label:]path
paths = sys.argv[1:]
instances = list(map(load_instance, enumerate(paths)))
import math
def natstobits(x):
return x / math.log(2)
colors = "blue red green cyan magenta yellow black white".split()
for which_set in "train valid test".split():
plt.figure()
for situation, kwargs in [("inference", dict(linestyle="solid")),
("training", dict(linestyle="dashed"))]:
for color, instance in zip(colors, instances):
# baseline training/inference performances will be identical
if instance["name"] == "LSTM" and situation == "training":
continue
label = instance["name"]
if instance["name"] == "BN-LSTM":
label += ", " + dict(training="batch statistics",
inference="population statistics")[situation]
results = instance["results"][situation][which_set]
tvs = [(t, v["cross_entropy"]) for t, v in results.items()]
time, value = zip(*tvs)
# don't care about result of length 50 as we're training on 100 now
assert time[0] == 50
time = time[1:]
value = value[1:]
value = list(map(natstobits, value))
plt.plot(time, value, label=label, c=color, linewidth=3, **kwargs)
#plt.yscale("log")
#plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.legend()
#plt.title("performance on slices of the " + which_set + " string")
plt.xlabel("sequence length")
plt.ylabel("mean bits per character")
for instance in instances:
print "bpc on full test", instance["name"], natstobits(instance["results"]["proper_test"]["cross_entropy"])
import pdb; pdb.set_trace()
plt.show()
if False:
for instance in instances:
for variable, value in instance["new_popstats"].items():
plt.figure()
plt.imshow(value, cmap="bone", aspect="auto")
plt.colorbar()
plt.title("%s %s" % (instance["name"], variable.name))
import pdb; pdb.set_trace()
plt.show()
import pdb; pdb.set_trace()