Skip to content

Commit

Permalink
optimzie reshape related fusion (#53066)
Browse files Browse the repository at this point in the history
  • Loading branch information
JiabinYang authored Apr 19, 2023
1 parent e669528 commit c29dc34
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
19 changes: 17 additions & 2 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,16 @@ void AnalyseClusterVariables(
bool is_inference_stage,
const std::unordered_set<std::string>& skip_gc_var_names) {
// collecting all input and output of op
std::unordered_set<std::string> unused_outputs;
std::unordered_set<std::string> legacy_ops{"reshape2", "transpose2"};
for (auto* op_node : cluster) {
const auto& op_name = op_node->Name();
if (legacy_ops.count(op_name) && op_node->Op()->HasOutput("XShape")) {
for (const auto& var_name :
(*(op_node->Op()->MutableOutputs()))["XShape"]) {
unused_outputs.insert(var_name);
}
}
for (auto* input_var_node : op_node->inputs) {
if (!deny_var_set.count(input_var_node->Name())) {
// ignore deny var node
Expand All @@ -527,9 +535,12 @@ void AnalyseClusterVariables(
// remove output node from cluster_inputs,
// and add cluster_internals node
for (auto* var_node : *cluster_outputs) {
if (cluster_inputs->count(var_node) > 0) {
if ((cluster_inputs->count(var_node) > 0) ||
(unused_outputs.count(var_node->Name()))) {
// if a input node also exists in output list, remove
cluster_inputs->erase(var_node);
if (cluster_inputs->count(var_node) > 0) {
cluster_inputs->erase(var_node);
}

// the internal node is must an output node of sub-graph,
// but not any input node of out-graph.
Expand All @@ -538,8 +549,12 @@ void AnalyseClusterVariables(
for (size_t i = 0; i < var_node->outputs.size() && is_only_used_internal;
++i) {
is_only_used_internal &= (cluster.count(var_node->outputs[i]) > 0);
VLOG(3) << "var_node->outputs[" << i << "]: " << var_node->Name()
<< ", is_only_used_internal: " << is_only_used_internal;
}

if (is_only_used_internal) {
VLOG(3) << "insert internal var: " << var_node->Name();
cluster_internals->insert(var_node);
}
}
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
out = difference * rsqrt_var

if scale is not None:
scale = reshape(scale, x.shape[begin_norm_axis:])
if x.shape[begin_norm_axis:] is not scale.shape:
scale = reshape(scale, x.shape[begin_norm_axis:])
out = out * scale
if bias is not None:
bias = reshape(bias, x.shape[begin_norm_axis:])
if x.shape[begin_norm_axis:] is not bias.shape:
bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + bias

mean_ = reshape(mean_, [-1])
Expand Down

0 comments on commit c29dc34

Please sign in to comment.