Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Mar 12, 2024
1 parent 319062e commit 91bf28d
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 99 deletions.
182 changes: 92 additions & 90 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,100 +345,102 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(tensors_data[0], shards[0], at::kFloat));

// Returns multiple input shards, explicitly replicated
int64_t n_devices =
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
if (n_devices > 1) {
auto sharded_xla_data =
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
tensors_data[1]);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
sharded_xla_data);
EXPECT_EQ(shards.size(), n_devices);
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(shards[0], shards[1], at::kFloat));
}
// Returns multiple input shards, explicitly replicated
int64_t n_devices =
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
if (n_devices > 1) {
auto sharded_xla_data = std::dynamic_pointer_cast<
torch_xla::runtime::ComputationClient::Data>(tensors_data[1]);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
sharded_xla_data);
EXPECT_EQ(shards.size(), n_devices);
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(shards[0], shards[1], at::kFloat));
}

// Returns multiple input shards, implicitly replicated
if (n_devices > 1) {
auto sharded_xla_data =
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
tensors_data[2]);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
sharded_xla_data);
EXPECT_EQ(shards.size(), n_devices);
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(shards[0], shards[1], at::kFloat));
// Returns multiple input shards, implicitly replicated
if (n_devices > 1) {
auto sharded_xla_data = std::dynamic_pointer_cast<
torch_xla::runtime::ComputationClient::Data>(tensors_data[2]);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
sharded_xla_data);
EXPECT_EQ(shards.size(), n_devices);
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(shards[0], shards[1], at::kFloat));
}
}
}

TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4});
int64_t n_devices =
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
xla::Array<int64_t> tile_assignment({1, n_devices});
tile_assignment.FillIota(0);
xla::OpSharding tiled = xla::HloSharding::Tile(tile_assignment).ToProto();

// Build simple addition with a sharded input.
xla::XlaBuilder b("builder");
b.SetSharding(tiled);
auto x = xla::Parameter(&b, 0, shape, "p0");
b.ClearSharding();
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
xla::XlaComputation xla_computation =
ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false));
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
instances.push_back({std::move(xla_computation),
bridge::GetDefaultDevice()->toString(),
{bridge::GetDefaultDevice()->toString()},
&shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true});

std::vector<
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
computations = torch_xla::runtime::GetComputationClient()->Compile(
std::move(instances));
torch_xla::runtime::ComputationClient::ComputationPtr computation =
std::make_shared<torch_xla::runtime::ComputationClient::Computation>(
"add", std::move(computations[0]->move_computation()));

// Prepare output sharding propagation, expect a sharded output placeholder.
std::vector<XLATensorPtr> tensors{XLATensor::Create(
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
ShardingUtil::PrepareOutputShardingPropagation(
&tensors, {0}, computation, &data_placeholders, &sharding_specs);

// Check if the output sharding spec is correctly extracted.
EXPECT_EQ(sharding_specs.size(), 1);
if (n_devices > 1) {
// Tiled sharding requires multiple devices.
EXPECT_TRUE(
xla::protobuf_util::ProtobufEquals(tiled, sharding_specs[0]->sharding));
} else {
// Sincle device execution defaults to replication sharding.
EXPECT_TRUE(xla::protobuf_util::ProtobufEquals(
xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding));
}
TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
xla::Shape shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4});
int64_t n_devices =
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
xla::Array<int64_t> tile_assignment({1, n_devices});
tile_assignment.FillIota(0);
xla::OpSharding tiled = xla::HloSharding::Tile(tile_assignment).ToProto();

// Build simple addition with a sharded input.
xla::XlaBuilder b("builder");
b.SetSharding(tiled);
auto x = xla::Parameter(&b, 0, shape, "p0");
b.ClearSharding();
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
xla::XlaComputation xla_computation =
ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false));
std::vector<torch_xla::runtime::ComputationClient::CompileInstance>
instances;
instances.push_back({std::move(xla_computation),
bridge::GetDefaultDevice()->toString(),
{bridge::GetDefaultDevice()->toString()},
&shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true});

