Skip to content

Commit

Permalink
[IR] Default to empty attributes, instead of NULL (#16745)
Browse files Browse the repository at this point in the history
* [IR] Default to empty attributes, instead of NULL

Prior to this commit, the default `DictAttrs` for an `IRModule`,
`tir::PrimFunc`, `relax::Function`, and `relay::Function` was a null
value.  At each callsite, the absence of a `DictAttrs` needed to be
treated as equivalent to an empty `DictAttrs`.  In C++, this typically
was done using the `foo->GetAttr` helper function, but in Python it
needed to be checked explicitly.  That is, every callsite needed to
check `if func.attrs is not None and attr_name in func.attrs`, rather
than only checking `if attr_name in func.attrs`.

Since most functions would have at least one attribute to specify the
global symbol, these bugs would often surface when working on
unrelated changes.

This commit changes the default attribute dictionary from
`NullValue<DictAttrs>()` to `DictAttrs()`.  This avoids having two
separate representations of an object without any attributes, and
allows the `if attr_name in func.attrs` pattern in the Python API.

* Remove no-longer-needed checks on attrs being present

* Fix up unit tests

* More unit test fixes

* Undo erroneous find/replace

* A few more unit tests

* Provide `DictAttrs.get`
  • Loading branch information
Lunderberg authored Mar 25, 2024
1 parent ef46f4e commit b2204ae
Show file tree
Hide file tree
Showing 39 changed files with 127 additions and 107 deletions.
7 changes: 2 additions & 5 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class DictAttrs : public Attrs {
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
*/
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict = {});

// Utils for accessing attributes
// This needs to be on DictAttrs, not DictAttrsNode because we return the default
Expand Down Expand Up @@ -298,7 +298,7 @@ class DictAttrs : public Attrs {
return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
}

TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};

Expand Down Expand Up @@ -415,9 +415,6 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
if (input->attrs.defined()) {
TNode* node = input.CopyOnWrite();
node->attrs.CopyOnWrite()->dict.erase(attr_key);
if (node->attrs->dict.size() == 0) {
node->attrs = NullValue<DictAttrs>();
}
}
return input;
}
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ class IRModule : public ObjectRef {
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {}, SourceMap map = {},
DictAttrs attrs = {}, Map<String, Array<GlobalInfo>> global_infos = {});
DictAttrs attrs = DictAttrs(),
Map<String, Array<GlobalInfo>> global_infos = {});

/*! \brief default constructor */
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
Expand Down
5 changes: 2 additions & 3 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -983,15 +983,14 @@ class FunctionNode : public BaseFuncNode {
class Function : public BaseFunc {
public:
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());
bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());

