Skip to content

Commit

Permalink
[GLCC]Part-3: Support jit.save and jit.load for pylayer op (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#57066)

* complete static_pylayer op

* finish static_pylayer op context manager

* finish single test

* append import path

* maybe modify test/ir/inference

* percept static_pylayer op in dy2st
  • Loading branch information
MarioLulab authored Sep 22, 2023
1 parent a828804 commit 8e89dd3
Show file tree
Hide file tree
Showing 9 changed files with 1,078 additions and 221 deletions.
181 changes: 155 additions & 26 deletions paddle/fluid/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const char kRecurrent[] = "recurrent"; // NOLINT
const char kStates[] = "states"; // NOLINT
const char kExStates[] = "ex_states"; // NOLINT

const char kPyLayer[] = "pylayer"; // NOLINT

bool HasDependentInputVar(
const proto::OpDesc& op_desc,
const std::unordered_set<std::string>& dependent_vars) {
Expand Down Expand Up @@ -86,6 +88,23 @@ int GetSubBlockIndex(const proto::OpDesc& op_desc) {
return -1;
}

void GetSubBlocksIndices(const proto::OpDesc& op_desc,
std::vector<int>* indices) {
for (auto& attr : op_desc.attrs()) {
if (attr.type() == proto::AttrType::BLOCKS) {
PADDLE_ENFORCE_GT(
attr.blocks_idx_size(),
0,
platform::errors::NotFound(
"Attribute blocks is not found in operator %s", op_desc.type()));
indices->resize(attr.blocks_idx_size());
for (int i = 0; i < attr.blocks_idx_size(); i++) {
(*indices)[i] = attr.blocks_idx(i);
}
}
}
}

void SetSubBlockIndex(proto::OpDesc* op_desc, int sub_idx) {
for (auto& attr : *op_desc->mutable_attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
Expand All @@ -99,10 +118,43 @@ void SetSubBlockIndex(proto::OpDesc* op_desc, int sub_idx) {
}
}

void SetSubBlocksIndices(proto::OpDesc* op_desc,
const std::vector<int>& sub_indices) {
for (auto& attr : *op_desc->mutable_attrs()) {
if (attr.type() == proto::AttrType::BLOCKS) {
PADDLE_ENFORCE_GT(
attr.blocks_idx_size(),
0,
platform::errors::NotFound(
"Attribute blocks is not found in operator %s", op_desc->type()));
attr.clear_blocks_idx();
for (auto idx : sub_indices) {
attr.add_blocks_idx(idx);
}
}
}
}

bool HasSubBlock(const proto::OpDesc& op_desc) {
return GetSubBlockIndex(op_desc) > 0;
}

bool HasSubBlocks(const proto::OpDesc& op_desc) {
// ``blocks_idx_size() == 0`` indicates no sub blocks.
for (auto& attr : op_desc.attrs()) {
if (attr.type() == proto::AttrType::BLOCKS) {
PADDLE_ENFORCE_GT(
attr.blocks_idx_size(),
0,
platform::errors::NotFound(
"Attribute blocks is not found in operator %s", op_desc.type()));
return true;
}
}

return false;
}

int GetOpRole(const proto::OpDesc& op_desc) {
for (auto& attr : op_desc.attrs()) {
if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) {
Expand Down Expand Up @@ -150,14 +202,15 @@ int FindMapByValue(const std::map<int, int>& m, int val) {
}

// In other two cases, the op that has feed vars as output vars is dependent:
// 1. op has subblock, like while/for/ifelse/recurrent
// 1. op has subblock, like while/for/ifelse/recurrent/pylayer
// 2. op is in subblock
bool IsSubBlockDependent(const proto::OpDesc& op_desc,
const std::set<std::string>& feed_vars,
int parent_block_id) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if ((HasSubBlock(op_desc) || parent_block_id != -1) &&
if ((HasSubBlock(op_desc) || HasSubBlocks(op_desc) ||
parent_block_id != -1) &&
feed_vars.count(argu) != 0) {
return true;
}
Expand Down Expand Up @@ -289,7 +342,7 @@ void prune_impl(const proto::ProgramDesc& input,
if (should_run[i]) {
auto* op = op_field->Add();
*op = input.blocks(block_id).ops(static_cast<int>(i));
if (HasSubBlock(*op)) {
if (HasSubBlock(*op) || HasSubBlocks(*op)) {
VLOG(2) << "Pruning op which has sub block: " << op->type();
// create sub_block_dependent_vars here to help prune the sub block
std::unordered_set<std::string> sub_block_dependent_vars;
Expand Down Expand Up @@ -321,15 +374,41 @@ void prune_impl(const proto::ProgramDesc& input,
}
}
}
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input,
output,
GetSubBlockIndex(*op),
output_block_id,
&sub_block_dependent_vars,
feed_var_names,
pruned_origin_block_id_map);
if (HasSubBlock(*op)) {
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input,
output,
GetSubBlockIndex(*op),
output_block_id,
&sub_block_dependent_vars,
feed_var_names,
pruned_origin_block_id_map);
} else if (HasSubBlocks(*op)) {
// GetSubBlocksIndices(*op) are the indices of the sub_blocks in the
// input desc output_block_id is the idx of the current block in the
// output desc
std::vector<int> sub_indices;
GetSubBlocksIndices(*op, &sub_indices);
for (auto& sub_index : sub_indices) {
// create a copy of dependent_vars to avoid being overwrited by the
// other sub_block
std::unordered_set<std::string> dependent_vars_copy =
sub_block_dependent_vars;
prune_impl(input,
output,
sub_index,
output_block_id,
&dependent_vars_copy,
feed_var_names,
pruned_origin_block_id_map);
}
} else {
PADDLE_ENFORCE(false,
platform::errors::PreconditionNotMet(
"Attr Block or Blocks must exist when recursively "
"calling prune_impl"));
}
}
}
}
Expand Down Expand Up @@ -402,12 +481,29 @@ std::map<int, int> Prune(const proto::ProgramDesc& input,
int origin_sub_idx = GetSubBlockIndex(op_desc);
auto sub_idx =
FindMapByValue(pruned_origin_block_id_map, origin_sub_idx);
PADDLE_ENFORCE_NE(sub_idx,
-1,
platform::errors::NotFound(
"The origin sub block id should be found in "
"pruned_progin_block_id_map"));
PADDLE_ENFORCE_NE(
sub_idx,
-1,
platform::errors::NotFound(
"The origin sub block id should be found in "
"pruned_progin_block_id_map when the op has sub_block"));
SetSubBlockIndex(&op_desc, sub_idx);
} else if (HasSubBlocks(op_desc)) {
std::vector<int> origin_sub_indices;
GetSubBlocksIndices(op_desc, &origin_sub_indices);
std::vector<int> sub_indices;
for (int index : origin_sub_indices) {
auto sub_idx = FindMapByValue(pruned_origin_block_id_map, index);
PADDLE_ENFORCE_NE(
sub_idx,
-1,
platform::errors::NotFound(
"The origin sub block id should be found in "
"pruned_progin_block_id_map when the op has sub_blocks"));
sub_indices.push_back(sub_idx);
}

SetSubBlocksIndices(&op_desc, sub_indices);
}
}
}
Expand Down Expand Up @@ -441,6 +537,19 @@ void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
AppendOpInputVarNames(op_desc, &op_input_vars);
AppendOpOutputVarNames(op_desc, &op_output_vars);
*op = op_desc;

// if the type of op is "pylayer", we need to update the ``blocks``
// attribute because the backward block will be pruned
if (op->type() == kPyLayer && HasSubBlocks(*op)) {
std::vector<int> sub_indices;
GetSubBlocksIndices(*op, &sub_indices);
if (sub_indices.size() > 1) {
// sub_indices contains both forward block id and backward block id
std::vector<int> new_sub_indices(sub_indices.begin(),
sub_indices.end() - 1);
SetSubBlocksIndices(op, new_sub_indices);
}
}
}
}

