-
Notifications
You must be signed in to change notification settings - Fork 21
/
train.py
134 lines (110 loc) · 4.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from modules.entities import EntityTracker
from modules.bow import BoW_encoder
from modules.lstm_net import LSTM_net
from modules.embed import UtteranceEmbed
from modules.actions import ActionTracker
from modules.data_utils import Data
import modules.util as util
import numpy as np
import sys
class Trainer():
def __init__(self):
et = EntityTracker()
self.bow_enc = BoW_encoder()
self.emb = UtteranceEmbed()
at = ActionTracker(et)
self.dataset, dialog_indices = Data(et, at).trainset
self.dialog_indices_tr = dialog_indices[:200]
self.dialog_indices_dev = dialog_indices[200:250]
obs_size = self.emb.dim + self.bow_enc.vocab_size + et.num_features
self.action_templates = at.get_action_templates()
action_size = at.action_size
nb_hidden = 128
self.net = LSTM_net(obs_size=obs_size,
action_size=action_size,
nb_hidden=nb_hidden)
def train(self):
print('\n:: training started\n')
epochs = 20
for j in range(epochs):
# iterate through dialogs
num_tr_examples = len(self.dialog_indices_tr)
loss = 0.
for i,dialog_idx in enumerate(self.dialog_indices_tr):
# get start and end index
start, end = dialog_idx['start'], dialog_idx['end']
# train on dialogue
loss += self.dialog_train(self.dataset[start:end])
# print #iteration
sys.stdout.write('\r{}.[{}/{}]'.format(j+1, i+1, num_tr_examples))
print('\n\n:: {}.tr loss {}'.format(j+1, loss/num_tr_examples))
# evaluate every epoch
accuracy = self.evaluate()
print(':: {}.dev accuracy {}\n'.format(j+1, accuracy))
if accuracy > 0.99:
self.net.save()
break
def dialog_train(self, dialog):
# create entity tracker
et = EntityTracker()
# create action tracker
at = ActionTracker(et)
# reset network
self.net.reset_state()
loss = 0.
# iterate through dialog
for (u,r) in dialog:
u_ent = et.extract_entities(u)
u_ent_features = et.context_features()
u_emb = self.emb.encode(u)
u_bow = self.bow_enc.encode(u)
# concat features
features = np.concatenate((u_ent_features, u_emb, u_bow), axis=0)
# get action mask
action_mask = at.action_mask()
# forward propagation
# train step
loss += self.net.train_step(features, r, action_mask)
return loss/len(dialog)
def evaluate(self):
# create entity tracker
et = EntityTracker()
# create action tracker
at = ActionTracker(et)
# reset network
self.net.reset_state()
dialog_accuracy = 0.
for dialog_idx in self.dialog_indices_dev:
start, end = dialog_idx['start'], dialog_idx['end']
dialog = self.dataset[start:end]
num_dev_examples = len(self.dialog_indices_dev)
# create entity tracker
et = EntityTracker()
# create action tracker
at = ActionTracker(et)
# reset network
self.net.reset_state()
# iterate through dialog
correct_examples = 0
for (u,r) in dialog:
# encode utterance
u_ent = et.extract_entities(u)
u_ent_features = et.context_features()
u_emb = self.emb.encode(u)
u_bow = self.bow_enc.encode(u)
# concat features
features = np.concatenate((u_ent_features, u_emb, u_bow), axis=0)
# get action mask
action_mask = at.action_mask()
# forward propagation
# train step
prediction = self.net.forward(features, action_mask)
correct_examples += int(prediction == r)
# get dialog accuracy
dialog_accuracy += correct_examples/len(dialog)
return dialog_accuracy/num_dev_examples
if __name__ == '__main__':
# setup trainer
trainer = Trainer()
# start training
trainer.train()