Skip to content

Commit

Permalink
fix IRModule parsing by resolving GlobalVars later (apache#41)
Browse files Browse the repository at this point in the history
* fix IRModule parsing by resolving GlobalVars later

* disable fast path that causes type inference problem for now

* print checked type on vars if present

* document ResolveGlobals
  • Loading branch information
altanh authored and junrushao committed Feb 9, 2023
1 parent e441225 commit fdb3d71
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 195 deletions.
40 changes: 27 additions & 13 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,68 +19,82 @@
import tvm.ir
from . import _ffi_api


@tvm._ffi.register_object("relax.FunctionPass")
class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relax.Function in a module. A function
pass class should be created through `function_pass`.
"""

def FMARewrite() -> tvm.transform.Pass:

def FMARewrite() -> tvm.ir.transform.Pass:
"""Perform fused multiply add rewriting in dataflow blocks.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.FMARewrite()


def ToNonDataflow() -> tvm.transform.Pass:
def ToNonDataflow() -> tvm.ir.transform.Pass:
"""Transform all dataflow structure to non-dataflow version.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.ToNonDataflow()


def CallDPSRewrite() -> tvm.transform.Pass:
def CallDPSRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_dps.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.CallDPSRewrite()


def VMMemoryLower() -> tvm.transform.Pass:
def VMMemoryLower() -> tvm.ir.transform.Pass:
"""Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.VMMemoryLower()


def VMShapeLower() -> tvm.transform.Pass:
"""Lower the shape expressions in relax to VM shape heap manipulations and generate related
def VMShapeLower() -> tvm.ir.transform.Pass:
"""Lower the shape expressions in relax to VM shape heap manipulations and generate related
TIR functions to do shape calculations.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.VMShapeLower()


def ToANF() -> tvm.transform.Pass:
def ToANF() -> tvm.ir.transform.Pass:
"""Transforming Relax IR to A-normal form.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.ToANF()


def ResolveGlobals() -> tvm.ir.transform.Pass:
"""Resolve global variables using string equality. This ensures all GlobalVars in the IR refer
to the correct GlobalVar of the input IRModule. An error is reported if any GlobalVar cannot be
resolved.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.ResolveGlobals()
9 changes: 6 additions & 3 deletions python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,9 +970,12 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr:
var_name = expr.id.name
if _is_registered(var_name, op_set=self._registered_ops):
return relay.op.get(var_name)
if var_name not in self.scope:
self.report_error("undefined variable", expr.span)
return self.scope[var_name]
if var_name in self.scope:
return self.scope[var_name]
# NOTE: this is a "hack" to get around Python eagerly parsing class method decorators
# first (meaning we need to resolve them after the functions are parsed). These
# GlobalVars need to be resolved using string equality only.
return relay.GlobalVar(var_name)

elif isinstance(expr, ast.Constant):
# FIXME(@altanh): use internal representation that doesn't have precision limits here
Expand Down
28 changes: 14 additions & 14 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {

private:
/*!
* \brief Memoization map for expressions using Id for equality of variables.
*/
* \brief Memoization map for expressions using Id for equality of variables.
*/
class ExprMemo {
public:
public:
Optional<Expr> Get(const Expr& expr) {
if (const VarNode* var = expr.as<VarNode>()) {
auto it = var_memo_.find(var->vid);
Expand All @@ -230,7 +230,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
}
}

private:
private:
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
};
Expand Down Expand Up @@ -370,7 +370,9 @@ Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_
Var BlockBuilderNode::Emit(const VarBinding& binding) {
BlockFrame* cur_frame = CurrentFrame();
if (cur_frame->is_dataflow) {
ICHECK(binding->var.as<DataflowVarNode>());
ICHECK(binding->var.as<DataflowVarNode>())
<< "Emit can only be used for local bindings in a dataflow block, use EmitOutput for "
"output bindings instead";
}
cur_frame->bindings.push_back(binding);
binding_table_[binding->var->vid] = binding->value;
Expand Down Expand Up @@ -408,9 +410,11 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& p

Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) {
BlockFrame* cur_frame = CurrentFrame();
if (cur_frame->is_dataflow && binding->var.defined()) {
ICHECK(!binding->var.as<DataflowVarNode>())
<< "cannot bind DataflowVar outside dataflow block.";
if (binding->var.defined()) {
ICHECK(!cur_frame->is_dataflow || binding->var.as<DataflowVarNode>())
<< "EmitMatchShape can only be used for local bindings in a dataflow block.";
ICHECK(cur_frame->is_dataflow || !binding->var.as<DataflowVarNode>())
<< "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint();
}
cur_frame->bindings.push_back(binding);
// TODO(@altanh, @yuchen): what value should we bind? Consider
Expand Down Expand Up @@ -511,13 +515,9 @@ BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() {
return &block_stack_.top();
}

NameTable* BlockBuilderNode::name_table() {
return name_table_.get();
}
NameTable* BlockBuilderNode::name_table() { return name_table_.get(); }

BlockBuilder BlockBuilder::Create() {
return BlockBuilder(make_object<BlockBuilderNode>());
}
BlockBuilder BlockBuilder::Create() { return BlockBuilder(make_object<BlockBuilderNode>()); }

TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilder::Create);

Expand Down
29 changes: 17 additions & 12 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,20 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) {
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);

if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
// no-op if there is no change
builder_->Emit(GetRef<VarBinding>(binding));
return;
}
auto emit = [this](VarBinding b) {
if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as<DataflowVarNode>()) {
this->builder_->EmitOutput(b);
} else {
this->builder_->Emit(b);
}
};