Expand Down Expand Up @@ -471,9 +580,10 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
// Copy original ProgramDesc, origin can't be change
framework::ProgramDesc origin_clone(origin);

// Step 1. check if the program contains grad loss operator.
// If not, the program need no pruning.
// Step 1. check if the program contains grad loss operator or pylayer
// operator. If not, the program need no pruning.
bool has_loss_grad_op = false;
bool has_pylayer_op = false;
std::queue<int> block_contains_loss;
std::queue<int> block_contains_loss_grad;
for (size_t i = 0; i < origin_clone.Size(); i++) {
Expand All @@ -485,13 +595,15 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
static_cast<int>(OpRole::kLoss))) {
op->SetIsTarget(false);
has_loss_grad_op = true;
break;
}
if (op->Type() == kPyLayer) {
has_pylayer_op = true;
}
}
}

std::map<int, int> pruned_progin_block_id_map;
if (!has_loss_grad_op) {
if (!has_loss_grad_op && !has_pylayer_op) {
// No pruning, fast return a copy of the origin ProgramDesc with an empty
// map, means default mapped, i.e.{0:0, 1:1, ..., n:n}.
return std::make_tuple(framework::ProgramDesc(origin_clone),
Expand Down Expand Up @@ -544,12 +656,29 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
int origin_sub_idx = GetSubBlockIndex(op_desc);
auto sub_idx =
FindMapByValue(pruned_progin_block_id_map, origin_sub_idx);
PADDLE_ENFORCE_NE(sub_idx,
-1,
platform::errors::NotFound(
"The origin sub block id is not found in "
"pruned_progin_block_id_map"));
PADDLE_ENFORCE_NE(
sub_idx,
-1,
platform::errors::NotFound(
"The origin sub block id is not found in "
"pruned_progin_block_id_map when the op has sub_block"));
SetSubBlockIndex(&op_desc, sub_idx);
} else if (HasSubBlocks(op_desc)) {
std::vector<int> origin_sub_indices;
GetSubBlocksIndices(op_desc, &origin_sub_indices);
std::vector<int> sub_indices;
for (int index : origin_sub_indices) {
auto sub_idx = FindMapByValue(pruned_progin_block_id_map, index);
PADDLE_ENFORCE_NE(
sub_idx,
-1,
platform::errors::NotFound(
"The origin sub block id should be found in "
"pruned_progin_block_id_map when the op has sub_blocks"));
sub_indices.push_back(sub_idx);
}

SetSubBlocksIndices(&op_desc, sub_indices);
}
}
}
Expand Down
15 changes: 13 additions & 2 deletions python/paddle/jit/dy2static/py_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools
import inspect

