-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_rnn_updated.py
165 lines (143 loc) · 6.04 KB
/
train_rnn_updated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import torch.nn as nn
from nltk.tokenize import word_tokenize
from torch.optim import Adam
from torch.utils import data
import config
from attention_rnn import SpeakerAttentionClassifier, SpeakerAttentionRNN
from dataset_updated import Dataset
from rnn import SpeakerClassifier, SpeakerRNN
from utils import (generate_batch_vectors, generate_sent_vector,
get_boundary_mapping, get_word_vector, load_model,
read_data_from_csv)
use_attention = config.USE_ATTENTION
if __name__ == '__main__':
print("Loading pre-trained embeddings...")
w2v_model = load_model(config.PATH_TO_PRETRAINED_EMBEDDINGS)
print("Loading training data...")
if config.EQUALIZE_CLASS_COUNTS is True:
print("\tEqualizing class counts!")
train_data = read_data_from_csv(
filename=config.CSV_FILENAME_TRAIN,
train=True,
equalize=config.EQUALIZE_CLASS_COUNTS,
num_records=config.RNN_NUM_RECORDS
)
print("\tTotal length of training data: {}".format(len(train_data)))
print("\tNumber of SAME records: {}".format(len([a for a in train_data if a['boundary'] == '[SAME]'])))
print("\tNumber of CHANGE records: {}".format(len([a for a in train_data if a['boundary'] == '[CHANGE]'])))
print("Creating data generator...")
train_set = Dataset(train_data)
train_generator = data.DataLoader(
dataset=train_set,
drop_last=True,
batch_size=config.RNN_BATCH_SIZE,
shuffle=True
)
print("Initializing models...")
device = torch.device(config.DEVICE)
if use_attention is True:
print("\tUsing attention!")
model1 = SpeakerAttentionRNN(
emb_size=300,
hidden_size=config.RNN_HIDDEN_SIZE,
num_layers=1,
dev=device
)
model2 = SpeakerAttentionRNN(
emb_size=300,
hidden_size=config.RNN_HIDDEN_SIZE,
num_layers=1,
dev=device
)
classifier = SpeakerAttentionClassifier(
hidden_size=config.RNN_HIDDEN_SIZE * 2,
num_classes=1
)
classifier.to(device)
else:
model1 = SpeakerRNN(
device=device,
emb_size=300,
hidden_size=config.RNN_HIDDEN_SIZE,
num_classes=1,
batch_size=config.RNN_BATCH_SIZE,
num_layers=1,
bidirectionality=False
)
model2 = SpeakerRNN(
device=device,
emb_size=300,
hidden_size=config.RNN_HIDDEN_SIZE,
num_classes=1,
batch_size=config.RNN_BATCH_SIZE,
num_layers=1,
bidirectionality=False
)
classifier = SpeakerClassifier(
device=device,
input_size=config.RNN_HIDDEN_SIZE * 2,
output_size=1
)
model1 = model1.to(device)
optimizer1 = Adam(model1.parameters(), lr=config.RNN_LEARNING_RATE)
model2 = model2.to(device)
optimizer2 = Adam(model2.parameters(), lr=config.RNN_LEARNING_RATE)
classifier = classifier.to(device)
criterion = nn.BCELoss()
print("Training model...")
for epoch in range(config.RNN_NUM_EPOCHS):
print("Current epoch: {}".format(epoch + 1))
epoch_loss = 0.0
for sent1_batch, sent2_batch, boundary_batch in train_generator:
optimizer1.zero_grad()
optimizer2.zero_grad()
if use_attention is True:
# when using attention, seq_len for both sentences need to be the same
max_sent_len = max(
max([len(word_tokenize(a)) for a in sent1_batch]),
max([len(word_tokenize(a)) for a in sent2_batch])
)
sent1_batch_vectors = generate_batch_vectors(sent1_batch, w2v_model, max_sent_len=max_sent_len)
sent2_batch_vectors = generate_batch_vectors(sent2_batch, w2v_model, max_sent_len=max_sent_len)
else:
sent1_batch_vectors = generate_batch_vectors(sent1_batch, w2v_model)
sent2_batch_vectors = generate_batch_vectors(sent2_batch, w2v_model)
boundary_batch = get_boundary_mapping(boundary_batch)
sent1_batch_vectors = sent1_batch_vectors.to(device)
sent2_batch_vectors = sent2_batch_vectors.to(device)
boundary_batch = torch.Tensor(boundary_batch).to(device)
output1, hidden1 = model1(sent1_batch_vectors)
output2, hidden2 = model2(sent2_batch_vectors)
if use_attention is True:
output = classifier(output1, output2)
else:
hidden1 = hidden1.squeeze(dim=0)
hidden2 = hidden2.squeeze(dim=0)
combined_hidden = torch.cat([hidden1, hidden2], dim=1)
output = classifier(combined_hidden)
# binary cross entropy needs this; originally the shape of output is [batch_size, 1]
output = output.squeeze(dim=1)
loss = criterion(output, boundary_batch)
epoch_loss += loss.item()
loss.backward()
optimizer1.step()
optimizer2.step()
print("Loss: {}".format(epoch_loss))
print()
if use_attention is True and config.EQUALIZE_CLASS_COUNTS is True:
torch.save(model1, config.RNN_EQ_ATTENTION_MODEL1)
torch.save(model2, config.RNN_EQ_ATTENTION_MODEL2)
torch.save(classifier, config.RNN_EQ_ATTENTION_CLASSIFIER)
elif use_attention is True and config.EQUALIZE_CLASS_COUNTS is False:
torch.save(model1, config.RNN_ATTENTION_MODEL1)
torch.save(model2, config.RNN_ATTENTION_MODEL2)
torch.save(classifier, config.RNN_ATTENTION_CLASSIFIER)
elif use_attention is False and config.EQUALIZE_CLASS_COUNTS is True:
torch.save(model1, config.RNN_EQ_MODEL1)
torch.save(model2, config.RNN_EQ_MODEL2)
torch.save(classifier, config.RNN_EQ_CLASSIFIER)
else:
torch.save(model1, config.RNN_MODEL1)
torch.save(model2, config.RNN_MODEL2)
torch.save(classifier, config.RNN_CLASSIFIER)