std::vector<
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
computations = torch_xla::runtime::GetComputationClient()->Compile(
std::move(instances));
torch_xla::runtime::ComputationClient::ComputationPtr computation =
std::make_shared<torch_xla::runtime::ComputationClient::Computation>(
"add", std::move(computations[0]->move_computation()));

// Prepare output sharding propagation, expect a sharded output placeholder.
std::vector<XLATensorPtr> tensors{XLATensor::Create(
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
ShardingUtil::PrepareOutputShardingPropagation(
&tensors, {0}, computation, &data_placeholders, &sharding_specs);

// Check if the output sharding spec is correctly extracted.
EXPECT_EQ(sharding_specs.size(), 1);
if (n_devices > 1) {
// Tiled sharding requires multiple devices.
EXPECT_TRUE(xla::protobuf_util::ProtobufEquals(
tiled, sharding_specs[0]->sharding));
} else {
// Sincle device execution defaults to replication sharding.
EXPECT_TRUE(xla::protobuf_util::ProtobufEquals(
xla::HloSharding::Replicate().ToProto(),
sharding_specs[0]->sharding));
}

// Check if the placeholder is on a SPMD device (sharded) with no real values.
EXPECT_EQ(data_placeholders.size(), 1);
EXPECT_EQ(
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
data_placeholders[0])
->device(),
"SPMD:0");
EXPECT_FALSE(
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
data_placeholders[0])
->HasValue());
}
// Check if the placeholder is on a SPMD device (sharded) with no real
// values.
EXPECT_EQ(data_placeholders.size(), 1);
EXPECT_EQ(
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
data_placeholders[0])
->device(),
"SPMD:0");
EXPECT_FALSE(
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
data_placeholders[0])
->HasValue());
}

} // namespace cpp_test
} // namespace torch_xla
2 changes: 1 addition & 1 deletion torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ torch::lazy::BackendDevice GetVirtualDevice();
bool UseVirtualDevice(bool force_spmd = false);

// Return true if device is of "SPMD" device type.
bool IsVirtualDevice(const std::string& device);
bool IsVirtualDevice(const std::string& device);

// Return true if SPMD config can be switches. That is, no device has been
// initialized, yet.
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,8 @@ xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
for (int i = 0; i < parameter_shapes.size(); ++i) {
*input_tuple_shape.add_tuple_shapes() = parameter_shapes[i];
}
xla::XlaOp input_tuple = xla::Parameter(&builder, 0, input_tuple_shape, "in.");
xla::XlaOp input_tuple =
xla::Parameter(&builder, 0, input_tuple_shape, "in.");

// Handle the results of the original computation.
std::vector<xla::XlaOp> inner_params;
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ torch::lazy::BackendDataPtr XLATensor::GetXlaData() {
return data()->handle;
}

void XLATensor::SetShardingSpec(const ShardingSpec& sharding, bool allow_overwrite) {
void XLATensor::SetShardingSpec(const ShardingSpec& sharding,
bool allow_overwrite) {
// Existing annotation must be cleared explicitly. We do not clear and
// overwrite the existing sharding on the user's behalf. This is a no-op if
// the same sharding already applied.
Expand Down Expand Up @@ -288,9 +289,10 @@ XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const {
// Re-sync the sharding annotation from the node to the tensor if there is
// one attached to the node. A new sharding annotation is attached
// directly to the node, and gets synced to the tensor after this.
// If sharding is attached via SetShardingSpec, then it flows from the tensor
// to the node. If sharding is attached by the compiler pass, then it first
// gets attached to the graph node, and then synced to the tensor here.
// If sharding is attached via SetShardingSpec, then it flows from the
// tensor to the node. If sharding is attached by the compiler pass, then
// it first gets attached to the graph node, and then synced to the tensor
// here.
if (!sharding ||
(sharding && !ShardingUtil::EqualOpShardings(*new_op_sharding,
sharding->sharding))) {
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,5 @@ void ShardingUtil::SetAutoSharding() {
// This stays on throughout the program.
use_auto_sharding = true;
}
bool ShardingUtil::GetAutoSharding() {
return use_auto_sharding;
}
bool ShardingUtil::GetAutoSharding() { return use_auto_sharding; }
} // namespace torch_xla

0 comments on commit 91bf28d

Please sign in to comment.