from paddle.base.framework import Variable
from paddle.common_ops_import import LayerHelper
Expand Down Expand Up @@ -73,9 +74,19 @@ def __init__(self, dyfunc_self):
)

# NOTE: only support position args and Variables Now
def apply(self, *args):
def apply(self, *args, **kwargs):
# rearrange `position-args + keyword-args` into `position-args`
dyfunc_sig = inspect.signature(self.dyfunc_self.forward)
bound_args = dyfunc_sig.bind(self.dyfunc_self, *args, **kwargs)
bound_args.apply_defaults()
input_args = [
item
for i, item in enumerate(bound_args.arguments.values())
if i > 0
] # index 0 indicate `dyfunc_self` which shouldn't be put into `input_args`

return static_pylayer(
forward_fn=self.forward_fn_with_ctx,
inputs=list(args),
inputs=input_args,
backward_fn=self.backward_fn_with_ctx,
)
14 changes: 14 additions & 0 deletions python/paddle/static/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ def normalize_program(program, feed_vars, fetch_vars, **kwargs):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
remove_op_idx.append(i)

if op.type == "pylayer":
sub_blocks_ids = op._blocks_attr_ids("blocks")
if len(sub_blocks_ids) > 1:
# pylayer op ``blocks`` attr contains forward block id and backward block id
backward_block_id = sub_blocks_ids[-1]
# remove backward block
copy_program.blocks.pop(backward_block_id)
# update attrs ``blocks``
reserverd_blocks = []
for block_id in sub_blocks_ids[:-1]:
reserverd_blocks.append(copy_program.block(block_id))
op._update_desc_attr("blocks", reserverd_blocks)

for idx in remove_op_idx[::-1]:
global_block._remove_op(idx)
copy_program.desc.flush()
Expand Down
Loading

0 comments on commit 8e89dd3

Please sign in to comment.