Skip to content

Commit

Permalink
[Semi-Auto] add split spmd rule (PaddlePaddle#55397)
Browse files Browse the repository at this point in the history
* add split spmd rule

* add pytest in cmake file

* small fix
  • Loading branch information
pkuzyc authored and wyf committed Aug 30, 2023
1 parent 0be3ada commit 0d9c0cd
Show file tree
Hide file tree
Showing 6 changed files with 379 additions and 2 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr) {
void VerifySpecs(const std::vector<DistTensorSpec>& specs,
const std::string& op_name) {
for (size_t i = 0, n = specs.size(); i < n; ++i) {
std::vector<int64_t> shape = specs[i].shape();
std::vector<int64_t> dims_mapping = specs[i].dims_mapping();
const std::vector<int64_t>& shape = specs[i].shape();
const std::vector<int64_t>& dims_mapping = specs[i].dims_mapping();
PADDLE_ENFORCE_EQ(shape.size(),
dims_mapping.size(),
phi::errors::InvalidArgument(
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"

// TODO(ljz) Automatic this process in cmake file.
namespace paddle {
Expand Down Expand Up @@ -150,6 +151,10 @@ REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);
REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule);
REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule);

// split rule
REGISTER_SPMD_RULE(split, SplitSPMDRule);
REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule);

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
126 changes: 126 additions & 0 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
#include <algorithm>
#include <typeinfo>
#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace paddle {
namespace distributed {
namespace auto_parallel {

using phi::distributed::auto_parallel::str_join;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SplitSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in split must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "split");

// step1: Build Einsum Notation
int64_t ndim = input_specs[0].shape().size();
int64_t noutput = 0;
// split api uses num or sections as attribute
if (attrs.find("num") != attrs.end()) {
noutput = ExtractAttr<int64_t>("num", attrs);
} else if (attrs.find("sections") != attrs.end()) {
std::vector<int64_t> sections =
ExtractAttr<std::vector<int64_t>>("sections", attrs);
noutput = sections.size();
}
int64_t axis = ExtractAttr<int>("axis", attrs);
if (axis < 0) {
axis += ndim;
}
std::string alphabet = "abcdefghijlmnopqrstuvwxyz";

// get einsum notation for input, use a special
// notation 'k' to mark the splitted axis in input
std::vector<std::string> input_axes_vec;
std::string input_axes = alphabet.substr(0, ndim);
input_axes[axis] = 'k';
input_axes_vec.emplace_back(input_axes);

// get einsum notation for output
std::string output_axes(input_axes);
// the splitted axis cannot be sharded, set its notation
// with the special '1' to set its dim mapping to -1.
output_axes[axis] = '1';

// step2: Sharding Propogation
// step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

// step2.2: infer output dimsmapping from merged input dimsmapping
std::vector<int64_t> output_dims_mapping =
GetDimsMappingForAxes(output_axes, axis_to_dim_map);

// get the dist attributes for all outputs, the
// dist attributes are same for all outputs.
std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < noutput; i++) {
output_dist_attrs.emplace_back(
CopyTensorDistAttrForOutput(input_specs[0].dist_attr()));
output_dist_attrs[i].set_dims_mapping(output_dims_mapping);
}

// step2.3 get new dist attribute for input. the splitted
// cannot be sharded, if it is sharded, set it to replicated.
std::vector<TensorDistAttr> new_input_dist_attrs;
new_input_dist_attrs.emplace_back(input_specs[0].dist_attr());
std::vector<int64_t> new_input_dims_mapping(input_specs[0].dims_mapping());
new_input_dims_mapping[axis] = -1;
new_input_dist_attrs[0].set_dims_mapping(new_input_dims_mapping);

// Step2.4 handle input tensor partial (TODO)
VLOG(4) << "SplitSPMDRule InferForward: ";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "einsum_notation: " << input_axes << " src_dims_mapping: ["
<< str_join(input_specs[i].dims_mapping()) << "] "
<< "dst_dims_mapping: ["
<< str_join(new_input_dist_attrs[i].dims_mapping()) << "]";
}
for (int64_t i = 0; i < noutput; i++) {
VLOG(4) << "Output" << std::to_string(i) << " dims_mapping: ["
<< str_join(output_dims_mapping) << "]";
}

return {new_input_dist_attrs, output_dist_attrs};
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SplitSPMDRule::InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of SplitPMDRule is NOT implemented yet."));

return {};
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iterator>
#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"

namespace paddle {
namespace distributed {
namespace auto_parallel {

class SplitSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
1 change: 1 addition & 0 deletions test/auto_parallel/spmd_rules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_matmul_rule MODULES test_embedding_rule)
py_test_modules(test_matmul_rule MODULES test_replicated_rule)
py_test_modules(test_matmul_rule MODULES test_softmax_rule)
py_test_modules(test_split_rule MODULES test_split_rule)
# End of unittests WITH single card WITHOUT timeout

endif()
Loading

0 comments on commit 0d9c0cd

Please sign in to comment.