-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] inplace all reshape op #49146
Changes from all commits
01861ff
48644d4
4f60266
29ca34e
f847f9e
a3320c0
8bb7b0f
f0ebbe0
bdb2a60
5059e8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/inplace_op_var_pass.h" | ||
|
||
#include "paddle/fluid/framework/ir/graph_helper.h" | ||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" | ||
#include "paddle/fluid/framework/ir/node.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
class Graph; | ||
|
||
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { | ||
FusePassBase::Init("inplace_op_var", graph); | ||
int found_subgraph_count = 0; | ||
MapToReshape(graph); | ||
|
||
auto nodes = graph->Nodes(); | ||
auto is_valid_reshape = [](Node* node) { | ||
// Some cases need to consider, please refer to | ||
// https://github.com/PaddlePaddle/Paddle/pull/49146 | ||
if (node->IsOp() && node->Op()->Type() == "reshape2") { | ||
auto x_name = node->Op()->Input("X").front(); | ||
for (auto* var_node : node->inputs) { | ||
if (var_node->Name() == x_name) { | ||
if (!var_node->Var()->Persistable() && var_node->outputs.size() == 1) | ||
return true; | ||
} | ||
} | ||
} | ||
return false; | ||
}; | ||
|
||
// Record all reshape2 op's input name and output name in block 0. | ||
// If the name used in other block, we can not inplace reshape op. | ||
std::unordered_set<std::string> var_names, deny_var_names; | ||
for (auto* node : nodes) { | ||
if (is_valid_reshape(node)) { | ||
for (auto n : node->inputs) var_names.insert(n->Name()); | ||
for (auto n : node->outputs) var_names.insert(n->Name()); | ||
} | ||
} | ||
for (size_t i = 1; i < graph->SubGraphsSize(); ++i) { | ||
auto sub_graph = graph->GetSubGraph(i); | ||
for (auto* node : sub_graph->Nodes()) { | ||
if (node->IsOp()) { | ||
for (auto var_node : node->inputs) { | ||
if (var_names.count(var_node->Name())) | ||
deny_var_names.insert(var_node->Name()); | ||
} | ||
for (auto var_node : node->outputs) { | ||
if (var_names.count(var_node->Name())) | ||
deny_var_names.insert(var_node->Name()); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// inplace all reshape op. | ||
auto topo_nodes = TopologySortOperations(*graph); | ||
for (auto* node : topo_nodes) { | ||
if (!is_valid_reshape(node)) continue; | ||
auto* op_node = node->Op(); | ||
auto input_name = op_node->Input("X")[0]; | ||
auto output_name = op_node->Output("Out")[0]; | ||
if (deny_var_names.count(input_name) || deny_var_names.count(output_name)) { | ||
continue; | ||
} | ||
++found_subgraph_count; | ||
for (auto* out_var : node->outputs) { | ||
if (out_var->Name() == output_name) { | ||
out_var->RenameVar(input_name); | ||
for (auto* next_op : out_var->outputs) { | ||
next_op->Op()->RenameInput(output_name, input_name); | ||
next_op->Op()->Flush(); | ||
} | ||
} | ||
} | ||
|
||
op_node->RenameOutput(output_name, input_name); | ||
op_node->Flush(); | ||
} | ||
AddStatis(found_subgraph_count); | ||
} | ||
|
||
void InplaceOpVarPass::MapToReshape(ir::Graph* graph) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. map?是不是考虑改变下map_depthwise_conv_to_conv_pass的功能,将其改成一个通用的pass专门做op的映射?然后将这块放进去? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 收到,下个pr将map类的操作规范到一个pass里 |
||
// flatten_contiguous_range op map to reshape. | ||
for (auto* node : graph->Nodes()) { | ||
if (node->IsOp() && node->Op()->Type() == "flatten_contiguous_range") { | ||
auto* op_node = node->Op(); | ||
auto start_axis = PADDLE_GET_CONST(int, op_node->GetAttr("start_axis")); | ||
auto stop_axis = PADDLE_GET_CONST(int, op_node->GetAttr("stop_axis")); | ||
auto input_name = op_node->Input("X")[0]; | ||
auto* block = op_node->Block(); | ||
auto input_shape = block->FindVar(input_name)->GetShape(); | ||
if (start_axis == 1 && stop_axis == 3 && input_shape.size() == 4 && | ||
input_shape[2] == 1 && input_shape[3] == 1) { | ||
op_node->SetType("reshape2"); | ||
op_node->SetAttr("shape", std::vector<int>{0, -1}); | ||
op_node->Flush(); | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
REGISTER_PASS(inplace_op_var_pass, paddle::framework::ir::InplaceOpVarPass); | ||
REGISTER_PASS_CAPABILITY(inplace_op_var_pass) | ||
.AddCombination( | ||
paddle::framework::compatible::OpVersionComparatorCombination().EQ( | ||
"reshape2", 0)); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/ir/fuse_pass_base.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
class Graph; | ||
|
||
class InplaceOpVarPass : public FusePassBase { | ||
protected: | ||
void ApplyImpl(ir::Graph* graph) const override; | ||
|
||
private: | ||
virtual ~InplaceOpVarPass() = default; | ||
void MapToReshape(ir::Graph* graph) const; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
from functools import partial | ||
|
||
import hypothesis.strategies as st | ||
import numpy as np | ||
from auto_scan_test import PassAutoScanTest | ||
from program_config import OpConfig, ProgramConfig, TensorConfig | ||
|
||
import paddle.fluid.core as core | ||
|
||
|
||
@unittest.skipIf( | ||
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" | ||
) | ||
class TestInplaceOpPass(PassAutoScanTest): | ||
def is_program_valid(self, program_config: ProgramConfig) -> bool: | ||
return True | ||
|
||
def sample_program_config(self, draw): | ||
def generate_input(): | ||
return np.random.random(x_shape).astype(np.float32) | ||
|
||
def generate_tmp1(val): | ||
return np.array([val]).astype(np.int32) | ||
|
||
def generate_tmp2(val): | ||
return np.array([val]).astype(np.int32) | ||
|
||
def generate_tmp3(val): | ||
return np.array([val]).astype(np.int32) | ||
|
||
def generate_shape(val): | ||
return np.array(val).astype(np.int32) | ||
|
||
x_shape = draw( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个单测被执行了吗?draw从哪里import的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 执行了,draw是autoscan单测框架引入的hypothesis库提供的 |
||
st.lists( | ||
st.integers(min_value=1, max_value=10), min_size=4, max_size=4 | ||
) | ||
) | ||
shape = [0, -1, x_shape[-1]] | ||
scale_op = OpConfig( | ||
"scale", | ||
inputs={"X": ["scale_in"]}, | ||
outputs={"Out": ["scale_out"]}, | ||
scale=1.3, | ||
bias=0.1, | ||
bias_after_scale=False, | ||
) | ||
|
||
test_case = draw( | ||
st.sampled_from( | ||
["simple_reshape", "shape_tensor1", "shape_tensor2"] | ||
) | ||
) | ||
|
||
if test_case == "simple_reshape": | ||
reshape_op = OpConfig( | ||
"reshape2", | ||
inputs={"X": ["scale_out"]}, | ||
outputs={ | ||
"Out": ["reshape_out"], | ||
"XShape": ["reshape_xshape_out"], | ||
}, | ||
shape=shape, | ||
) | ||
ops = [scale_op, reshape_op] | ||
program_config = ProgramConfig( | ||
ops=ops, | ||
inputs={ | ||
"scale_in": TensorConfig(data_gen=partial(generate_input)), | ||
}, | ||
weights={}, | ||
outputs=["reshape_out"], | ||
) | ||
return program_config | ||
|
||
elif test_case == "shape_tensor1": | ||
shape = [-1, -1, x_shape[-1]] | ||
reshape_op = OpConfig( | ||
"reshape2", | ||
inputs={ | ||
"X": ["scale_out"], | ||
"ShapeTensor": ["tmp1", "tmp2", "tmp3"], | ||
}, | ||
outputs={ | ||
"Out": ["reshape_out"], | ||
"XShape": ["reshape_xshape_out"], | ||
}, | ||
shape=shape, | ||
) | ||
ops = [scale_op, reshape_op] | ||
program_config = ProgramConfig( | ||
ops=ops, | ||
inputs={ | ||
"scale_in": TensorConfig(data_gen=partial(generate_input)), | ||
"tmp1": TensorConfig( | ||
data_gen=partial(generate_tmp1, x_shape[0]) | ||
), | ||
"tmp2": TensorConfig( | ||
data_gen=partial(generate_tmp2, x_shape[1] * x_shape[2]) | ||
), | ||
"tmp3": TensorConfig( | ||
data_gen=partial(generate_tmp3, x_shape[-1]) | ||
), | ||
}, | ||
weights={}, | ||
outputs=["reshape_out"], | ||
) | ||
return program_config | ||
|
||
else: | ||
shape = [0, -1, x_shape[-1]] | ||
reshape_op = OpConfig( | ||
"reshape2", | ||
inputs={"X": ["scale_out"], "Shape": ["shape"]}, | ||
outputs={ | ||
"Out": ["reshape_out"], | ||
"XShape": ["reshape_xshape_out"], | ||
}, | ||
shape=shape, | ||
) | ||
ops = [scale_op, reshape_op] | ||
program_config = ProgramConfig( | ||
ops=ops, | ||
inputs={ | ||
"scale_in": TensorConfig(data_gen=partial(generate_input)), | ||
"shape": TensorConfig( | ||
data_gen=partial( | ||
generate_shape, | ||
[x_shape[0], x_shape[1] * x_shape[2], x_shape[3]], | ||
) | ||
), | ||
}, | ||
weights={}, | ||
outputs=["reshape_out"], | ||
) | ||
return program_config | ||
|
||
def sample_predictor_configs(self, program_config): | ||
config = self.create_inference_config(use_gpu=True) | ||
yield config, ['scale', 'reshape2'], (1e-5, 1e-5) | ||
|
||
def add_ignore_pass_case(self): | ||
pass | ||
|
||
def test(self): | ||
self.run_and_statis( | ||
quant=False, | ||
passes=["inplace_op_var_pass"], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass的名字通用?不是只针对reshape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前已知的 reshape和squeeze操作都可以inplace,可能还有其他算子