-
Notifications
You must be signed in to change notification settings - Fork 242
/
visualize.py
171 lines (138 loc) · 6.19 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from models.NMT import simpleNMT
from utils.examples import run_example
from data.reader import Vocabulary
HERE = os.path.realpath(os.path.join(os.path.realpath(__file__), '..'))
def load_examples(file_name):
with open(file_name) as f:
return [s.replace('\n', '') for s in f.readlines()]
# create a directory if it doesn't already exist
if not os.path.exists(os.path.join(HERE, 'attention_maps')):
os.makedirs(os.path.join(HERE, 'attention_maps'))
SAMPLE_HUMAN_VOCAB = os.path.join(HERE, 'data', 'sample_human_vocab.json')
SAMPLE_MACHINE_VOCAB = os.path.join(HERE, 'data', 'sample_machine_vocab.json')
SAMPLE_WEIGHTS = os.path.join(HERE, 'weights', 'sample_NMT.49.0.01.hdf5')
class Visualizer(object):
def __init__(self,
padding=None,
input_vocab=SAMPLE_HUMAN_VOCAB,
output_vocab=SAMPLE_MACHINE_VOCAB):
"""
Visualizes attention maps
:param padding: the padding to use for the sequences.
:param input_vocab: the location of the input human
vocabulary file
:param output_vocab: the location of the output
machine vocabulary file
"""
self.padding = padding
self.input_vocab = Vocabulary(
input_vocab, padding=padding)
self.output_vocab = Vocabulary(
output_vocab, padding=padding)
def set_models(self, pred_model, proba_model):
"""
Sets the models to use
:param pred_model: the prediction model
:param proba_model: the model that outputs the activation maps
"""
self.pred_model = pred_model
self.proba_model = proba_model
def attention_map(self, text):
"""
Text to visualze attention map for.
"""
# encode the string
d = self.input_vocab.string_to_int(text)
# get the output sequence
predicted_text = run_example(
self.pred_model, self.input_vocab, self.output_vocab, text)
text_ = list(text) + ['<eot>'] + ['<unk>'] * self.input_vocab.padding
# get the lengths of the string
input_length = len(text)+1
output_length = predicted_text.index('<eot>')+1
# get the activation map
activation_map = np.squeeze(self.proba_model.predict(np.array([d])))[
0:output_length, 0:input_length]
# import seaborn as sns
plt.clf()
f = plt.figure(figsize=(8, 8.5))
ax = f.add_subplot(1, 1, 1)
# add image
i = ax.imshow(activation_map, interpolation='nearest', cmap='gray')
# add colorbar
cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
cbar.ax.set_xlabel('Probability', labelpad=2)
# add labels
ax.set_yticks(range(output_length))
ax.set_yticklabels(predicted_text[:output_length])
ax.set_xticks(range(input_length))
ax.set_xticklabels(text_[:input_length], rotation=45)
ax.set_xlabel('Input Sequence')
ax.set_ylabel('Output Sequence')
# add grid and legend
ax.grid()
# ax.legend(loc='best')
f.savefig(os.path.join(HERE, 'attention_maps', text.replace('/', '')+'.pdf'), bbox_inches='tight')
f.show()
def main(examples, args):
print('Total Number of Examples:', len(examples))
weights_file = os.path.expanduser(args.weights)
print('Weights loading from:', weights_file)
viz = Visualizer(padding=args.padding,
input_vocab=args.human_vocab,
output_vocab=args.machine_vocab)
print('Loading models')
pred_model = simpleNMT(trainable=False,
pad_length=args.padding,
n_chars=viz.input_vocab.size(),
n_labels=viz.output_vocab.size())
pred_model.load_weights(weights_file, by_name=True)
pred_model.compile(optimizer='adam', loss='categorical_crossentropy')
proba_model = simpleNMT(trainable=False,
pad_length=args.padding,
n_chars=viz.input_vocab.size(),
n_labels=viz.output_vocab.size(),
return_probabilities=True)
proba_model.load_weights(weights_file, by_name=True)
proba_model.compile(optimizer='adam', loss='categorical_crossentropy')
viz.set_models(pred_model, proba_model)
print('Models loaded')
for example in examples:
viz.attention_map(example)
print('Completed visualizations')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
named_args = parser.add_argument_group('named arguments')
named_args.add_argument('-e', '--examples', metavar='|',
help="""Example string/file to visualize attention map for
If file, it must end with '.txt'""",
required=True)
named_args.add_argument('-w', '--weights', metavar='|',
help="""Location of weights""",
required=False,
default=SAMPLE_WEIGHTS)
named_args.add_argument('-p', '--padding', metavar='|',
help="""Length of padding""",
required=False, default=50, type=int)
named_args.add_argument('-hv', '--human-vocab', metavar='|',
help="""Path to the human vocabulary""",
required=False,
default=SAMPLE_HUMAN_VOCAB,
type=str)
named_args.add_argument('-mv', '--machine-vocab', metavar='|',
help="""Path to the machine vocabulary""",
required=False,
default=SAMPLE_MACHINE_VOCAB,
type=str)
args = parser.parse_args()
if '.txt' in args.examples:
examples = load_examples(args.examples)
else:
examples = [args.examples]
main(examples, args)