-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdump_charcnn_weights.lua
53 lines (39 loc) · 1.3 KB
/
dump_charcnn_weights.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
-- script for dumping statistics about charCNN weights
-- works best with cudnn (the alternative -temp_conv option doesn't dump all the statistics)
-- Acknowledgement: uses code from https://github.com/harvardnlp/seq2seq-attn
require 'nn'
require 'nngraph'
require 's2sa.data'
require 's2sa.models'
require 'cudnn'
local cmd = torch.CmdLine()
-- file location
cmd:option('-model', 'model.t7.', [[Path to model .t7 file]])
--cmd:option('weightFile', 'weights.txt', [[Path to save char cnn weights]])
cmd:option('-temp_conv', false, [[Model trained with temporal convolution (not cudnn option)]])
opt = cmd:parse(arg)
function get_layer(layer)
if layer.name ~= nil then
if layer.name == 'charcnn_enc' then
charcnn = layer
end
end
end
checkpoint = torch.load(opt.model)
model = checkpoint[1]
model[1]:apply(get_layer)
--print(model[1])
--print(charcnn)
if opt.temp_conv then
weights = charcnn.modules[2]:double()
print('')
print('average variance overall: ' .. weights:var(2):mean())
else
weights = charcnn.modules[3].weight:double()
weights = weights:squeeze()
print('')
print('average variance overall: ' .. weights:var(2):mean())
print('average variance per char embedding dim:')
print(weights:var(2):mean(1):squeeze())
print('variance of variances: ' .. weights:var(2):var())
end