Skip to content

Commit

Permalink
torch: fix async bug (#137)
Browse files Browse the repository at this point in the history
* torch: attempt to fix async bug

* torch: flush benchmark log

* torch: only local rank 0 print log
  • Loading branch information
ymjiang authored Oct 29, 2019
1 parent 839d8e6 commit 89bbb3b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
20 changes: 15 additions & 5 deletions byteps/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def _push_pull_grad_async(self, p):
else:
name = self._parameter_names.get(p)
if self._enable_async:
tensor = p
_, ctx = self._compression.compress(tensor)
handle = byteps_push_pull(p, average=False, name="AsyncParam."+name)
# the real handle will be created in step()
handle, ctx = None, None
else:
tensor = p.grad
tensor_compressed, ctx = self._compression.compress(tensor)
Expand Down Expand Up @@ -179,9 +178,20 @@ def step(self, closure=None):
old_weight_map[p] = p.data.clone().detach()
# update
loss = super(self.__class__, self).step(closure)
# get the diff for each weight (in-place)
for p, _ in self._handles.items():

for p, (h, _) in self._handles.items():
# get the diff for each weight (in-place)
p.data.sub_(old_weight_map.get(p))
if h is None:
# create the handler now
if self._is_tensor_instance:
name = self._parameter_names.get(p.__hash__())
else:
name = self._parameter_names.get(p)
handle = byteps_push_pull(p, average=False, name="AsyncParam."+name)
_, ctx = self._compression.compress(p)
self._handles[p] = (handle, ctx)

self.synchronize()
return loss
else:
Expand Down
5 changes: 3 additions & 2 deletions example/pytorch/benchmark_byteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import byteps.torch as bps
import timeit
import numpy as np
import os
import os, sys

# Benchmark settings
parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark',
Expand Down Expand Up @@ -94,9 +94,10 @@ def benchmark_step():


def log(s, nl=True):
if bps.rank() != 0:
if bps.local_rank() != 0:
return
print(s, end='\n' if nl else '')
sys.stdout.flush()


log('Model: %s' % args.model)
Expand Down

0 comments on commit 89bbb3b

Please sign in to comment.