Skip to content

Commit

Permalink
prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Aug 15, 2024
1 parent c6ba0e9 commit 3529d20
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 7 deletions.
47 changes: 47 additions & 0 deletions collective_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.distributed as dist
import torch_xla
from typing import List
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable



def dummy_collective_fn(input: torch.Tensor):
# output_tensor = xm.all_reduce(xm.REDUCE_SUM, input)
# output_tensor = dist.all_reduce(input, dist.ReduceOp.SUM)
output_tensor = torch.Tensor([[0, 0, 0, 0]])
# dist.all_gather_into_tensor(output_tensor, input, None)
dist.all_gather(output_tensor, input, None)
return output_tensor

def _mp_fn(index):
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()
world_size = xr.world_size()
if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'):
print(f'skip this test for hw {xm.xla_device_hw(device)}')
return
ordinal_tensor = torch.tensor([index+1], dtype=torch.float).to(device)
met.clear_all()
compiled_collective = torch.compile(
dummy_collective_fn, backend=my_compiler, dynamic=False)
dummy_collective_fn
res_tensor = compiled_collective(ordinal_tensor)
print(res_tensor)
# expected_tensor = torch.tensor(
# [world_size * world_size / 2] * world_size, dtype=torch.float) + 3.0
# torch_xla.sync()
# torch.allclose(res_tensor.cpu(), expected_tensor)
# assert met.metric_data("ExecuteTime")[0] == 1
print(met.metric_data("ExecuteTime"))

if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=(), debug_single_process=False)
16 changes: 9 additions & 7 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,15 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True):
"""
groups = groups or []

# No-op if there is only one device
if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
bool, False):
if isinstance(inputs, torch.Tensor):
return inputs.clone()
else:
return inputs

# PIZ: comment for debug
# # No-op if there is only one device
# if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
# bool, False):
# if isinstance(inputs, torch.Tensor):
# return inputs.clone()
# else:
# return inputs

if isinstance(inputs, torch.Tensor):
result = None
Expand Down
21 changes: 21 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "torch_xla/csrc/xla_graph_executor.h"
#include "xla/shape_util.h"

#include "torch_xla/csrc/XLANativeFunctions.h" // piz for unsqueeze

namespace torch_xla {
namespace {

Expand Down Expand Up @@ -114,6 +116,7 @@ std::shared_ptr<torch::lazy::Value> CreateToken(

at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
std::string /*group_name*/) {
std::cout << "trigger all_reduce lower" << std::endl;
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
auto self_tensor = bridge::GetXlaTensor(self);
// TODO(alanwaketan): Use group_name to generate groups. Currently we just
Expand Down Expand Up @@ -254,6 +257,24 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
}



// function signature should match torch/csrc/distributed/c10d/Functional.cpp
at::Tensor all_gather_into_tensor(const at::Tensor& self, int64_t group_size, std::string group_name
) {
TORCH_LAZY_FN_COUNTER("xla::");
std::cout << "trigger all_gather lower" << std::endl;
auto self_tensor = bridge::GetXlaTensor(self);
std::vector<int64_t> all_groups(group_size);
std::iota(all_groups.begin(), all_groups.end(), 0);
auto result = tensor_methods::all_gather(self_tensor, 0, group_size, {all_groups}, true);
return bridge::AtenFromXlaTensor(result);
}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_gather_into_tensor", all_gather_into_tensor);
}

AllGatherResultCoalesced BuildAllGatherCoalesced(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def __init__(self, prefix_store, rank, size, timeout):
def getBackendName(self):
return 'xla'


def _set_group_name(self, name: str) -> None:
self._group_name = name

@property
def group_name(self):
assert self._group_name
return self._group_name


def _get_reduce_type(self, reduce_op):
if reduce_op == dist.ReduceOp.SUM:
return xm.REDUCE_SUM
Expand Down

0 comments on commit 3529d20

Please sign in to comment.