Skip to content

Commit

Permalink
fix runtime crash when rnn model inference, test=develop (#31833) (#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Mar 25, 2021
1 parent d44d173 commit c7a6a1f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
"merge_lod_tensor",
"equal",
"sequence_pool",
"recurrent",
"lod_reset"};
for (auto* tmp : node->inputs) {
CHECK(tmp->IsOp());
Expand Down
25 changes: 12 additions & 13 deletions paddle/fluid/operators/recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,10 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
auto *block = Attr<framework::BlockDesc *>(kStepBlock);

auto *program = block->Program();
auto ctx = executor.Prepare(
*program, block->ID(), Attr<std::vector<std::string>>(
kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/);
auto ctx = executor.Prepare(*program, block->ID(),
Attr<std::vector<std::string>>(
kSkipEagerDeletionVars), /*skip_ref_cnt_vars*/
true);

static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
Expand Down Expand Up @@ -256,16 +257,6 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
// Link inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
if (i > 0) {
LinkTensorWithCallback(scope, Outputs(kOutputs), cur_scope,
Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
framework::Tensor src_slice =
src_tensor.Slice(seq_offset, seq_offset + 1);
dst_tensor->ShareDataWith(src_slice);
});
}

// Linked now, execute!
executor.RunPreparedContext(ctx.get(), &cur_scope,
Expand All @@ -285,6 +276,14 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
// early.
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
});
} else {
LinkTensorWithCallback(
cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
});
}

scopes.ForwardNext();
Expand Down
9 changes: 5 additions & 4 deletions python/paddle/nn/functional/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def batch_norm(x,

if in_dygraph_mode():
# for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout",
data_format, "use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", use_global_stats, "trainable_statistics",
trainable_statistics)
attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
not training, "data_layout", data_format, "use_mkldnn", False,
"fuse_with_relu", False, "use_global_stats", use_global_stats,
"trainable_statistics", trainable_statistics)
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
x, weight, bias, running_mean, running_var, mean_out, variance_out,
*attrs)
Expand All @@ -207,6 +207,7 @@ def batch_norm(x,
attrs = {
"momentum": momentum,
"epsilon": epsilon,
"is_test": not training,
"data_layout": data_format,
"use_mkldnn": False,
"fuse_with_relu": False,
Expand Down

0 comments on commit c7a6a1f

Please sign in to comment.