-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
35 lines (30 loc) · 1.32 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
def attention_visualization(model, trans, sent):
N = len(model.encoder.layers)
tgt_sent = trans
def draw(data, x, y, ax):
seaborn.heatmap(data.cpu(),
xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
cbar=False, ax=ax)
for layer in range(0, N):
fig, axs = plt.subplots(1, 4, figsize=(20, 10))
print("Encoder Layer", layer + 1)
for h in range(4):
draw(model.encoder.layers[layer].self_attn.attn[0, h].data,
sent, sent if h == 0 else [], ax=axs[h])
plt.show()
for layer in range(0, N):
fig, axs = plt.subplots(1, 4, figsize=(20, 10))
print("Decoder Self Layer", layer + 1)
for h in range(4):
draw(model.decoder.layers[layer].self_attn.attn[0, h].data[:len(tgt_sent), :len(tgt_sent)],
tgt_sent, tgt_sent if h == 0 else [], ax=axs[h])
plt.show()
print("Decoder Src Layer", layer + 1)
fig, axs = plt.subplots(1, 4, figsize=(20, 10))
for h in range(4):
draw(model.decoder.layers[layer].self_attn.attn[0, h].data[:len(tgt_sent), :len(sent)],
sent, tgt_sent if h == 0 else [], ax=axs[h])
plt.show()