Skip to content

Commit

Permalink
[IR][TRANSFORM] Enable CopyOnWrite for passes. (apache#5309)
Browse files Browse the repository at this point in the history
This PR enables the copy on write optimizations passes:
- Enable COW for IRModule both TIR and relay passes.
- Enabled COW for PrimFunc in TIR passes.

Need more thoughts into whether/how to enable COW
for relay::Function, due to some function passes depend
on the presence of IRModule for context information,
and the std::move of the related function to nullptr
might affect the related behavior.
  • Loading branch information
tqchen authored and masahi committed Apr 12, 2020
1 parent 240150a commit 11f2826
Show file tree
Hide file tree
Showing 23 changed files with 253 additions and 146 deletions.
7 changes: 3 additions & 4 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class PrimExpr : public BaseExpr {
private:
// Internal function for conversion.
friend struct runtime::PackedFuncValueConverter<PrimExpr>;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
};

/*!
Expand Down Expand Up @@ -464,9 +464,8 @@ struct PackedFuncValueConverter<PrimExpr> {
if (val.type_code() == kDLFloat) {
return PrimExpr(static_cast<float>(val.operator double()));
}
TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
Object* ptr = val.ptr<Object>();
return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));

return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};
} // namespace runtime
Expand Down
23 changes: 16 additions & 7 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <string>
#include <utility>

namespace tvm {
namespace transform {
Expand Down Expand Up @@ -251,8 +252,8 @@ class PassNode : public Object {
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
return this->operator()(mod, PassContext::Current());
IRModule operator()(IRModule mod) const {
return this->operator()(std::move(mod), PassContext::Current());
}

/*!
Expand All @@ -263,7 +264,7 @@ class PassNode : public Object {
*
* \return The transformed module.
*/
virtual IRModule operator()(const IRModule& mod,
virtual IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const = 0;

void VisitAttrs(AttrVisitor* v) {}
Expand All @@ -277,14 +278,22 @@ class Pass : public ObjectRef {
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \code
*
* // If you do no longer need the input module
* // it is recommended to use std::move to move your input module.
* mod = pass(std::move(mod));
*
* \endcode
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
IRModule operator()(IRModule mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
return node->operator()(std::move(mod));
}
/*!
* \brief Transform mod using a functor under a given pass context.
Expand All @@ -294,11 +303,11 @@ class Pass : public ObjectRef {
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod,
IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
return node->operator()(std::move(mod), pass_ctx);
}

TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ inline const char* TypeCode2Str(int type_code) {
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
case kTVMObjectRValueRefArg: return "ObjectRValueRefArg";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {

template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
CHECK(!ref.defined() || ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.data_));
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
} else if (std::is_rvalue_reference<T>::value) {
} else if (std::is_rvalue_reference<decltype(value)>::value) {
values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
type_codes_[i] = kTVMObjectRValueRefArg;
} else {
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ TVM_DLL Pass CombineContextCall();
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
* \param target_bits The target bits
*
* \note Run this pass after storage flatten.
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();
TVM_DLL Pass NarrowDataType(int target_bits);

} // namespace transform
} // namespace tir
Expand Down
1 change: 1 addition & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, msg):
register_error("ValueError", ValueError)
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)


@register_error
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def Apply(ftransform):
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return ftransform(func)
return _fpass.prim_func_pass(_transform, opt_level=0)
return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply")


def Filter(fcond):
Expand All @@ -57,7 +57,7 @@ def Filter(fcond):
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return func if fcond(func) else None
return _fpass.prim_func_pass(_transform, opt_level=0)
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")


def LowerCustomDatatypes():
Expand Down Expand Up @@ -221,9 +221,14 @@ def CombineContextCall():
return _ffi_api.CombineContextCall()


def NarrowDataType():
def NarrowDataType(target_bits):
"""Narrow down PrimExpr datatype in stmt to target_bits.
Parameters
----------
target_bits : int
The target bit configuration.
Returns
-------
fpass : tvm.ir.transform.Pass
Expand All @@ -233,4 +238,4 @@ def NarrowDataType():
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
return _ffi_api.NarrowDataType(target_bits)
20 changes: 10 additions & 10 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,21 @@ PrimExpr::PrimExpr(float value)
PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {}

PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
if (ptr->IsInstance<tir::IterVarNode>()) {
return tir::IterVar(ptr)->var;
if (auto* ptr = ref.as<tir::IterVarNode>()) {
return GetRef<tir::IterVar>(ptr)->var;
}
if (ptr->IsInstance<te::TensorNode>()) {
return te::Tensor(ptr)();
if (auto* ptr = ref.as<te::TensorNode>()) {
return GetRef<te::Tensor>(ptr)();
}
if (ptr->IsInstance<runtime::StringObj>()) {
return tir::StringImmNode::make(runtime::String(ptr));
if (auto* ptr = ref.as<runtime::StringObj>()) {
return tir::StringImmNode::make(GetRef<runtime::String>(ptr));
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
CHECK(ObjectTypeChecker<PrimExpr>::Check(ref.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return PrimExpr(ptr);
<< " but get " << ref->GetTypeKey();
return Downcast<PrimExpr>(ref);
}


Expand Down
16 changes: 14 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,20 @@ bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {

GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Module";
if (it == global_var_map_.end()) {
std::ostringstream msg;
msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
<< "candidates are: [";
int counter = 0;
for (auto kv : global_var_map_) {
if (counter++ != 0) {
msg << ", ";
}
msg << "\"" << kv.first << "\"";
}
msg << "]";
LOG(FATAL) << msg.str();
}
return (*it).second;
}

Expand Down
29 changes: 13 additions & 16 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class ModulePassNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;

/*!
* \brief Get the pass information/meta data.
Expand Down Expand Up @@ -205,7 +205,7 @@ class SequentialNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;

static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
Expand All @@ -231,19 +231,20 @@ ModulePass::ModulePass(
}

// Module -> Module optimizations.
IRModule ModulePassNode::operator()(const IRModule& mod,
IRModule ModulePassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;

CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
mod = pass_func(std::move(mod), pass_ctx);
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, false);
return mod;
}

Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
Expand Down Expand Up @@ -314,18 +315,17 @@ Pass GetPass(const std::string& pass_name) {
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(const IRModule& module,
IRModule SequentialNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
IRModule mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(mod, pass_ctx);
mod = GetPass(it)(std::move(mod), pass_ctx);
}
mod = pass(mod, pass_ctx);
mod = pass(std::move(mod), pass_ctx);
}
return mod;
}
Expand Down Expand Up @@ -375,11 +375,8 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass")
});

TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
IRModule mod = args[1];
ObjectRef ref = args[1];
*ret = pass(mod);
.set_body_typed([](Pass pass, IRModule mod) {
return pass(std::move(mod));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down
9 changes: 8 additions & 1 deletion src/node/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <cstring>
#include "../support/str_escape.h"

namespace tvm {

Expand Down Expand Up @@ -63,6 +63,13 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
static_cast<const runtime::StringObj*>(n)).operator std::string();
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::StringObj*>(node.get());
p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
});


struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;

Expand Down
5 changes: 3 additions & 2 deletions src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class FunctionPassNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;

/*!
* \brief Get the pass information/meta data.
Expand Down Expand Up @@ -113,7 +113,7 @@ FunctionPass::FunctionPass(
}

// Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(const IRModule& mod,
IRModule FunctionPassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
Expand All @@ -122,6 +122,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< " with opt level: "
<< pass_info->opt_level;
pass_ctx.Trace(mod, pass_info, true);

// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
Expand Down
Loading

0 comments on commit 11f2826

Please sign in to comment.