Skip to content

Commit

Permalink
fix save inference model conditional op (PaddlePaddle#37579)
Browse files Browse the repository at this point in the history
  • Loading branch information
JZZ-NOTE committed Jan 6, 2022
1 parent 1e8432f commit 69e117f
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 25 deletions.
44 changes: 19 additions & 25 deletions paddle/fluid/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ int FindMapByValue(const std::map<int, int>& m, int val) {
return -1;
}

// In other two cases,the op that has feed vars as output vars is dependent:
// 1. op has subblock, like while/for/ifelse/recurrent
// 2. op is in subblock
bool IsSubBlockDependent(const proto::OpDesc& op_desc,
const std::set<std::string>& feed_vars,
int parent_block_id) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if ((HasSubBlock(op_desc) || parent_block_id != -1) &&
feed_vars.count(argu) != 0) {
return true;
}
}
}
return false;
}

// block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block
Expand Down Expand Up @@ -210,7 +227,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// }

if (IsTarget(op_desc) ||
(HasDependentOutputVar(op_desc, *dependent_vars) &&
((HasDependentOutputVar(op_desc, *dependent_vars) ||
(IsSubBlockDependent(op_desc, feed_var_names, parent_block_id))) &&
(GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) {
// NOTE(zhiqiu): since optimize op takes the trainable parameters as
// inputs and output, it may introduce wrong dependency graph.
Expand All @@ -227,30 +245,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
should_run.push_back(true);
} else {
should_run.push_back(false);
// If the output of an op modifies feed vars, the op should not clip.
// For example, in the transformer structure, the third parameter returned
// by beam_search op is generally assigned to a feed var. Cutting the
// assign op will cause an error.
if (parent_block_id != -1) {
bool flag = false;
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu)) {
flag = true;
}
}
}
if (flag) {
should_run.back() = true;

// If any op should run, then there inputs are dependent_vars
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars->insert(argu);
}
}
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import unittest
import numpy as np

import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F


def getModelOp(model_path):
model_bytes = paddle.static.load_from_file(model_path)
pg = paddle.static.deserialize_program(model_bytes)
main_block = pg.desc.block(0)
size = main_block.op_size()

result = set()
for i in range(0, size):
#print(main_block.op(i).type())
result.add(main_block.op(i).type())

return result


class WhileNet(paddle.nn.Layer):
def __init__(self):
super(WhileNet, self).__init__()

def forward(self, x):
y = paddle.rand(shape=[1, 3, 4, 4])

w1 = paddle.shape(y)[0]
w2 = paddle.shape(x)[0]

while w2 != w1:
x = F.avg_pool2d(x, kernel_size=3, padding=1, stride=2)
w2 = paddle.shape(x)[0]

return x + y


class ForNet(paddle.nn.Layer):
def __init__(self):
super(ForNet, self).__init__()

def forward(self, x):
y = paddle.randint(low=0, high=5, shape=[1], dtype='int32')
z = paddle.randint(low=0, high=5, shape=[1], dtype='int32')
for i in range(0, z):
x = x + i

return x + y


class IfElseNet(paddle.nn.Layer):
def __init__(self):
super(IfElseNet, self).__init__()

def forward(self, x):
y = paddle.to_tensor([5])
if x > y:
x = x + 1
else:
x = x - 1
return x


class TestConditionalOp(unittest.TestCase):
def test_while_op(self):
paddle.disable_static()
net = WhileNet()
net = paddle.jit.to_static(
net,
input_spec=[
paddle.static.InputSpec(
shape=[1, 3, 8, 8], dtype='float32')
])
paddle.jit.save(net, './while_net')

right_pdmodel = set([
"uniform_random", "shape", "slice", "not_equal", "while",
"elementwise_add"
])
paddle.enable_static()
pdmodel = getModelOp("while_net.pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The while op is pruned by mistake.")

def test_for_op(self):
paddle.disable_static()
net = ForNet()
net = paddle.jit.to_static(
net,
input_spec=[paddle.static.InputSpec(
shape=[1], dtype='int32')])
paddle.jit.save(net, './for_net')

right_pdmodel = set([
"randint", "fill_constant", "cast", "less_than", "while",
"elementwise_add"
])
paddle.enable_static()
pdmodel = getModelOp("for_net.pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The for op is pruned by mistake.")

def test_if_op(self):
paddle.disable_static()
net = IfElseNet()
net = paddle.jit.to_static(
net,
input_spec=[paddle.static.InputSpec(
shape=[1], dtype='int32')])
paddle.jit.save(net, './if_net')

right_pdmodel = set([
"assign_value", "greater_than", "cast", "conditional_block",
"logical_not", "select_input"
])
paddle.enable_static()
pdmodel = getModelOp("if_net.pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The if op is pruned by mistake.")


if __name__ == '__main__':
unittest.main()

0 comments on commit 69e117f

Please sign in to comment.