Skip to content

Commit

Permalink
[MXNET-750] fix nested call on CachedOp. (apache#11951)
Browse files Browse the repository at this point in the history
* fix nested call on cachedop.

* fix.
  • Loading branch information
zheng-da authored and piiswrong committed Jul 31, 2018
1 parent 1eb7c5d commit 1486c5f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,11 @@ OpStatePtr CachedOp::DynamicForward(

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

if (recording && !inlining_) Imperative::Get()->set_is_recording(false);

// If we are already recording, we don't need RunGraph to record all
// computation again.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes);

Imperative::Get()->set_is_recording(recording);
std::move(ref_count), &states, dispatch_modes,
!recording || inlining_);

return op_state;
}
Expand Down Expand Up @@ -947,7 +946,8 @@ void CachedOp::DynamicBackward(
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
Imperative::Get()->is_recording());

if (retain_graph) {
buff.resize(num_forward_entries);
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ std::vector<NDArray*> Imperative::Backward(
int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
is_recording());

Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ void RunGraph(
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes) {
const DispatchModeVector &dispatch_modes,
bool recording) {
using namespace nnvm;
using namespace imperative;
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
Expand All @@ -40,7 +41,6 @@ void RunGraph(
const auto imp = Imperative::Get();

std::vector<OpStatePtr>& states = *p_states;
bool recording = imp->is_recording();

std::vector<NDArray*> ndinputs, ndoutputs;
ShapeVector arg_shapes;
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,8 @@ void RunGraph(const bool retain_graph,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes);
const DispatchModeVector &dispatch_modes,
bool recording);

} // namespace imperative
} // namespace mxnet
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,7 @@ def check_contrib_rnn(cell_type, num_states):

configs = [
{},
{'inline_limit': 0},
{'static_alloc': True},
{'static_alloc': True, 'static_shape': True} ]
for config in configs:
Expand Down

0 comments on commit 1486c5f

Please sign in to comment.