Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
lingfanyu committed Jan 14, 2019
1 parent 3cd6802 commit 3f6ff9b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
3 changes: 1 addition & 2 deletions examples/pytorch/transformer/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def __call__(self, y_pred, y, norm):
return self.loss.item() * norm

class MultiGPULossCompute(SimpleLossCompute):
def __init__(self, criterion, dev_id, ndev, grad_accum, model, opt=None):
def __init__(self, criterion, ndev, grad_accum, model, opt=None):
super(MultiGPULossCompute, self).__init__(criterion, opt)
self.dev_id = dev_id
self.ndev = ndev
self.grad_accum = grad_accum
self.model = model
Expand Down
14 changes: 7 additions & 7 deletions examples/pytorch/transformer/translation_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from functools import partial

def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True):
def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True):
universal = isinstance(model, UTransformer)
for i, g in enumerate(data_iter):
#print("Dev {} start batch {}".format(dev_rank, i))
Expand All @@ -39,8 +39,8 @@ def run_epoch(data_iter, dev_rank, ndev, model, loss_compute, is_train=True):
for step in range(1, model.MAX_DEPTH + 1):
print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0])))
model.reset_stat()
print('{}: Dev {} average loss: {}, accuracy {}'.format(
"Training" if is_train else "Evaluting",
print('Epoch {} {}: Dev {} average loss: {}, accuracy {}'.format(
epoch, "Training" if is_train else "Evaluating",
dev_rank, loss_compute.avg_loss, loss_compute.accuracy))

def run(dev_id, args):
Expand Down Expand Up @@ -82,8 +82,8 @@ def main(dev_id, args):
if args.ngpu > 1:
dev_rank = dev_id # current device id
ndev = args.ngpu # number of devices (including cpu)
loss_compute = partial(MultiGPULossCompute, criterion, dev_id,
args.ngpu, args.grad_accum, model)
loss_compute = partial(MultiGPULossCompute, criterion, args.ngpu,
args.grad_accum, model)
else: # cpu or single gpu case
dev_rank = 0
ndev = 1
Expand All @@ -96,11 +96,11 @@ def main(dev_id, args):
valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch,
device=device, dev_rank=dev_rank, ndev=ndev)
model.train(True)
run_epoch(train_iter, dev_rank, ndev, model,
run_epoch(epoch, train_iter, dev_rank, ndev, model,
loss_compute(opt=model_opt), is_train=True)
model.att_weight_map = None
model.eval()
run_epoch(valid_iter, dev_rank, ndev, model,
run_epoch(epoch, valid_iter, dev_rank, ndev, model,
loss_compute(opt=None), is_train=False)
end = time.time()
if dev_rank == 0:
Expand Down

0 comments on commit 3f6ff9b

Please sign in to comment.