/*!
* \brief Mimics the constructor but without body Expr.
* \note ret_struct_info is required, since it can not deduced by the body.
*/
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
bool is_pure = true, DictAttrs attrs = DictAttrs(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Function : public BaseFunc {
* \param span The span of the function.
*/
TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
tvm::DictAttrs attrs = DictAttrs(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
Expand Down
14 changes: 12 additions & 2 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,14 +741,24 @@ struct ObjectPtrEqual {
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \
ObjectName) \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
using ContainerType = ObjectName;

/*
* \brief Define object reference methods.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, ObjectName)

/*
* \brief Define object reference methods that is not nullable.
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class PrimFuncFrameNode : public TIRFrameNode {
/*! \brief Maps some parameters to specific Buffer data structures. */
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
/*! \brief Additional attributes storing the meta-data */
Optional<Map<String, ObjectRef>> attrs;
Map<String, ObjectRef> attrs;
/*! \brief The variable map bound to thread env. */
Map<tvm::tir::Var, tvm::tir::IterVar> env_threads;
/*! \brief The buffer allocated in root block. */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class PrimFunc : public BaseFunc {
*/
TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
DictAttrs attrs = DictAttrs(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def handle_norm(self, f, op_type):
return f.with_attrs(attrs)

def visit_function_(self, f):
if f.attrs is None or "Composite" not in f.attrs:
if b"Composite" not in f.attrs:
body = super().visit_expr(f.body)
return relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span)

Expand Down
23 changes: 11 additions & 12 deletions python/tvm/contrib/relay_viz/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,14 @@ def _function(
node_to_id: Dict[relay.Expr, str],
) -> Tuple[Union[VizNode, None], List[VizEdge]]:
"""Render rule for a relay function node"""
node_details = []
name = ""
func_attrs = node.attrs
if func_attrs:
node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
name = func_attrs["Composite"]
node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
name = func_attrs["Composite"]
else:
name = ""

node_id = node_to_id[node]

# Body -> FunctionNode
Expand All @@ -244,11 +244,10 @@ def _call(
elif isinstance(node.op, relay.Function):
func_attrs = node.op.attrs
op_name = "Anonymous Func"
if func_attrs:
node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
op_name = func_attrs["Composite"]
node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
op_name = func_attrs["Composite"]
elif isinstance(node.op, relay.GlobalVar):
op_name = "GlobalVar"
node_detail = [f"GlobalVar.name_hint: {node.op.name_hint}"]
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/dlight/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
def _is_scheduled(func: tir.PrimFunc) -> bool:
if not isinstance(func, tir.PrimFunc):
return False
if not func.attrs:
return False
if "tir.is_scheduled" not in func.attrs:
return False
return func.attrs["tir.is_scheduled"] == 1
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys():
if "dlight.do_not_tensorize" in func.attrs.keys():
return None

reduction_blocks = get_reduction_blocks(sch, blocks)
Expand Down Expand Up @@ -556,7 +556,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys():
if "dlight.do_not_tensorize" in func.attrs.keys():
return None

reduction_blocks = get_reduction_blocks(sch, blocks)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def build(
if target is None and isinstance(input_mod, tvm.IRModule):
target_mod = {}
for gvar, func in input_mod.functions.items():
tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm"
tgt = func.attrs["target"] if "target" in func.attrs else "llvm"
if tgt not in target_mod:
target_mod[tgt] = {}
target_mod[tgt][gvar] = func
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/ir/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def keys(self):
def __getitem__(self, k):
return self._dict().__getitem__(k)

def get(self, key, default=None):
"""Get an element with a default value."""
return self._dict().get(key, default)

def __contains__(self, k):
return self._dict().__contains__(k)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/relax_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def extracted_tasks_to_tune_contexts(
get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]),
fork_seed(seed, n=len(extracted_tasks)),
):
if task.mod.attrs is not None and task.mod.attrs.get("tir.is_scheduled", False):
if task.mod.attrs.get("tir.is_scheduled", False):
warnings.warn("The task {task.task_name} is already scheduled, skipping it.")
continue
tasks.append(
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,11 +526,11 @@ def __init__(self, mod):
super().__init__(mod)

def visit_function_(self, f):
if f.attrs is None or "Composite" not in f.attrs:
if "Composite" not in f.attrs:
body = super().visit_expr(f.body)
new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span)

if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]:
if "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]:
composite_func = body.blocks[0].bindings[0].value
if "WorkspaceSize" in composite_func.attrs:
return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"])
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n
detached_mod = tvm.IRModule()
params_dict = dict()
for gv, func in mod.functions_items():
if func.attrs is not None and "params" in func.attrs:
if "params" in func.attrs:
params = list(func.attrs["params"])
if not all([isinstance(param, tvm.nd.NDArray) for param in params]):
raise ValueError(
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/relax/training/setup_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,15 @@ def _check_well_formed(self, mod: IRModule):
) from exc

# Check function attrs
if (
mod.attrs is None
or not self.PARAM_NUM_ATTR_KEY in mod.attrs
or not isinstance(mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm)
if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance(
mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm
):
raise ValueError(
f"SetupTrainer: The backbone module should has an integer attribute named "
f"{self.PARAM_NUM_ATTR_KEY}"
)
if (
mod.attrs is None
or not self.STATE_NUM_ATTR_KEY in mod.attrs
or not isinstance(mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm)
if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance(
mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm
):
raise ValueError(
f"SetupTrainer: The backbone module should has an integer attribute named "
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
self.memory_free_insertion = None

def transform(self, func: relax.Function) -> relax.Function:
if func.attrs is not None and "num_input" in func.attrs:
if "num_input" in func.attrs:
num_input = func.attrs["num_input"].value
else:
num_input = 0
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
super().__init__(mod)

def visit_function_(self, func: relax.Function) -> relax.Expr:
if func.attrs is not None and "num_input" in func.attrs:
if "num_input" in func.attrs:
num_input = func.attrs["num_input"].value
else:
num_input = 0
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class QPadArgs(Enum):

def is_npu_func(func: relay.Function) -> bool:
"""Check if the given function is an NPU function."""
return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u"
return "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u"


def is_composite_func(func: relay.Function, name: str) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(self, params, body, ret_type=None, type_params=None, attrs=None, sp
if type_params is None:
type_params = convert([])

if attrs is None:
attrs = tvm.ir.make_node("DictAttrs")

self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, type_params, attrs, span
)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/quantize/_partition_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def partition_prefix(mod, quantized_dtypes):
prefix_cutter = PrefixCutter(func.params, quantized_dtypes)
mid_body = prefix_cutter.visit(func.body)
assert not func.type_params, "unimplemented"
assert func.attrs is None, "unimplemented"
assert not func.attrs, "unimplemented"
mid_func = relay.Function(relay.analysis.free_vars(mid_body), mid_body)
mid_mod = tvm.IRModule.from_expr(mid_func)
mid_mod = relay.transform.InferType()(mid_mod)
Expand Down Expand Up @@ -288,7 +288,7 @@ def partition_suffix(mod, quantized_dtypes):
suffix_cutter = SuffixCutter(quantized_dtypes)
post_body = suffix_cutter.visit(func.body)
assert not func.type_params, "unimplemented"
assert func.attrs is None, "unimplemented"
assert not func.attrs, "unimplemented"
post_func = relay.Function(relay.analysis.free_vars(post_body), post_body, func.ret_type)
post_mod = tvm.IRModule.from_expr(post_func)
post_mod = relay.transform.InferType()(post_mod)
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,11 @@ def visit_call(self, call: Expr):

# lowered operator: generate a call to a function that gets the PackedFunc
# from TVM's registry
if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1:
if (
isinstance(func, Function)
and hasattr(func.attrs, "Primitive")
and int(func.attrs.Primitive) == 1
):
op_call_def, op_call = self.create_op_call(func, call.args, fields)
return (op_call, field_defs + [op_call_def])

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/testing/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,12 +1076,12 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"):
main = mod
else:
main = mod["main"]
if main.attrs is None or main.attrs["output_tensor_names"] is None:
if "output_tensor_names" in main.attrs:
output_tensor_names = main.attrs["output_tensor_names"]
else:
output_tensor_names = (
["output"] if output_count == 1 else [f"output{i}" for i in range(output_count)]
)
else:
output_tensor_names = main.attrs["output_tensor_names"]

return dict(zip(output_tensor_names, out))

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(
else:
raise TypeError("params can only contain Var or Buffer")

if attrs is None:
attrs = tvm.ir.make_node("DictAttrs")

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc,
param_list,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
auto module = IRModule({}, {});
DiagnosticContext diag_ctx = DiagnosticContext::Default(module);
auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}, {}));
module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, diag_ctx);

auto mod = [module, solver, diag_ctx](std::string name) -> PackedFunc {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
CHECK_EQ(before_arity, after_arity);
lifted_func =
Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(),
free_type_vars, /*attrs=*/{}, func->span);
free_type_vars, DictAttrs(), func->span);
lifted_func->virtual_device_ = result_virtual_device;
lifted_func = MarkClosure(lifted_func);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ Expr InferTypeWithModule(const Expr& expr, const IRModule& m) {
if (expr.as<FunctionNode>()) {
func = Downcast<Function>(expr);
} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod));
}
mod->Add(gvar, func);
mod = transform::InferType()(mod);
Expand Down
4 changes: 3 additions & 1 deletion src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace relay {

Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
tvm::Array<TypeVar> type_params, DictAttrs attrs, Span span) {
CHECK(attrs.defined());

ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
ICHECK(params.defined());
ICHECK(type_params.defined());
Expand Down Expand Up @@ -251,7 +253,7 @@ TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer")

TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext")
.set_body_typed([](RelayExpr expr, IRModule mod) -> Function {
return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod));
});

TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr")
Expand Down
Loading

0 comments on commit b2204ae

Please sign in to comment.