diff --git a/collective_op_test.py b/collective_op_test.py new file mode 100755 index 000000000000..87d11ea9d59f --- /dev/null +++ b/collective_op_test.py @@ -0,0 +1,47 @@ +import torch +import torch.distributed as dist +import torch_xla +from typing import List +from torch_xla import runtime as xr +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met + + +def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + + + +def dummy_collective_fn(input: torch.Tensor): + # output_tensor = xm.all_reduce(xm.REDUCE_SUM, input) + # output_tensor = dist.all_reduce(input, dist.ReduceOp.SUM) + output_tensor = torch.Tensor([[0, 0, 0, 0]]) + # dist.all_gather_into_tensor(output_tensor, input, None) + dist.all_gather(output_tensor, input, None) + return output_tensor + +def _mp_fn(index): + dist.init_process_group("xla", init_method='xla://') + device = xm.xla_device() + world_size = xr.world_size() + if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): + print(f'skip this test for hw {xm.xla_device_hw(device)}') + return + ordinal_tensor = torch.tensor([index+1], dtype=torch.float).to(device) + met.clear_all() + compiled_collective = torch.compile( + dummy_collective_fn, backend=my_compiler, dynamic=False) + dummy_collective_fn + res_tensor = compiled_collective(ordinal_tensor) + print(res_tensor) + # expected_tensor = torch.tensor( + # [world_size * world_size / 2] * world_size, dtype=torch.float) + 3.0 + # torch_xla.sync() + # torch.allclose(res_tensor.cpu(), expected_tensor) + # assert met.metric_data("ExecuteTime")[0] == 1 + print(met.metric_data("ExecuteTime")) + +if __name__ == '__main__': + torch_xla.launch(_mp_fn, args=(), debug_single_process=False) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index d3c243335f3d..d8a8e72a2e1f 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -430,13 +430,15 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True): """ groups = groups or [] - # No-op if there is only one device - if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE', - bool, False): - if isinstance(inputs, torch.Tensor): - return inputs.clone() - else: - return inputs + + # PIZ: comment for debug + # # No-op if there is only one device + # if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE', + # bool, False): + # if isinstance(inputs, torch.Tensor): + # return inputs.clone() + # else: + # return inputs if isinstance(inputs, torch.Tensor): result = None diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 6cb689fad655..67d7e4dd5281 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -18,6 +18,8 @@ #include "torch_xla/csrc/xla_graph_executor.h" #include "xla/shape_util.h" +#include "torch_xla/csrc/XLANativeFunctions.h" // piz for unsqueeze + namespace torch_xla { namespace { @@ -114,6 +116,7 @@ std::shared_ptr CreateToken( at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp, std::string /*group_name*/) { + std::cout << "trigger all_reduce lower" << std::endl; TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto self_tensor = bridge::GetXlaTensor(self); // TODO(alanwaketan): Use group_name to generate groups. Currently we just @@ -254,6 +257,24 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, return {all_gather_result, token_handler.GetNewToken(all_gather_result)}; } + + +// function signature should match torch/csrc/distributed/c10d/Functional.cpp +at::Tensor all_gather_into_tensor(const at::Tensor& self, int64_t group_size, std::string group_name + ) { + TORCH_LAZY_FN_COUNTER("xla::"); + std::cout << "trigger all_gather lower" << std::endl; + auto self_tensor = bridge::GetXlaTensor(self); + std::vector all_groups(group_size); + std::iota(all_groups.begin(), all_groups.end(), 0); + auto result = tensor_methods::all_gather(self_tensor, 0, group_size, {all_groups}, true); + return bridge::AtenFromXlaTensor(result); +} + +TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) { + m.impl("all_gather_into_tensor", all_gather_into_tensor); +} + AllGatherResultCoalesced BuildAllGatherCoalesced( absl::Span inputs, xla::XlaOp token, int64_t dim, int64_t shard_count, const std::vector>& groups, diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 31266fa078d4..763b2d5111ee 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -46,6 +46,16 @@ def __init__(self, prefix_store, rank, size, timeout): def getBackendName(self): return 'xla' + + def _set_group_name(self, name: str) -> None: + self._group_name = name + + @property + def group_name(self): + assert self._group_name + return self._group_name + + def _get_reduce_type(self, reduce_op): if reduce_op == dist.ReduceOp.SUM: return xm.REDUCE_SUM