Skip to content

Commit

Permalink
merge upstream and resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfeiyu committed Dec 10, 2021
2 parents f1f7f4b + 43f19cc commit 9dcbf5c
Show file tree
Hide file tree
Showing 100 changed files with 5,024 additions and 813 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_test
cc_test(test_repeated_fc_relu_fuse_pass_cc SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass)
cc_test(test_fc_elementwise_layernorm_fuse_pass_cc SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass)
cc_test(test_skip_layernorm_fuse_pass SRCS skip_layernorm_fuse_pass_tester.cc DEPS skip_layernorm_fuse_pass)
cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass)
cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -338,3 +339,9 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {

REGISTER_PASS(fc_elementwise_layernorm_fuse_pass,
paddle::framework::ir::FCElementwiseLayerNormFusePass);
REGISTER_PASS_CAPABILITY(fc_elementwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0)
.LE("elementwise_add", 1)
.EQ("layer_norm", 0));
56 changes: 56 additions & 0 deletions paddle/fluid/framework/ir/ipu/avg_shard_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/avg_shard_pass.h"

#include "paddle/fluid/platform/device/ipu/ipu_backend.h"

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

void AvgShardPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter AvgShardPass::ApplyImpl";

std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
platform::ipu::IpuBackend::GetInstance();

if (ipu_backend->GetIpuStrategy()->need_avg_shard) {
VLOG(10) << "start AvgShardPass";
auto nodes = ir::TopologySortOperations(*graph);
auto num_ipus = ipu_backend->GetIpuStrategy()->num_ipus;

int shard_position = nodes.size() / num_ipus;
int index_and_stage = -1;
for (int i = 0; i < nodes.size(); i++) {
if ((i % shard_position) == 0 && index_and_stage < num_ipus - 1) {
index_and_stage++;
}
nodes[i]->Op()->SetAttr("ipu_index", index_and_stage);
nodes[i]->Op()->SetAttr("ipu_stage", index_and_stage);
}
VLOG(10) << "end AvgShardPass";
}

VLOG(10) << "leave AvgShardPass::ApplyImpl";
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(avg_shard_pass, paddle::framework::ir::AvgShardPass);
30 changes: 30 additions & 0 deletions paddle/fluid/framework/ir/ipu/avg_shard_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

class AvgShardPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
133 changes: 133 additions & 0 deletions paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h"

#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

void ForwardGraphExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter ForwardGraphExtractPass::ApplyImpl";

std::unordered_map<OpRole, std::unordered_set<ir::Node*>> all_ops{
{OpRole::kForward, {}}, {OpRole::kBackward, {}},
{OpRole::kOptimize, {}}, {OpRole::kRPC, {}},
{OpRole::kDist, {}}, {OpRole::kLRSched, {}},
{OpRole::kLoss, {}}, {OpRole::kNotSpecified, {}}};
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) {
continue;
}
auto op_role = BOOST_GET_MUTABLE(int, node->Op()->GetAttr("op_role"));
if (op_role == static_cast<int>(OpRole::kForward)) {
all_ops[OpRole::kForward].insert(node);
} else if (op_role == static_cast<int>(OpRole::kBackward)) {
all_ops[OpRole::kBackward].insert(node);
} else if (op_role == static_cast<int>(OpRole::kOptimize)) {
all_ops[OpRole::kOptimize].insert(node);
} else if (op_role == static_cast<int>(OpRole::kRPC)) {
} else if (op_role == static_cast<int>(OpRole::kDist)) {
} else if (op_role == static_cast<int>(OpRole::kLRSched)) {
} else if (op_role == static_cast<int>(OpRole::kLoss)) {
all_ops[OpRole::kLoss].insert(node);
} else if (op_role == static_cast<int>(OpRole::kNotSpecified)) {
LOG(WARNING) << "Op: " << node->Name() << " OpRole is NotSpecified ";
}
}

