diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 6818d757ab79..39c2880d627b 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -32,6 +32,22 @@ DMLC_REGISTER_PARAMETER(CachedOpConfig); constexpr uint32_t kEidNotExist = std::numeric_limits::max(); +const char CachedOp::FULL[] = "full"; +const char CachedOp::FORWARD[] = "forward"; +const char CachedOp::BACKWARD[] = "backward"; +const char CachedOp::REF_COUNT[] = "ref_count"; +const char CachedOp::MEM_PLAN[] = "mem_plan"; +const char CachedOp::STORAGE_PLAN[] = "storage_plan"; + +namespace { + +std::string AddPrefix(const std::string& prefix, + const std::string& s) { + return prefix + "_" + s; +} + +} // namespace + struct CachedOp::GraphInfo { nnvm::Graph fwd_graph; nnvm::Graph full_graph; @@ -136,7 +152,7 @@ CachedOp::CachedOp( for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; } - fwd_graph_.attrs["forward_ref_count"] = + fwd_graph_.attrs[AddPrefix(FORWARD, REF_COUNT)] = std::make_shared(std::move(ref_count)); inlining_ = !config_.static_alloc && @@ -201,9 +217,9 @@ CachedOp::CachedOp( } } - auto full_ref_count = fwd_graph_.GetAttr >("forward_ref_count"); + auto full_ref_count = fwd_graph_.GetAttr >(AddPrefix(FORWARD, REF_COUNT)); for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += ref_count[i]; - fwd_graph_.attrs["full_ref_count"] = + fwd_graph_.attrs[AddPrefix(FULL, REF_COUNT)] = std::make_shared(std::move(full_ref_count)); size_t num_forward_inputs = num_inputs(); @@ -336,14 +352,15 @@ bool CachedOp::SetForwardGraph( // When dynmaic shape exists, it is not feasible to plan memory ahead of time if (contain_dynamic_shape) { - g.attrs.erase("forward_mem_plan"); - g.attrs.erase("full_mem_plan"); + g.attrs.erase(AddPrefix(FORWARD, MEM_PLAN)); + g.attrs.erase(AddPrefix(FULL, MEM_PLAN)); return false; } + const std::string& prefix = recording ? FULL : FORWARD; if (!match) { - g.attrs.erase("forward_mem_plan"); - g.attrs.erase("full_mem_plan"); - } else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) { + g.attrs.erase(AddPrefix(FORWARD, MEM_PLAN)); + g.attrs.erase(AddPrefix(FULL, MEM_PLAN)); + } else if (g.attrs.count(AddPrefix(prefix, MEM_PLAN))) { return true; } @@ -363,9 +380,9 @@ bool CachedOp::SetForwardGraph( } auto mem_plan = PlanMemory( - &g, std::move(storage), g.GetAttr >( - recording ? "full_ref_count" : "forward_ref_count")); - g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] = + &g, std::move(storage), g.GetAttr >(AddPrefix(prefix, REF_COUNT)), + AddPrefix(prefix, STORAGE_PLAN)); + g.attrs[AddPrefix(prefix, MEM_PLAN)] = std::make_shared(std::move(mem_plan)); return false; @@ -432,7 +449,7 @@ bool CachedOp::SetBackwardGraph( size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries(); - if (!g.attrs.count("backward_ref_count")) { + if (!g.attrs.count(AddPrefix(BACKWARD, REF_COUNT))) { std::vector ref_count(idx.num_node_entries(), 0); for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; @@ -443,7 +460,7 @@ bool CachedOp::SetBackwardGraph( } } for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; - g.attrs["backward_ref_count"] = std::make_shared(std::move(ref_count)); + g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared(std::move(ref_count)); } auto shapes = info->fwd_graph.GetAttr("shape"); @@ -476,8 +493,8 @@ bool CachedOp::SetBackwardGraph( false, node_range, entry_range); if (!match) { - g.attrs.erase("backward_mem_plan"); - } else if (g.attrs.count("backward_mem_plan")) { + g.attrs.erase(AddPrefix(BACKWARD, MEM_PLAN)); + } else if (g.attrs.count(AddPrefix(BACKWARD, MEM_PLAN))) { return true; } @@ -491,11 +508,13 @@ bool CachedOp::SetBackwardGraph( for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID; auto mem_plan = PlanMemory( - &g, std::move(storage), g.GetAttr >("backward_ref_count"), + &g, std::move(storage), + g.GetAttr >(AddPrefix(BACKWARD, REF_COUNT)), + AddPrefix(BACKWARD, STORAGE_PLAN), {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, idx.num_node_entries()}, detect_inplace_addto); - g.attrs["backward_mem_plan"] = std::make_shared(std::move(mem_plan)); + g.attrs[AddPrefix(BACKWARD, MEM_PLAN)] = std::make_shared(std::move(mem_plan)); return false; } @@ -526,9 +545,10 @@ void CachedOp::StaticAllocMemory( const auto& default_ctx = state.context; nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; const auto& idx = g.indexed_graph(); - const auto& vstorage_inplace = g.GetAttr >("storage_inplace_index"); - const auto& mem_plan = g.GetAttr( - keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan")); + const std::string& graph_type = keep_fwd ? BACKWARD : (recording ? FULL : FORWARD); + const auto& storage_plan_attr = AddPrefix(graph_type, STORAGE_PLAN); + const auto& storage_plan = g.GetAttr >(storage_plan_attr); + const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); std::vector addto_entry; if (g.attrs.count("addto_entry")) { addto_entry = g.GetAttr >("addto_entry"); @@ -558,9 +578,9 @@ void CachedOp::StaticAllocMemory( for (size_t i = start_eid; i < end_eid; ++i) { if (addto_entry.size() && addto_entry[i]) { state.array_reqs[i] = kAddTo; - } else if (vstorage_inplace[i] >= 0) { + } else if (storage_plan[i] >= 0) { state.array_reqs[i] = kWriteInplace; - } else if (vstorage_inplace[i] == -2) { + } else if (storage_plan[i] == -2) { // -2 indicate that the entry is never referenced. state.array_reqs[i] = kNullOp; } else { @@ -862,8 +882,9 @@ OpStatePtr CachedOp::DynamicForward( } // Allocate NDArrays - std::vector ref_count = g.GetAttr >( - recording ? "full_ref_count" : "forward_ref_count"); + const std::string& graph_type = recording ? FULL : FORWARD; + std::vector ref_count = + g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); std::vector array_reqs(arrays.size(), kWriteTo); for (size_t i = 0; i < idx.num_node_entries(); ++i) { @@ -871,8 +892,7 @@ OpStatePtr CachedOp::DynamicForward( } const auto& dispatch_modes = g.GetAttr("dispatch_mode"); if (!use_naive_run) { - const auto& mem_plan = g.GetAttr( - recording ? "full_mem_plan" : "forward_mem_plan"); + const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), mem_plan, arrays, &array_reqs); const auto& dtypes = g.GetAttr("dtype"); @@ -1011,7 +1031,7 @@ void CachedOp::DynamicBackward( } // Allocate NDArrays - auto ref_count = g.GetAttr >("backward_ref_count"); + auto ref_count = g.GetAttr >(AddPrefix(BACKWARD, REF_COUNT)); if (retain_graph) { for (size_t i = 0; i < num_forward_entries; ++i) ++ref_count[i]; } @@ -1027,7 +1047,7 @@ void CachedOp::DynamicBackward( if (ref_count[i] == 0) array_reqs[i] = kNullOp; } - const auto& mem_plan = g.GetAttr("backward_mem_plan"); + const auto& mem_plan = g.GetAttr(AddPrefix(BACKWARD, MEM_PLAN)); AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(), mem_plan, arrays, &array_reqs); diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index db049d59ed80..84f96300c27b 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -140,6 +140,13 @@ class CachedOp { void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, bool monitor_all = false); + static const char FULL[]; + static const char FORWARD[]; + static const char BACKWARD[]; + static const char REF_COUNT[]; + static const char MEM_PLAN[]; + static const char STORAGE_PLAN[]; + private: struct GraphInfo; struct DynamicRuntime; diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index c5932bb3bbfe..8317c6073a24 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -814,6 +814,7 @@ inline MemoryPlanVector PlanMemory( nnvm::Graph* p_g, nnvm::StorageVector&& storage, const std::vector& ref_count, + const std::string& storage_plan, const std::pair& node_range = {0, 0}, const std::pair& entry_range = {0, 0}, bool detect_inplace_addto = false) { @@ -831,6 +832,7 @@ inline MemoryPlanVector PlanMemory( const auto& dtypes = g.GetAttr("dtype"); const auto& shapes = g.GetAttr("shape"); const auto& storage_inplace = g.GetAttr >("storage_inplace_index"); + g.attrs[storage_plan] = std::make_shared(storage_inplace); const auto& storage_ids = g.GetAttr("storage_id"); uint32_t entry_start = entry_range.first; uint32_t entry_end =