Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Support manual all-reduce #7576

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,46 @@ def test_spmd_reduce_scatter_canonical_index(self):
expected_x = torch.ones(8, 8 // self.n_devices) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"Only runs on TPUv4")
def test_spmd_all_reduce(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# all reduce
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, 1.0,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 8) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"Only runs on TPUv4")
def test_spmd_all_reduce_scale(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# all reduce
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, 0.25,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 8)
self.assertTrue(torch.allclose(x.cpu(), expected_x))


if __name__ == '__main__':
test = unittest.main()
Expand Down
21 changes: 21 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,27 @@ std::vector<xla::XlaOp> BuildAllReduce(
return result;
}

xla::XlaOp BuildAllReduce(AllReduceType reduce_type, xla::XlaOp input,
double scale,
const std::vector<std::vector<int64_t>>& groups) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
// Just a dummy channel handle, and it's required to set the
// use_global_device_ids which is requried for SPMD.
xla::ChannelHandle channel_handle;
channel_handle.set_handle(1);
channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
auto reduce_result = xla::AllReduce(
input, GetReduceComutation(reduce_type, input_shape.element_type()),
std::move(reduce_groups), std::move(channel_handle), std::nullopt, true);
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, input_shape.element_type(), input.builder());
reduce_result = reduce_result * scaling_value;
}
return reduce_result;
}

AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ std::vector<xla::XlaOp> BuildAllReduce(
xla::XlaOp token, double scale,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

xla::XlaOp BuildAllReduce(AllReduceType reduce_type, xla::XlaOp operand,
double scale,
const std::vector<std::vector<int64_t>>& groups);

AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,16 @@ void InitXlaModuleBindings(py::module m) {
return torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
});
m.def("_xla_spmd_all_reduce", [](const std::string& reduce_type,
const at::Tensor& input, double scale,
const py::list& groups) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
auto result = tensor_methods::all_reduce(bridge::GetXlaTensor(input),
GetReduceType(reduce_type), scale,
std::move(replica_groups));
return bridge::AtenFromXlaTensor(std::move(result));
});
m.def("_xla_cast_int4",
[](const at::Tensor& weight,
const std::vector<int>& int4_weight_values) -> at::Tensor {
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/ops/all_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ AllReduce::AllReduce(AllReduceType reduce_type,
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

AllReduce::AllReduce(AllReduceType reduce_type, torch::lazy::Value operand,
double scale, std::vector<std::vector<int64_t>> groups)
: XlaNode(xla_cross_replica_sum, {operand}, GetXlaShape(operand),
/*num_outputs=*/1,
torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale,
groups)),
reduce_type_(reduce_type),
scale_(scale),
groups_(std::move(groups)),
pin_layout_(false),
has_token_(false) {}

torch::lazy::NodePtr AllReduce::Clone(torch::lazy::OpList operands) const {
std::vector<torch::lazy::Value> operand_list(operands.begin(),
operands.end() - 1);
Expand All @@ -48,6 +60,12 @@ torch::lazy::NodePtr AllReduce::Clone(torch::lazy::OpList operands) const {
}

XlaOpVector AllReduce::Lower(LoweringContext* loctx) const {
if (!has_token_) {
auto result = BuildAllReduce(
reduce_type_, loctx->GetOutputOp(operands()[0]), scale_, groups_);
return ReturnOp(result, loctx);
}

auto& operand_list = operands();
std::vector<xla::XlaOp> inputs;
inputs.reserve(operand_list.size());
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/ops/all_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class AllReduce : public XlaNode {
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token, double scale,
std::vector<std::vector<int64_t>> groups, bool pin_layout);
AllReduce(AllReduceType reduce_type, torch::lazy::Value operand, double scale,
std::vector<std::vector<int64_t>> groups);

std::string ToString() const override;

Expand All @@ -31,7 +33,8 @@ class AllReduce : public XlaNode {
AllReduceType reduce_type_;
double scale_;
std::vector<std::vector<int64_t>> groups_;
bool pin_layout_;
bool pin_layout_{false};
bool has_token_{true};
};

} // namespace torch_xla
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ void all_reduce(const std::vector<XLATensorPtr>& inputs,
}
}

XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you call it all_reduce _no_token, the only difference in signature is it does not take pin_layout but the main difference in the op is that it does not set token.. It is better to reflect that in the name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I can follow up with that.

double scale,
std::vector<std::vector<int64_t>> groups) {
return input->CreateFrom(torch::lazy::MakeNode<AllReduce>(
reduce_type, input->GetIrValue(), scale, std::move(groups)));
}

std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
const XLATensorPtr& input, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ void all_reduce(const std::vector<XLATensorPtr>& inputs,
AllReduceType reduce_type, double scale,
std::vector<std::vector<int64_t>> groups, bool pin_layout);

XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type,
double scale, std::vector<std::vector<int64_t>> groups);

std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
const XLATensorPtr& input, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
Expand Down
Loading