std::unordered_set<ir::Node*> forward_vars;
std::unordered_set<ir::Node*> backward_vars;
std::unordered_set<ir::Node*> control_vars;
// forward_vars
for (auto& nodes : std::array<std::unordered_set<ir::Node*>, 2>{
all_ops[OpRole::kForward], all_ops[OpRole::kLoss]}) {
for (auto* node : nodes) {
for (auto* in_node : node->inputs) {
forward_vars.insert(in_node);
}
for (auto* out_node : node->outputs) {
forward_vars.insert(out_node);
}
}
}
// control_vars & backward_vars
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
if (node->IsCtrlVar()) {
control_vars.insert(node);
}
for (auto* in_node : node->inputs) {
if (all_ops[OpRole::kOptimize].count(in_node)) {
backward_vars.insert(node);
}
}
}
// all removed node
std::unordered_set<ir::Node*> rm_nodes;
for (auto* node : graph->Nodes()) {
if (backward_vars.count(node)) {
rm_nodes.insert(node);
} else if (control_vars.count(node)) {
rm_nodes.insert(node);
} else if (all_ops[OpRole::kBackward].count(node)) {
rm_nodes.insert(node);
} else if (all_ops[OpRole::kForward].count(node) == 0 &&
all_ops[OpRole::kLoss].count(node) == 0 &&
forward_vars.count(node) == 0) {
rm_nodes.insert(node);
} else if (node->Name() == "feed" || node->Name() == "fetch") {
rm_nodes.insert(node);
}
}

VLOG(10) << "Remove Node: ";
for (auto* node : rm_nodes) {
// rm node releations
for (auto* node_in : node->inputs) {
for (size_t i = 0; i < node_in->outputs.size(); ++i) {
if (node_in->outputs[i] == node) {
node_in->outputs.erase(node_in->outputs.begin() + i);
break;
}
}
}
for (auto* node_out : node->outputs) {
for (size_t i = 0; i < node_out->inputs.size(); ++i) {
if (node_out->inputs[i] == node) {
node_out->inputs.erase(node_out->inputs.begin() + i);
break;
}
}
}
VLOG(10) << "\t" << node->Name();
graph->RemoveNode(node);
}

VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);

VLOG(10) << "leave ForwardGraphExtractPass::ApplyImpl";
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(forward_graph_extract_pass,
paddle::framework::ir::ForwardGraphExtractPass);
31 changes: 31 additions & 0 deletions paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

class ForwardGraphExtractPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
108 changes: 108 additions & 0 deletions paddle/fluid/framework/ir/ipu/infer_shape_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/infer_shape_pass.h"

#include "paddle/fluid/platform/device/ipu/ipu_backend.h"

#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"

namespace paddle {
namespace framework {
namespace ir {

void InferShapePass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter InferShapePass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);

std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
platform::ipu::IpuBackend::GetInstance();
auto batch_size = ipu_backend->GetIpuStrategy()->batch_size;

auto feed_list = Get<std::vector<std::string>>("feed_list");
for (auto node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
bool is_feed = std::find(feed_list.begin(), feed_list.end(),
node->Name()) != feed_list.end();
if (is_feed) {
auto input_shape = node->Var()->GetShape();
if (input_shape[0] <= -1) {
input_shape[0] = batch_size;
node->Var()->SetShape(input_shape);
}
// int64->int32
if (node->Var()->GetDataType() == proto::VarType::INT64) {
node->Var()->SetDataType(proto::VarType::INT32);
}
}
}

// temp scope for shape inference
std::shared_ptr<paddle::framework::Scope> scope(
new paddle::framework::Scope());
for (auto node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
auto var_desc = node->Var();
auto* ptr = scope->Var(var_desc->Name());
paddle::framework::InitializeVariable(ptr, var_desc->GetType());

auto tensor = ptr->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize(paddle::framework::make_ddim(var_desc->GetShape()));
}

// infer shape
auto nodes = ir::TopologySortOperations(*graph);
for (auto node : nodes) {
auto op_desc = node->Op();
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
paddle::framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), *scope);
op->RuntimeInferShape(*scope, paddle::platform::CPUPlace(), ctx);

for (auto it = ctx.outputs.begin(); it != ctx.outputs.end(); it++) {
for (int i = 0; i < it->second.size(); i++) {
auto output_name = op_desc->Output(it->first)[i];
auto dim =
it->second[i]->GetMutable<paddle::framework::LoDTensor>()->dims();
auto new_shape = paddle::framework::vectorize(dim);
for (auto output_node : node->outputs) {
if (output_node->Name() == output_name) {
output_node->Var()->SetShape(new_shape);
}
}
}
}
}
// release the temp scope
scope.reset();

VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave InferShapePass::ApplyImpl";
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(infer_shape_pass, paddle::framework::ir::InferShapePass)
.RequirePassAttr("feed_list");
Loading

0 comments on commit 9dcbf5c

Please sign in to comment.