// FIXME(@altanh): try to clean up all the fast paths and ty/shape infer, it's getting unwieldy
// if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
// // no-op if there is no change
// emit(GetRef<VarBinding>(binding));
// return;
// }

{
Var temp = WithShapeAndType(new_var, new_value->shape_, new_value->checked_type_);
Expand All @@ -368,11 +377,7 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) {
}
}

if (builder_->CurrentBlockIsDataFlow() && !new_var.as<DataflowVarNode>()) {
builder_->EmitOutput(VarBinding(new_var, new_value));
} else {
builder_->Emit(VarBinding(new_var, new_value));
}
emit(VarBinding(new_var, new_value));
}

void ExprMutator::VisitBinding_(const MatchShapeNode* binding) {
Expand All @@ -387,8 +392,8 @@ void ExprMutator::VisitBinding_(const MatchShapeNode* binding) {
if (new_value->checked_type_.defined() && new_value->checked_type_.as<DynTensorTypeNode>()) {
new_shape = new_pattern;
}
Var temp =
WithShapeAndType(this->VisitVarDef(binding->var), new_shape, new_value->checked_type_);
new_var = this->VisitVarDef(binding->var);
Var temp = WithShapeAndType(new_var, new_shape, new_value->checked_type_);
if (!temp.same_as(new_var)) {
new_var = temp;
this->var_remap_[binding->var->vid] = new_var;
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/tensor/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) {
auto* t1 = rhs_type.as<DynTensorTypeNode>();
if (!t0 || !t1) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Both lhs and rhs should be DynTensor for broadcasting");
<< "Both lhs and rhs should be DynTensor for broadcasting, but got "
<< lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey());
}

DataType output_dtype;
Expand Down
65 changes: 65 additions & 0 deletions src/relax/transform/resolve_globals.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/resolve_globals.cc
* \brief Resolve GlobalVars using string equality.
*/
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

namespace tvm {
namespace relax {

class GlobalVarResolver : public ExprMutator {
public:
GlobalVarResolver(IRModule mod, DiagnosticContext diag_ctx) : mod_(mod), diag_ctx_(diag_ctx) {}

Expr VisitExpr_(const GlobalVarNode* gvar) {
if (!mod_->ContainGlobalVar(gvar->name_hint)) {
diag_ctx_.Emit(Diagnostic::Error(gvar->span)
<< "undefined variable/global \"" << gvar->name_hint << "\"");
return GetRef<GlobalVar>(gvar);
}
return mod_->GetGlobalVar(gvar->name_hint);
}

private:
/*! \brief the IRModule used for GlobalVar lookup. */
IRModule mod_;
DiagnosticContext diag_ctx_;
};

namespace transform {

Pass ResolveGlobals() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[](Function f, IRModule m, PassContext pc) {
// TODO(@altanh): make sure pc always has diag_ctx?
GlobalVarResolver resolver(m, pc->diag_ctx.value());
return Downcast<Function>(resolver.VisitExpr(f));
};
return CreateFunctionPass(pass_func, 0, "ResolveGlobals", {});
}

TVM_REGISTER_GLOBAL("relax.transform.ResolveGlobals").set_body_typed(ResolveGlobals);

} // namespace transform

} // namespace relax
} // namespace tvm
11 changes: 8 additions & 3 deletions src/relay/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,18 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
}

Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) {
// TODO(@altanh): we should consider moving annotation into binding
Doc doc;
if (var->type_annotation.defined()) {
Type annotation = var->checked_type_;
if (!annotation.defined()) {
annotation = var->type_annotation.value_or(Type());
}
if (annotation.defined()) {
doc << ": ";
if (const relax::DynTensorTypeNode* tty = var->type_annotation.as<relax::DynTensorTypeNode>()) {
if (const relax::DynTensorTypeNode* tty = annotation.as<relax::DynTensorTypeNode>()) {
doc << PrintTensorAnnotation(GetRef<DynTensorType>(tty), var->shape_);
} else {
doc << Print(var->type_annotation);
doc << Print(annotation);
}
}
return doc;
Expand Down
60 changes: 25 additions & 35 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,45 +496,35 @@ def f(x: Tensor):


def test_class_irmodule():
# FIXME(@altanh): Python class method decorators are executed eagerly before the class
# decorator, which means each function is parsed in isolation. This means we cannot resolve
# global variables at parsing time (or indeed any undefined identifier), so we either need to
# 1. defer parsing in the function decorators (so that the ir_module decorator can populate
# global variables first), although this means non-IRModule uses of the function decorators
# will no longer return Function/PrimFunc but some kind of wrapper type. This could cause
# problems if we pass them directly to things that expect Function/PrimFuncs.
# 2. parse every undefined identifier to a placeholder node (e.g. "UndefinedVar"), and run an
# IRModule -> IRModule pass that tries to resolve identifiers.
src = """@tvm.script.ir_module
class MyModule:
@T.prim_func
def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] += A[vi, vk] * B[vj, vk]
@tvm.script.ir_module
class MyModule:
@T.prim_func
def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))

@R.function
def f(x: Tensor[(n, n), _]) -> Tensor:
return g(x)
for i, j, k in T.grid(128, 128, 128):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] += A[vi, vk] * B[vj, vk]

@R.function
def g(y: Tensor[(n, n), _]) -> Tensor:
return relax.call_dps((n, n), my_matmul, (y, y))
@R.function
def f(x: Tensor[(n, n), _]) -> Tensor:
return g(x)

@R.function
def h(x, y, z):
_ = my_matmul(x, y, z)
return z
"""
@R.function
def g(y: Tensor[(n, n), _]) -> Tensor:
return relax.call_dps((n, n), my_matmul, (y, y))

@R.function
def h(x, y, z):
_ = my_matmul(x, y, z)
return z

my_module = tvm.script.relax.parser.from_source(src)
my_module = MyModule
assert isinstance(my_module, tvm.IRModule)

var_f = my_module.get_global_var("f")
Expand Down
Loading

0 comments on commit fdb3d71

Please sign in to comment.