Skip to content

Commit

Permalink
[SPMD] Support manual all-reduce (#7576)
Browse files Browse the repository at this point in the history
Summary:
This is to add manual all-reduce support to SPMD and it currently only supports one input tensor. For array support, we can do that in python layer instead.

Test Plan:
python ./test/spmd/test_xla_sharding.py -v -k test_spmd_all_reduce
  • Loading branch information
alanwaketan authored Jun 26, 2024
1 parent 016f7cc commit 0df5c29
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 1 deletion.
40 changes: 40 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,46 @@ def test_spmd_reduce_scatter_canonical_index(self):
expected_x = torch.ones(8, 8 // self.n_devices) * self.n_devices
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"Only runs on TPUv4")
def test_spmd_all_reduce(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# all reduce
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, 1.0,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 8) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"Only runs on TPUv4")
def test_spmd_all_reduce_scale(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# all reduce
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, 0.25,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 8)
self.assertTrue(torch.allclose(x.cpu(), expected_x))


if __name__ == '__main__':
test = unittest.main()
Expand Down
21 changes: 21 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,27 @@ std::vector<xla::XlaOp> BuildAllReduce(
return result;
}

xla::XlaOp BuildAllReduce(AllReduceType reduce_type, xla::XlaOp input,
double scale,
const std::vector<std::vector<int64_t>>& groups) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
// Just a dummy channel handle, and it's required to set the
// use_global_device_ids which is requried for SPMD.
xla::ChannelHandle channel_handle;
channel_handle.set_handle(1);
channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
auto reduce_result = xla::AllReduce(
input, GetReduceComutation(reduce_type, input_shape.element_type()),
std::move(reduce_groups), std::move(channel_handle), std::nullopt, true);
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, input_shape.element_type(), input.builder());
reduce_result = reduce_result * scaling_value;
}
return reduce_result;
}

AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ std::vector<xla::XlaOp> BuildAllReduce(
xla::XlaOp token, double scale,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

xla::XlaOp BuildAllReduce(AllReduceType reduce_type, xla::XlaOp operand,
double scale,
const std::vector<std::vector<int64_t>>& groups);

AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,16 @@ void InitXlaModuleBindings(py::module m) {
return torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
});
m.def("_xla_spmd_all_reduce", [](const std::string& reduce_type,
const at::Tensor& input, double scale,
const py::list& groups) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
auto result = tensor_methods::all_reduce(bridge::GetXlaTensor(input),
GetReduceType(reduce_type), scale,
std::move(replica_groups));
return bridge::AtenFromXlaTensor(std::move(result));
});
m.def("_xla_cast_int4",
[](const at::Tensor& weight,
const std::vector<int>& int4_weight_values) -> at::Tensor {
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/ops/all_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ AllReduce::AllReduce(AllReduceType reduce_type,
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

AllReduce::AllReduce(AllReduceType reduce_type, torch::lazy::Value operand,
double scale, std::vector<std::vector<int64_t>> groups)
: XlaNode(xla_cross_replica_sum, {operand}, GetXlaShape(operand),
/*num_outputs=*/1,
torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale,
groups)),
reduce_type_(reduce_type),
scale_(scale),
groups_(std::move(groups)),
pin_layout_(false),
has_token_(false) {}

torch::lazy::NodePtr AllReduce::Clone(torch::lazy::OpList operands) const {
std::vector<torch::lazy::Value> operand_list(operands.begin(),
operands.end() - 1);
Expand All @@ -48,6 +60,12 @@ torch::lazy::NodePtr AllReduce::Clone(torch::lazy::OpList operands) const {
}

XlaOpVector AllReduce::Lower(LoweringContext* loctx) const {
if (!has_token_) {
auto result = BuildAllReduce(
reduce_type_, loctx->GetOutputOp(operands()[0]), scale_, groups_);
return ReturnOp(result, loctx);
}

auto& operand_list = operands();
std::vector<xla::XlaOp> inputs;
inputs.reserve(operand_list.size());
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/ops/all_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class AllReduce : public XlaNode {
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token, double scale,
std::vector<std::vector<int64_t>> groups, bool pin_layout);
AllReduce(AllReduceType reduce_type, torch::lazy::Value operand, double scale,
std::vector<std::vector<int64_t>> groups);

std::string ToString() const override;

Expand All @@ -31,7 +33,8 @@ class AllReduce : public XlaNode {
AllReduceType reduce_type_;
double scale_;
std::vector<std::vector<int64_t>> groups_;
bool pin_layout_;
bool pin_layout_{false};
bool has_token_{true};
};

} // namespace torch_xla
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ void all_reduce(const std::vector<XLATensorPtr>& inputs,
}
}

XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type,
double scale,
std::vector<std::vector<int64_t>> groups) {
return input->CreateFrom(torch::lazy::MakeNode<AllReduce>(
reduce_type, input->GetIrValue(), scale, std::move(groups)));
}

std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
const XLATensorPtr& input, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ void all_reduce(const std::vector<XLATensorPtr>& inputs,
AllReduceType reduce_type, double scale,
std::vector<std::vector<int64_t>> groups, bool pin_layout);

XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type,
double scale, std::vector<std::vector<int64_t>> groups);

std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
const XLATensorPtr& input, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
Expand Down

0 comments on commit 0df5c29

Please sign in to comment.