Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ReplicateShardedData for int type #5404

Merged
merged 1 commit into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,14 @@ def test_mark_sharding_ir_with_multiple_output(self):
self.assertNotIn('convert(s32[8]{0} %get-tuple-element.25), sharding',
torch_xla._XLAC._get_xla_tensors_hlo([xt_index]))

def test_sharded_tensor_to_cpu_int_type(self):
partition_spec = (0, 1)
t1 = torch.arange(64).reshape(8, 8)
xt1 = t1.clone().to(xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((self.n_devices, 1)),
partition_spec)
self.assertTrue(torch.allclose(t1, xst1.cpu()))


if __name__ == '__main__':
test = unittest.main()
Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,24 @@ class ComputationClient {
std::vector<std::string> devices,
const xla::Shape* output_shape,
bool parameter_is_tupled_arguments = false,
bool is_sharded = false)
bool is_sharded = false,
bool allow_spmd_sharding_propagation_to_output = true)
: computation(std::move(computation)),
compilation_device(std::move(compilation_device)),
devices(std::move(devices)),
output_shape(output_shape),
parameter_is_tupled_arguments(parameter_is_tupled_arguments),
is_sharded(is_sharded) {}
is_sharded(is_sharded),
allow_spmd_sharding_propagation_to_output(
allow_spmd_sharding_propagation_to_output) {}

xla::XlaComputation computation;
std::string compilation_device;
std::vector<std::string> devices;
const xla::Shape* output_shape = nullptr;
bool parameter_is_tupled_arguments;
bool is_sharded;
bool allow_spmd_sharding_propagation_to_output;
};

struct ExecuteComputationOptions : public ClientExecuteOptions {};
Expand Down
27 changes: 16 additions & 11 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "xla/shape.h"
#include "xla/stream_executor/tpu/tpu_initializer_helper.h"

using xla::internal::XlaBuilderFriend;

namespace torch_xla {
namespace runtime {

Expand Down Expand Up @@ -324,21 +326,22 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
// Data is replicated, return the first shard
return sharded_data->shards[0];
}
xla::XlaBuilder b("ReplicateShardedData");
xla::XlaBuilder builder("ReplicateShardedData");
xla::Shape shape = sharded_data->shape();
b.SetSharding(sharded_data->GetSharding());
builder.SetSharding(sharded_data->GetSharding());

// perform a simple identity calculation to reassemble the input as
// replicated output.
auto x = xla::Parameter(&b, 0, shape, "p0");
b.SetSharding(xla::HloSharding::Replicate().ToProto());
xla::XlaOp scalar_two_op =
xla::ConvertElementType(xla::ConstantR0(&b, 2), shape.element_type());
auto y = xla::Div(x, scalar_two_op);
auto z = xla::Add(y, y);
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "p0");
builder.SetSharding(xla::HloSharding::Replicate().ToProto());
xla::XlaOp scalar_zero_op = xla::ConvertElementType(
xla::ConstantR0(&builder, 0), shape.element_type());
xla::XlaOp y = xla::Add(x, scalar_zero_op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to have a totally no-op computation and remove the Add? Since we're setting allow_spmd_sharding_propagation_to_output=false, the result will always be replicated regardless of the input sharding

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me try

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented out the xla::add gives incorrect result. I think there are ways to make it work but since it is not a priority I will leave that to a separate pr.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for trying.

auto instruction = XlaBuilderFriend::GetInstruction(y);
*instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto();

xla::XlaComputation computation =
ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false));
ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false));
xla::ProgramShape program_shape =
ConsumeValue(computation.GetProgramShape());

Expand All @@ -348,7 +351,8 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
instances.push_back({std::move(computation), device,
GetCompilationDevices(device, {}), &shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true});
/*is_sharded=*/true,
/*allow_spmd_sharding_propagation_to_output=*/false});
std::vector<
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
computations = Compile(std::move(instances));
Expand Down Expand Up @@ -426,7 +430,8 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
// outputs. Setting this to true would wrapping the sharded outputs in
// PjRtShardedData.
compile_options.executable_build_options
.set_allow_spmd_sharding_propagation_to_output({true});
.set_allow_spmd_sharding_propagation_to_output(
{instance.allow_spmd_sharding_propagation_to_output});
compile_options.executable_build_options.set_num_partitions(
client_->device_count());
compile_options.executable_build_options.set_num_replicas(1);
Expand Down