-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRNN.py
133 lines (100 loc) · 3.89 KB
/
RNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch.nn as nn
import torch.nn.functional as F
import torch
from FiniteStateMachine import DFA
if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
class LSTM_model(nn.Module):
def __init__(self, hidden_dim, vocab_size, tagset_size):
super(LSTM_model, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = 2
self.lstm = nn.LSTM(vocab_size, hidden_dim, self.num_layers, batch_first=True)
self.input_size = vocab_size
# The linear layer that maps from hidden state space to tag space
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
self.positive_activation = torch.nn.ReLU()
def forward(self, x):
batch_size = x.size()[0]
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
lstm_out, (hn, cn) = self.lstm(x, (h0, c0))
#x = lstm_out.reshape(-1, self.hidden_dim)
x = lstm_out
#print(x.size())
tag_space = self.positive_activation(self.hidden2tag(x))+0.5
#print(tag_space.size())
return tag_space, (hn, cn)
def forward_from_state(self, x, state):
lstm_out, (hn, cn) = self.lstm(x, state)
x = lstm_out
tag_space = self.positive_activation(self.hidden2tag(x))+0.5
return tag_space, (hn, cn)
def next_sym_prob(self, x, state):
tag_space, state = self.forward_from_state(x, state)
tag_space = F.softmax(tag_space, dim=-1)
return tag_space, state
def predict(self, sentence):
tag_space = self.forward(sentence)
out = F.softmax(tag_space, dim=1)[-1]
return out
class RNN_with_constraints_model(nn.Module):
def __init__(self, rnn, ltl_formula):
super(RNN_with_constraints_model, self).__init__()
self.rnn = rnn
#formula evaluator
dfa = DFA(ltl_formula, 2, "random DNF declare", ['c0', 'c1', 'end'])
self.deep_dfa_constraint = dfa.return_deep_dfa_constraint()
def forward(self, x):
pred_sym, hidden_states = self.rnn(x)
#print(pred_sym.size())
#print(x[0])
#TODO transform one-hot into indices
x_indices = torch.argmax(x, dim= -1).long()
#print(x_indices.size())
#print(x_indices[0])
masks, dfa_state = self.deep_dfa_constraint(x_indices)
#print(masks.size())
#print(masks)
#print(pred_sym.size())
#print(pred_sym)
pred_sym = (pred_sym) * masks
#print(pred_sym.size())
#print(pred_sym)
return pred_sym, (hidden_states, dfa_state)
def forward_from_state(self, x, tot_state):
state_rnn, state_dfa = tot_state
#print("state dfa")
#print(state_dfa[0])
#print("symbol one hot")
#print(x[0])
next_event, next_state_rnn = self.rnn.forward_from_state(x, state_rnn)
next_event = next_event.squeeze()
#print(x)
#print(x.size())
x = torch.argmax(x, -1).squeeze()
#print("symbol index")
#print(x[0])
#print(x)
#print(x.size())
next_dfa_state, mask = self.deep_dfa_constraint.step(state_dfa, x)
#print("next dfa state")
#print(next_dfa_state[0])
#print("mask to put on next event prediction")
#print(mask[0])
#print(mask.size())
#print("next event according rnn")
#print(next_event[0])
#print(next_event.size())
next_event =(next_event) * mask
#print("next event according rnn + constraint")
#print(next_event[0])
#print(next_event.size())
#print("_________________________________________________")
return next_event.unsqueeze(1), (next_state_rnn, next_dfa_state)
'''
class LSTM_next_activity_predictor(nn.Module):
def __init__(self, hidden_dim, vocab_size):
'''