Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jan 15, 2019
1 parent 3f6ff9b commit 6e40f29
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 83 deletions.
14 changes: 7 additions & 7 deletions examples/pytorch/transformer/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def __init__(self, path, exts, train='train', valid='valid', test='test', vocab=
vocab_path = os.path.join(path, vocab)
self.src = {}
self.tgt = {}
with open(os.path.join(path, train + '.' + exts[0]), 'r') as f:
with open(os.path.join(path, train + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['train'] = f.readlines()
with open(os.path.join(path, train + '.' + exts[1]), 'r') as f:
with open(os.path.join(path, train + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['train'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f:
with open(os.path.join(path, valid + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['valid'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f:
with open(os.path.join(path, valid + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['valid'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[0]), 'r') as f:
with open(os.path.join(path, test + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['test'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[1]), 'r') as f:
with open(os.path.join(path, test + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['test'] = f.readlines()

if not os.path.exists(vocab_path):
Expand Down Expand Up @@ -103,7 +103,7 @@ def __call__(self, graph_pool, mode='train', batch_size=32, k=1,
'''
src_data, tgt_data = self.src[mode], self.tgt[mode]
n = len(src_data)
# make sure all devices have the same number of batches
# make sure all devices have the same number of batch
n = n // ndev * ndev

# XXX: is partition then shuffle equivalent to shuffle then partition?
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/transformer/dataset/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def load(self, path):
self.vocab_lst.append(self.pad_token)
if self.unk_token is not None:
self.vocab_lst.append(self.unk_token)
with open(path, 'r') as f:
with open(path, 'r', encoding='utf-8') as f:
for token in f.readlines():
token = token.strip()
self.vocab_lst.append(token)
Expand Down
56 changes: 0 additions & 56 deletions examples/pytorch/transformer/parallel.py

This file was deleted.

32 changes: 13 additions & 19 deletions examples/pytorch/transformer/translation_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Multi-GPU support is required to train the model on WMT14.
"""
from modules import *
from parallel import *
from loss import *
from optims import *
from dataset import *
Expand All @@ -20,19 +19,13 @@ def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=Tr
for i, g in enumerate(data_iter):
#print("Dev {} start batch {}".format(dev_rank, i))
with T.set_grad_enabled(is_train):
if isinstance(model, list):
model = model[:len(gs)]
output = parallel_apply(model, g)
tgt_y = [g.tgt_y for g in gs]
n_tokens = [g.n_tokens for g in gs]
if universal:
output, loss_act = model(g)
if is_train: loss_act.backward(retain_graph=True)
else:
if universal:
output, loss_act = model(g)
if is_train: loss_act.backward(retain_graph=True)
else:
output = model(g)
tgt_y = g.tgt_y
n_tokens = g.n_tokens
output = model(g)
tgt_y = g.tgt_y
n_tokens = g.n_tokens
loss = loss_compute(output, tgt_y, n_tokens)

if universal:
Expand Down Expand Up @@ -75,7 +68,7 @@ def main(dev_id, args):
model.generator.proj.weight = model.tgt_embed.lut.weight

model, criterion = model.to(device), criterion.to(device)
model_opt = NoamOpt(dim_model, 1, 400,
model_opt = NoamOpt(dim_model, 1, 4000 * 1300 / (args.batch * max(1, args.ngpu)),
T.optim.Adam(model.parameters(), lr=1e-3,
betas=(0.9, 0.98), eps=1e-9))

Expand All @@ -98,12 +91,13 @@ def main(dev_id, args):
model.train(True)
run_epoch(epoch, train_iter, dev_rank, ndev, model,
loss_compute(opt=model_opt), is_train=True)
model.att_weight_map = None
model.eval()
run_epoch(epoch, valid_iter, dev_rank, ndev, model,
loss_compute(opt=None), is_train=False)
end = time.time()
if dev_rank == 0:
model.att_weight_map = None
model.eval()
run_epoch(epoch, valid_iter, dev_rank, 1, model,
loss_compute(opt=None), is_train=False)
end = time.time()
time.sleep(1)
print("epoch time: {}".format(end - start))

"""
Expand Down

0 comments on commit 6e40f29

Please sign in to comment.