Skip to content

Commit

Permalink
Remove bucket-cap division logic; separate bucket cap for allgather/r…
Browse files Browse the repository at this point in the history
…educescatter
  • Loading branch information
jeffhataws committed Mar 21, 2024
1 parent d7c9958 commit 5006388
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 29 deletions.
24 changes: 24 additions & 0 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,30 @@ def _mp_fn(index):
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized, zero bucket size)
# Reuse ordinal_tensors from previous test
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather_bucketized(
ordinal_tensors,
dim=0,
output=output_tensors,
pin_layout=False,
bucket_cap_mb=0)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# TODO: add test for torch.compile when support for list input is ready

else:
Expand Down
31 changes: 31 additions & 0 deletions test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,37 @@ def _mp_fn(index):
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output_bucketized')

# Testing reduce-scatter with list input and output (buckettized, but zero bucket size)
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
bucket_cap_mb=0,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous(
'test_reduce_scatter_list_input_output_bucketized, zero bucket size')
else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
29 changes: 9 additions & 20 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,12 +635,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):

class CoalescingBuckets(object):

def __init__(self,
func,
input_list,
output_list=None,
groups=None,
bucket_cap_mb=160):
def __init__(self, func, input_list, output_list=None, bucket_cap_mb=160):
if not isinstance(input_list, list) or any(
not isinstance(v, torch.Tensor) for v in input_list):
raise TypeError(
Expand All @@ -663,15 +658,14 @@ def __init__(self,
self._tensor_bucket = []
self._output_bucket = [] if output_list else None
self._bucket_cap = bucket_cap_mb * 1024 * 1024
if groups:
divisor = len(groups[0]) if type(groups[0]) == list else len(groups)
else:
divisor = xrt_world_size()
self._bucket_cap = self._bucket_cap / divisor
self._out_tensors = []

def flush(self):
if len(self._tensor_bucket):
if len(self._tensor_bucket) == 1:
# Use non-coalesced CCOp if its just one tensor
output = self._output_bucket[0] if self._output_bucket else None
self._out_tensors.append(self._func(self._tensor_bucket[0], output))
elif len(self._tensor_bucket):
self._out_tensors.extend(
self._func(self._tensor_bucket, self._output_bucket))
self._total = 0
Expand Down Expand Up @@ -712,7 +706,7 @@ def all_gather_bucketized(input_list,
dim=0,
groups=None,
output=None,
pin_layout=True,
pin_layout=False,
bucket_cap_mb=160):
"""Performs an all-gather operation along a given dimension, with bucketization.
Expand All @@ -739,11 +733,7 @@ def _all_gather_coalesced(_input_list, _output_list=None):
pin_layout=pin_layout)

buckets = CoalescingBuckets(
_all_gather_coalesced,
input_list,
output,
groups=groups,
bucket_cap_mb=bucket_cap_mb)
_all_gather_coalesced, input_list, output, bucket_cap_mb=bucket_cap_mb)
return buckets()


Expand Down Expand Up @@ -967,7 +957,7 @@ def reduce_scatter_bucketized(reduce_type,
shard_count,
groups=None,
output=None,
pin_layout=True,
pin_layout=False,
bucket_cap_mb=160):
"""Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized).
Expand Down Expand Up @@ -1000,7 +990,6 @@ def _reduce_scatter_coalesced(_input_list, _output_list=None):
_reduce_scatter_coalesced,
input_list,
output,
groups=groups,
bucket_cap_mb=bucket_cap_mb)
return buckets()

Expand Down
21 changes: 12 additions & 9 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(
sharding_groups: Optional[Any] = None,
grad_norm_groups: Optional[Any] = None,
lazy_init: bool = False,
bucket_cap_mb: int = 0,
bucket_cap_mb_all_gather: int = 0,
bucket_cap_mb_reduce_scatter: int = 0,
**defaults: Any,
):
super().__init__(params, defaults)
Expand All @@ -80,8 +81,10 @@ def __init__(
self.grad_clipping = grad_clipping
self.max_norm = max_norm if max_norm is not None else 1.0
self.pin_layout = pin_layout
self.bucket_cap_mb = bucket_cap_mb
self.coalesce_cc = bucket_cap_mb > 0
self.bucket_cap_mb_all_gather = bucket_cap_mb_all_gather
self.bucket_cap_mb_reduce_scatter = bucket_cap_mb_reduce_scatter
self.coalesce_cc_all_gather = bucket_cap_mb_all_gather > 0
self.coalesce_cc_reduce_scatter = bucket_cap_mb_reduce_scatter > 0

self._grad_norm = None

Expand Down Expand Up @@ -282,7 +285,7 @@ def step(self, closure=None, **kwargs):
if param.grad is not None:
padded_grad = self._pad_to_world_size(param.grad,
self.local_world_size)
if self.coalesce_cc:
if self.coalesce_cc_reduce_scatter:
padded_grads.append(padded_grad)
else:
grad_shard = xm.reduce_scatter(
Expand All @@ -298,7 +301,7 @@ def step(self, closure=None, **kwargs):
grad_shard = grad_shard.to(dtype=self.optimizer_dtype)
shard.grad = grad_shard

if self.coalesce_cc:
if self.coalesce_cc_reduce_scatter:
grad_shards = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
padded_grads,
Expand All @@ -307,7 +310,7 @@ def step(self, closure=None, **kwargs):
shard_count=self.local_world_size,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
bucket_cap_mb=self.bucket_cap_mb,
bucket_cap_mb=self.bucket_cap_mb_reduce_scatter,
)
index = 0
for param_group, sharded_param_group in zip(
Expand Down Expand Up @@ -341,7 +344,7 @@ def step(self, closure=None, **kwargs):
shard_data = shard.data
if param.dtype != self.optimizer_dtype:
shard_data = shard_data.to(dtype=param.dtype)
if self.coalesce_cc:
if self.coalesce_cc_all_gather:
sharded_data.append(shard_data)
else:
padded_param = xm.all_gather(
Expand All @@ -352,13 +355,13 @@ def step(self, closure=None, **kwargs):
)
param.data.copy_(padded_param.data[:param.size(0)])

if self.coalesce_cc:
if self.coalesce_cc_all_gather:
padded_params = xm.all_gather_bucketized(
sharded_data,
dim=0,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
bucket_cap_mb=self.bucket_cap_mb,
bucket_cap_mb=self.bucket_cap_mb_all_gather,
)
index = 0
for param_group, sharded_param_group in zip(
Expand Down

0 comments on commit 5006388

Please sign in to comment.