-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathUtilClass.py
executable file
·105 lines (84 loc) · 3.98 KB
/
UtilClass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import allennlp.modules.seq2vec_encoders
from allennlp.nn.util import sort_batch_by_length
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
def bottle(v):
return v.view(-1, v.size(2)) if v is not None else None
def unbottle(v, batch_size):
return v.view(batch_size, -1, v.size(1))
def shiftLeft(t, pad):
shifted_t = t[:, 1:] # first dim is batch
padding = torch.zeros(t.size(0), 1).fill_(pad).long().cuda()
return torch.cat((shifted_t, padding), 1)
class ConfigurationError(Exception):
"""
The exception raised by any AllenNLP object when it's misconfigured
(e.g. missing properties, invalid properties, unknown properties).
"""
def __init__(self, message):
super(ConfigurationError, self).__init__()
self.message = message
def __str__(self):
return repr(self.message)
class ProperLSTM(nn.LSTM):
def forward(self, seq, seq_lens):
if not self.batch_first:
raise ConfigurationError("Our encoder semantics assumes batch is always first!")
non_zero_length_mask = seq_lens.ne(0).float()
# make zero lengths into length=1
seq_lens = seq_lens + seq_lens.eq(0).float()
sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices =\
sort_batch_by_length(seq, seq_lens)
packed_input = pack(sorted_inputs, sorted_sequence_lengths.data.long().tolist(), batch_first=True)
outputs, final_states = super(ProperLSTM, self).forward(packed_input)
unpacked_sequence, _ = unpack(outputs, batch_first=True)
outputs = unpacked_sequence.index_select(0, restoration_indices)
new_unsorted_states = [self.fix_hidden(state.index_select(1, restoration_indices))
for state in final_states]
# To deal with zero length inputs
outputs = outputs * non_zero_length_mask.view(-1, 1, 1).expand_as(outputs)
new_unsorted_states[0] = new_unsorted_states[0] * non_zero_length_mask.view(1, -1, 1).expand_as(new_unsorted_states[0])
new_unsorted_states[1] = new_unsorted_states[1] * non_zero_length_mask.view(1, -1, 1).expand_as(new_unsorted_states[1])
return outputs, new_unsorted_states
def fix_hidden(self, h):
# (layers*directions) x batch x dim to layers x batch x (directions*dim))
if self.bidirectional:
h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
return h
class Bottle(nn.Module):
def forward(self, input):
if len(input.size()) <= 2:
return super(Bottle, self).forward(input)
size = input.size()[:2]
out = super(Bottle, self).forward(input.view(size[0]*size[1], -1))
return out.contiguous().view(size[0], size[1], -1)
class BottleLinear(Bottle, nn.Linear):
pass
class BottleEmb(nn.Module):
def forward(self, input):
size = input.size()
if len(size) <= 2:
return super(BottleEmb, self).forward(input)
if len(size) == 3:
out = super(BottleEmb, self).forward(input.view(size[0]*size[1], -1))
return out.contiguous().view(size[0], size[1], size[2], -1)
elif len(size) == 4:
out = super(BottleEmb, self).forward(input.view(size[0]*size[1]*size[2], -1))
return out.contiguous().view(size[0], size[1], size[2], size[3], -1)
class BottleLSTMHelper(nn.Module):
def forward(self, input, lengths):
size = input.size()
if len(size) <= 3:
return super(BottleLSTMHelper, self).forward(input, lengths)
if len(size) == 4:
out = super(BottleLSTMHelper, self).forward(input.view(size[0]*size[1], size[2], -1), lengths.view(-1))
return (out[0].contiguous().view(size[0], size[1], size[2], -1),
(out[1][0].contiguous().view(out[1][0].size(0), size[0], size[1], -1),
out[1][1].contiguous().view(out[1][1].size(0), size[0], size[1], -1))
)
class BottleLSTM(BottleLSTMHelper, ProperLSTM):
pass
class BottleEmbedding(BottleEmb, nn.Embedding):
pass