Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add out-of-place reduce-scatter coalescing #6058

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ def _mp_fn(index):

xm.rendezvous('test_reduce_scatter_list_input')

# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output')
else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
21 changes: 16 additions & 5 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,13 +785,24 @@ def reduce_scatter(reduce_type,
elif isinstance(input, list) and all(
isinstance(v, torch.Tensor) for v in input):
if output != None:
raise RuntimeError(
"For xm.reduce_scatter with list of tensors input, output != None is not yet supported."
)
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(input):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim,
shard_count, groups or [], pin_layout)
reduce_type, input, token, scale, scatter_dim, shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
else:
Expand Down
94 changes: 62 additions & 32 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,39 +236,52 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
std::make_shared<torch::lazy::Value>(new_token));
}

std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::string& reduce_type, at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
torch::lazy::Value new_token;
new_token = tensor_methods::reduce_scatter_out(
out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type),
scale, scatter_dim, shard_count, replica_groups, pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
ReduceScatterCoalesced(const std::string& reduce_type,
const std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token,
double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups, pin_layout);
xtensors, *token, GetReduceType(reduce_type), scale, scatter_dim,
shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
}
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::string& reduce_type, at::Tensor& output, const at::Tensor& input,
std::shared_ptr<torch::lazy::Value> ReduceScatterCoalescedOut(
const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
torch::lazy::Value new_token;
new_token = tensor_methods::reduce_scatter_out(
out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type),
scale, scatter_dim, shard_count, replica_groups, pin_layout);
new_token = tensor_methods::reduce_scatter_coalesced_out(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups, pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

Expand Down Expand Up @@ -1346,45 +1359,62 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_reduce_scatter_coalesced",
[](const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
m.def("_xla_reduce_scatter_out",
[](const std::string& reduce_type, at::Tensor& output,
const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> result;
at::Tensor result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) = ReduceScatterCoalesced(
reduce_type, outputs, inputs, token, scale, scatter_dim,
shard_count, replica_groups, pin_layout);
}
auto result_list = py::list(result.size() + 1);
for (int i = 0; i < result.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
result[i], /*requires_grad=*/result[i].requires_grad());
new_token = ReduceScatterOut(reduce_type, output, input, token,
scale, scatter_dim, shard_count,
replica_groups, pin_layout);
}
result_list[result.size()] = new_token;
return result_list;
return new_token;
});
m.def("_xla_reduce_scatter_out",
[](const std::string& reduce_type, at::Tensor& output,
const at::Tensor& input,
m.def(
"_xla_reduce_scatter_coalesced",
[](const std::string& reduce_type, const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) = ReduceScatterCoalesced(
reduce_type, inputs, token, scale, scatter_dim, shard_count,
replica_groups, pin_layout);
}
auto result_list = py::list(result.size() + 1);
for (int i = 0; i < result.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
result[i], /*requires_grad=*/result[i].requires_grad());
}
result_list[result.size()] = new_token;
return result_list;
});
m.def("_xla_reduce_scatter_coalesced_out",
[](const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
at::Tensor result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
new_token = ReduceScatterOut(reduce_type, output, input, token,
scale, scatter_dim, shard_count,
replica_groups, pin_layout);
new_token = ReduceScatterCoalescedOut(
reduce_type, outputs, inputs, token, scale, scatter_dim,
shard_count, replica_groups, pin_layout);
}
return new_token;
});
Expand Down
27 changes: 21 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,12 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
}

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value>
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs,
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
XLA_CHECK(outputs.empty() || outputs.size() == inputs.size());
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
Expand All @@ -412,13 +410,30 @@ reduce_scatter_coalesced(const std::vector<XLATensorPtr>& outputs,
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
if (!outputs.empty()) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
}
return {result, torch::lazy::Value(node, inputs.size())};
}

torch::lazy::Value reduce_scatter_coalesced_out(
const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<ReduceScatterCoalesced>(
reduce_type, input_values, token, scale, scatter_dim, shard_count,
std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
return torch::lazy::Value(node, inputs.size());
}

std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
const XLATensorPtr& input, const torch::lazy::Value& token,
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
bool pin_layout);

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value>
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs,
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout);

torch::lazy::Value reduce_scatter_coalesced_out(
const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
const XLATensorPtr& input, const torch::lazy::Value& token,
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
Expand Down