-
Notifications
You must be signed in to change notification settings - Fork 48
/
utils.py
247 lines (205 loc) · 10.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import torch
import os
import wget
import tarfile
import shutil
import codecs
import youtokentome
import math
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def download_data(data_folder):
"""
Downloads the training, validation, and test files for WMT '14 en-de translation task.
Training: Europarl v7, Common Crawl, News Commentary v9
Validation: newstest2013
Testing: newstest2014
The homepage for the WMT '14 translation task, https://www.statmt.org/wmt14/translation-task.html, contains links to
the datasets.
:param data_folder: the folder where the files will be downloaded
"""
train_urls = ["http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
"https://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
"http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz"]
print("\n\nThis may take a while.")
# Create a folder to store downloaded TAR files
if not os.path.isdir(os.path.join(data_folder, "tar files")):
os.mkdir(os.path.join(data_folder, "tar files"))
# Create a fresh folder to extract downloaded TAR files; previous extractions deleted to prevent tarfile module errors
if os.path.isdir(os.path.join(data_folder, "extracted files")):
shutil.rmtree(os.path.join(data_folder, "extracted files"))
os.mkdir(os.path.join(data_folder, "extracted files"))
# Download and extract training data
for url in train_urls:
filename = url.split("/")[-1]
if not os.path.exists(os.path.join(data_folder, "tar files", filename)):
print("\nDownloading %s..." % filename)
wget.download(url, os.path.join(data_folder, "tar files", filename))
print("\nExtracting %s..." % filename)
tar = tarfile.open(os.path.join(data_folder, "tar files", filename))
members = [m for m in tar.getmembers() if "de-en" in m.path]
tar.extractall(os.path.join(data_folder, "extracted files"), members=members)
# Download validation and testing data using sacreBLEU since we will be using this library to calculate BLEU scores
print("\n")
os.system("sacrebleu -t wmt13 -l en-de --echo src > '" + os.path.join(data_folder, "val.en") + "'")
os.system("sacrebleu -t wmt13 -l en-de --echo ref > '" + os.path.join(data_folder, "val.de") + "'")
print("\n")
os.system("sacrebleu -t wmt14/full -l en-de --echo src > '" + os.path.join(data_folder, "test.en") + "'")
os.system("sacrebleu -t wmt14/full -l en-de --echo ref > '" + os.path.join(data_folder, "test.de") + "'")
# Move files if they were extracted into a subdirectory
for dir in [d for d in os.listdir(os.path.join(data_folder, "extracted files")) if
os.path.isdir(os.path.join(data_folder, "extracted files", d))]:
for f in os.listdir(os.path.join(data_folder, "extracted files", dir)):
shutil.move(os.path.join(data_folder, "extracted files", dir, f),
os.path.join(data_folder, "extracted files"))
os.rmdir(os.path.join(data_folder, "extracted files", dir))
def prepare_data(data_folder, euro_parl=True, common_crawl=True, news_commentary=True, min_length=3, max_length=100,
max_length_ratio=1.5, retain_case=True):
"""
Filters and prepares the training data, trains a Byte-Pair Encoding (BPE) model.
:param data_folder: the folder where the files were downloaded
:param euro_parl: include the Europarl v7 dataset in the training data?
:param common_crawl: include the Common Crawl dataset in the training data?
:param news_commentary: include theNews Commentary v9 dataset in the training data?
:param min_length: exclude sequence pairs where one or both are shorter than this minimum BPE length
:param max_length: exclude sequence pairs where one or both are longer than this maximum BPE length
:param max_length_ratio: exclude sequence pairs where one is much longer than the other
:param retain_case: retain case?
"""
# Read raw files and combine
german = list()
english = list()
files = list()
assert euro_parl or common_crawl or news_commentary, "Set at least one dataset to True!"
if euro_parl:
files.append("europarl-v7.de-en")
if common_crawl:
files.append("commoncrawl.de-en")
if news_commentary:
files.append("news-commentary-v9.de-en")
print("\nReading extracted files and combining...")
for file in files:
with codecs.open(os.path.join(data_folder, "extracted files", file + ".de"), "r", encoding="utf-8") as f:
if retain_case:
german.extend(f.read().split("\n"))
else:
german.extend(f.read().lower().split("\n"))
with codecs.open(os.path.join(data_folder, "extracted files", file + ".en"), "r", encoding="utf-8") as f:
if retain_case:
english.extend(f.read().split("\n"))
else:
english.extend(f.read().lower().split("\n"))
assert len(english) == len(german)
# Write to file so stuff can be freed from memory
print("\nWriting to single files...")
with codecs.open(os.path.join(data_folder, "train.en"), "w", encoding="utf-8") as f:
f.write("\n".join(english))
with codecs.open(os.path.join(data_folder, "train.de"), "w", encoding="utf-8") as f:
f.write("\n".join(german))
with codecs.open(os.path.join(data_folder, "train.ende"), "w", encoding="utf-8") as f:
f.write("\n".join(english + german))
del english, german # free some RAM
# Perform BPE
print("\nLearning BPE...")
youtokentome.BPE.train(data=os.path.join(data_folder, "train.ende"), vocab_size=37000,
model=os.path.join(data_folder, "bpe.model"))
# Load BPE model
print("\nLoading BPE model...")
bpe_model = youtokentome.BPE(model=os.path.join(data_folder, "bpe.model"))
# Re-read English, German
print("\nRe-reading single files...")
with codecs.open(os.path.join(data_folder, "train.en"), "r", encoding="utf-8") as f:
english = f.read().split("\n")
with codecs.open(os.path.join(data_folder, "train.de"), "r", encoding="utf-8") as f:
german = f.read().split("\n")
# Filter
print("\nFiltering...")
pairs = list()
for en, de in tqdm(zip(english, german), total=len(english)):
en_tok = bpe_model.encode(en, output_type=youtokentome.OutputType.ID)
de_tok = bpe_model.encode(de, output_type=youtokentome.OutputType.ID)
len_en_tok = len(en_tok)
len_de_tok = len(de_tok)
if min_length < len_en_tok < max_length and \
min_length < len_de_tok < max_length and \
1. / max_length_ratio <= len_de_tok / len_en_tok <= max_length_ratio:
pairs.append((en, de))
else:
continue
print("\nNote: %.2f per cent of en-de pairs were filtered out based on sub-word sequence length limits." % (100. * (
len(english) - len(pairs)) / len(english)))
# Rewrite files
english, german = zip(*pairs)
print("\nRe-writing filtered sentences to single files...")
os.remove(os.path.join(data_folder, "train.en"))
os.remove(os.path.join(data_folder, "train.de"))
os.remove(os.path.join(data_folder, "train.ende"))
with codecs.open(os.path.join(data_folder, "train.en"), "w", encoding="utf-8") as f:
f.write("\n".join(english))
with codecs.open(os.path.join(data_folder, "train.de"), "w", encoding="utf-8") as f:
f.write("\n".join(german))
del english, german, bpe_model, pairs
print("\n...DONE!\n")
def get_positional_encoding(d_model, max_length=100):
"""
Computes positional encoding as defined in the paper.
:param d_model: size of vectors throughout the transformer model
:param max_length: maximum sequence length up to which positional encodings must be calculated
:return: positional encoding, a tensor of size (1, max_length, d_model)
"""
positional_encoding = torch.zeros((max_length, d_model)) # (max_length, d_model)
for i in range(max_length):
for j in range(d_model):
if j % 2 == 0:
positional_encoding[i, j] = math.sin(i / math.pow(10000, j / d_model))
else:
positional_encoding[i, j] = math.cos(i / math.pow(10000, (j - 1) / d_model))
positional_encoding = positional_encoding.unsqueeze(0) # (1, max_length, d_model)
return positional_encoding
def get_lr(step, d_model, warmup_steps):
"""
The LR schedule. This version below is twice the definition in the paper, as used in the official T2T repository.
:param step: training step number
:param d_model: size of vectors throughout the transformer model
:param warmup_steps: number of warmup steps where learning rate is increased linearly; twice the value in the paper, as in the official T2T repo
:return: updated learning rate
"""
lr = 2. * math.pow(d_model, -0.5) * min(math.pow(step, -0.5), step * math.pow(warmup_steps, -1.5))
return lr
def save_checkpoint(epoch, model, optimizer, prefix=''):
"""
Checkpoint saver. Each save overwrites previous save.
:param epoch: epoch number (0-indexed)
:param model: transformer model
:param optimizer: optimized
:param prefix: checkpoint filename prefix
"""
state = {'epoch': epoch,
'model': model,
'optimizer': optimizer}
filename = prefix + 'transformer_checkpoint.pth.tar'
torch.save(state, filename)
def change_lr(optimizer, new_lr):
"""
Scale learning rate by a specified factor.
:param optimizer: optimizer whose learning rate must be changed
:param new_lr: new learning rate
"""
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
class AverageMeter(object):
"""
Keeps track of most recent, average, sum, and count of a metric.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count