Skip to content

Commit

Permalink
Fix ReplicateShardedData for int type (#5404)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and will-cromar committed Sep 14, 2023
1 parent 1d99226 commit 33500a5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
8 changes: 8 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,14 @@ def test_mark_sharding_ir_with_multiple_output(self):
self.assertNotIn('convert(s32[8]{0} %get-tuple-element.25), sharding',
torch_xla._XLAC._get_xla_tensors_hlo([xt_index]))

def test_sharded_tensor_to_cpu_int_type(self):
partition_spec = (0, 1)
t1 = torch.arange(64).reshape(8, 8)
xt1 = t1.clone().to(xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((self.n_devices, 1)),
partition_spec)
self.assertTrue(torch.allclose(t1, xst1.cpu()))


if __name__ == '__main__':
test = unittest.main()
Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,24 @@ class ComputationClient {
std::vector<std::string> devices,
const xla::Shape* output_shape,
bool parameter_is_tupled_arguments = false,
bool is_sharded = false)
bool is_sharded = false,
bool allow_spmd_sharding_propagation_to_output = true)
: computation(std::move(computation)),
compilation_device(std::move(compilation_device)),
devices(std::move(devices)),
output_shape(output_shape),
parameter_is_tupled_arguments(parameter_is_tupled_arguments),
is_sharded(is_sharded) {}
is_sharded(is_sharded),
allow_spmd_sharding_propagation_to_output(
allow_spmd_sharding_propagation_to_output) {}

xla::XlaComputation computation;
std::string compilation_device;
std::vector<std::string> devices;
const xla::Shape* output_shape = nullptr;
bool parameter_is_tupled_arguments;
bool is_sharded;
bool allow_spmd_sharding_propagation_to_output;
};

struct ExecuteComputationOptions : public ClientExecuteOptions {};
Expand Down
27 changes: 16 additions & 11 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"

using xla::internal::XlaBuilderFriend;

namespace torch_xla {
namespace runtime {

Expand Down Expand Up @@ -300,21 +302,22 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
// Data is replicated, return the first shard
return sharded_data->shards[0];
}
xla::XlaBuilder b("ReplicateShardedData");
xla::XlaBuilder builder("ReplicateShardedData");
xla::Shape shape = sharded_data->shape();
b.SetSharding(sharded_data->GetSharding());
builder.SetSharding(sharded_data->GetSharding());

// perform a simple identity calculation to reassemble the input as
// replicated output.
auto x = xla::Parameter(&b, 0, shape, "p0");
b.SetSharding(xla::HloSharding::Replicate().ToProto());
xla::XlaOp scalar_two_op =
xla::ConvertElementType(xla::ConstantR0(&b, 2), shape.element_type());
auto y = xla::Div(x, scalar_two_op);
auto z = xla::Add(y, y);
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "p0");
builder.SetSharding(xla::HloSharding::Replicate().ToProto());
xla::XlaOp scalar_zero_op = xla::ConvertElementType(
xla::ConstantR0(&builder, 0), shape.element_type());
xla::XlaOp y = xla::Add(x, scalar_zero_op);
auto instruction = XlaBuilderFriend::GetInstruction(y);
*instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto();

xla::XlaComputation computation =
ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false));
ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false));
xla::ProgramShape program_shape =
ConsumeValue(computation.GetProgramShape());

Expand All @@ -324,7 +327,8 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
instances.push_back({std::move(computation), device,
GetCompilationDevices(device, {}), &shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true});
/*is_sharded=*/true,
/*allow_spmd_sharding_propagation_to_output=*/false});
std::vector<
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
computations = Compile(std::move(instances));
Expand Down Expand Up @@ -402,7 +406,8 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
// outputs. Setting this to true would wrapping the sharded outputs in
// PjRtShardedData.
compile_options.executable_build_options
.set_allow_spmd_sharding_propagation_to_output({true});
.set_allow_spmd_sharding_propagation_to_output(
{instance.allow_spmd_sharding_propagation_to_output});
compile_options.executable_build_options.set_num_partitions(
client_->device_count());
compile_options.executable_build_options.set_num_replicas(1);
Expand Down

0 comments on commit 33500a5

Please sign in to comment.