Skip to content

Commit

Permalink
End2End Lowering (apache#23)
Browse files Browse the repository at this point in the history
* call_dps lowering.

* Improve shape lowering.

* Support alloc_storage for dynamic shape.

* implementt ToNonDF to transform program to non-dataflow format.

* Fix the mutator issue.

* Update build api, an issue occurred.

* vm tests can pass.

* Support shape tuple in executable seriablization.

* Fix for test.

* Minor fixes.

* Address comments.

* Add mutate binding var back.

* Visit binding var and fix tests.

Co-authored-by: YuchenJin <yuchenj@cs.washington.edu>
  • Loading branch information
2 people authored and junrushao committed Feb 9, 2023
1 parent c6749b0 commit 8f8cb8b
Show file tree
Hide file tree
Showing 14 changed files with 568 additions and 185 deletions.
6 changes: 2 additions & 4 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
* \return expr.
*/
Expr Mutate(const Expr& expr) {
if (memo_.count(expr) == 0) {
memo_[expr] = this->VisitExpr(expr);
}
return Downcast<Expr>(memo_[expr]);
return this->VisitExpr(expr);
}

Expr VisitExpr(const Expr& expr) override;
Expand Down Expand Up @@ -226,6 +223,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
virtual void VisitBinding(const Binding& binding);
virtual Var VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);

virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relax/exec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm._ffi._ctypes.packed_func import TVMRetValueHandle
from tvm.runtime import Object
from tvm.runtime.container import ShapeTuple
from tvm._ffi.base import _LIB, check_call
from . vm import Executable
from . import _ffi_api
Expand Down Expand Up @@ -89,7 +90,11 @@ def emit_call(
dst = SpecialReg.VOID_ARG
args_ = []
for arg in args:
if isinstance(arg, tvm.nd.NDArray) or isinstance(arg, tvm.DataType):
if isinstance(arg, tuple):
shape_tuple = ShapeTuple(arg)
new_arg = self.emit_constant(shape_tuple)
args_.append(new_arg)
elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)):
new_arg = self.emit_constant(arg)
args_.append(new_arg)
else:
Expand Down
38 changes: 31 additions & 7 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
from tvm import IRModule
from tvm import IRModule
from . import _ffi_api


def fma_rewrite(expr):
"""Perform fused multiply add rewriting in dataflow blocks.
Expand All @@ -29,22 +30,45 @@ def fma_rewrite(expr):
"""
return _ffi_api.fma_rewrite(expr)

def explicit_memory_rewrite(expr):
"""Perform explicit memory allocation for call_dps in dataflow blocks.
def to_non_dataflow(mod: IRModule) -> IRModule:
"""Transform all dataflow structure to non-dataflow version.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : tvm.IRModule
The input module.
"""
return _ffi_api.explicit_memory_rewrite(expr)
return _ffi_api.to_non_dataflow(mod)


def call_dps_rewrite(mod: IRModule) -> IRModule:
"""Perform explicit memory allocation for call_dps.
Parameters
----------
mod : tvm.IRModule
The input module.
"""
return _ffi_api.call_dps_rewrite(mod)


def memory_lower(mod: IRModule) -> IRModule:
"""Perform memory lowering. Lower the relax.builtin.alloc_tensor op to VM builtin functions.
Parameters
----------
mod : tvm.IRModule
The input module.
"""
return _ffi_api.memory_lower(mod)


def shape_lower(mod: IRModule) -> IRModule:
"""Lower the shape expression in relax to shape heap and TIR functions.
Parameters
----------
expr : tvm.IRModule
mod : tvm.IRModule
The input module.
"""
return _ffi_api.shape_lower(mod)
7 changes: 6 additions & 1 deletion python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.runtime import Object, Device, Module, PackedFunc
from tvm._ffi.base import _LIB, check_call
from . import _ffi_api
from . import transform
from ..rpc.base import RPC_SESS_MASK


Expand Down Expand Up @@ -164,5 +165,9 @@ def build(mod: tvm.IRModule,
lib: tvm.runtime.Module
A runtime module that contains generated code.
"""
ex, lib = _ffi_api.VMBuild(mod, target, target_host)
new_mod = transform.to_non_dataflow(mod)
new_mod = transform.call_dps_rewrite(new_mod)
new_mod = transform.memory_lower(new_mod)
new_mod = transform.shape_lower(new_mod)
ex, lib = _ffi_api.VMBuild(new_mod, target, target_host)
return ex, lib
24 changes: 21 additions & 3 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,19 @@ void ExprVisitor::VisitBinding(const Binding& binding) {
}
}

void ExprVisitor::VisitVarBinding(const VarBinding& binding) { this->VisitExpr(binding->value); }
void ExprVisitor::VisitVarBinding(const VarBinding& binding) {
this->VisitExpr(binding->value);
this->VisitExpr(binding->var);
}

void ExprVisitor::VisitMatchShape(const MatchShape& binding) {
this->VisitExpr(binding->value);
// TODO(ziheng): should we change pattern from
// Array<PrimExpr> to ShapeExpr?
this->VisitExpr(ShapeExpr(binding->pattern));
if (binding->var.defined()) {
this->VisitExpr(binding->var);
}
}

void ExprVisitor::VisitBindingBlock(const BindingBlock& block) {
Expand Down Expand Up @@ -214,6 +220,10 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
}

Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) {
auto it = var_remap_.find(GetRef<Var>(op));
if (it != var_remap_.end()) {
return it->second;
}
if (op->type_annotation.defined()) {
Type type = this->VisitType(op->type_annotation.value());
if (!op->type_annotation.same_as(type)) {
Expand Down Expand Up @@ -339,7 +349,7 @@ void ExprMutator::VisitBinding(const Binding& binding) {

Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
Expr new_value = builder_->Normalize(this->Mutate(binding->value));
Var new_var = Downcast<Var>(this->Mutate(binding->var));

// TODO(@altanh): this probably shouldn't live here, all passes would have to make sure to do it
// in this method...
// if (new_value->shape_.defined()) {
Expand All @@ -356,6 +366,7 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
// new_var->checked_type_ = new_value->checked_type_;
// }

Var new_var = Downcast<Var>(this->Mutate(binding->var));
if (!builder_->CanProveShapeEqual(new_var->shape(), new_value->shape()) ||
!StructuralEqual()(new_var->checked_type(), new_value->checked_type())) {
new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span);
Expand All @@ -380,7 +391,14 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
void ExprMutator::VisitMatchShape(const MatchShape& binding) {
Expr new_value = this->Mutate(binding->value);
Expr new_pattern = this->Mutate(ShapeExpr(binding->pattern));
Var new_var = Downcast<Var>(this->Mutate(binding->var));
Var new_var;
if (binding->var.defined()){
new_var = Downcast<Var>(this->Mutate(binding->var));
} else {
new_var = binding->var;
}

// TODO: when value's shape/type changed, create new var
builder_->EmitMatchShape(
MatchShape(new_value, Downcast<ShapeExpr>(new_pattern)->values, new_var));
}
Expand Down
86 changes: 86 additions & 0 deletions src/relax/transform/call_dps_rewrite.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/call_dps_rewrite.cc
* \brief
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>

#include "../../relay/transforms/pattern_utils.h"

namespace tvm {
namespace relax {

// ==================
// CallDPSMutator
// Example:
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
// -->
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
// rx.call_packed(op.identity, x, lv0)

class CallDPSMutator : public ExprMutator {
public:
explicit CallDPSMutator(IRModule mod) { mod_ = mod; }

IRModule Lower() {
IRModule ret_mod = IRModule();
for (auto& p : mod_->functions) {
Expr func = p.second;
if (p.second->IsInstance<FunctionNode>()) {
func = this->Mutate(p.second);
}
ret_mod->Add(p.first, Downcast<BaseFunc>(func));
}
return ret_mod;
}

Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expr expr = ExprMutator::VisitExpr_(call);
call = expr.as<CallNode>();
// TODO(@yuchen, @altanh): using mutate cause infinite recursion
// Expr expr = ExprMutator::Mutate(GetRef<Call>(call));

static const Op& call_dps_op = Op::Get("relax.call_dps");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");

if (call->op == call_dps_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_");
return tensor;
}

return GetRef<Expr>(call);
}

private:
IRModule mod_;
};

TVM_REGISTER_GLOBAL("relax.transform.call_dps_rewrite").set_body_typed([](IRModule mod) {
return CallDPSMutator(mod).Lower();
});

} // namespace relax
} // namespace tvm
81 changes: 51 additions & 30 deletions src/relax/transform/memory_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* \file src/relax/transform/memory_rewrite.cc
* \brief
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>
Expand All @@ -30,14 +31,31 @@ namespace tvm {
namespace relax {

// ==================
// ExplicitMemMutator
// MemLowerMutator
// Lower the relax.builtin.alloc_tensor op to VM builtin functions.
// Example:
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
// x = relax.builtin.alloc_tensor((m, n))
// -->
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
// rx.call_packed(op.identity, x, lv0)
// gv0 = relax.call_packed("vm.builtin.alloc_storage", (m * n), alignment, device_type,
// relax.attrs.AllocStorageAttrs) gv1 = relax.call_packed("vm.builtin.alloc_tensor", gv0, offset,
// (m, n), relax.attrs.AllocTensorAttrs)

class MemLowerMutator : public ExprMutator {
public:
explicit MemLowerMutator(IRModule mod) { mod_ = mod; }

IRModule Lower() {
IRModule ret_mod = IRModule();
for (auto& p : mod_->functions) {
Expr func = p.second;
if (p.second->IsInstance<FunctionNode>()) {
func = this->Mutate(p.second);
}
ret_mod->Add(p.first, Downcast<BaseFunc>(func));
}
return ret_mod;
}

class ExplicitMemMutator : public ExprMutator {
Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
DynTensorType tensor_type = Downcast<DynTensorType>(type);
DataType dtype = DataType(tensor_type->dtype);
Expand All @@ -63,44 +81,47 @@ class ExplicitMemMutator : public ExprMutator {
return ret;
}

BindingBlock VisitBindingBlock(const BindingBlock& block) {
builder_->BeginBindingBlock();
for (Binding binding : block->bindings) {
this->VisitBinding(binding);
}
return builder_->EndBlock();
}

Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expr expr = ExprMutator::VisitExpr_(call);
call = expr.as<CallNode>();
// TODO(@yuchen, @altanh): using mutate cause infinite recursion
// Expr expr = ExprMutator::Mutate(GetRef<Call>(call));

static const Op& call_dps_op = Op::Get("relax.call_dps");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");

if (call->op == call_dps_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
Type arg_type = Downcast<Tuple>(call->args[2])->fields[0]->checked_type();
Expr output_size = ComputeStorageSize(output_shape, arg_type);
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_");
return tensor;
if (call->op == alloc_tensor_op) {
ShapeExpr tensor_shape = Downcast<ShapeExpr>(call->args[0]);
// TODO(@yuchen): Get the type of input x, options: add an attr to relax.builtin.alloc_tensor
Type tensor_type = DynTensorType(tensor_shape->values.size(), DataType::Float(32));
Expr storage_size = ComputeStorageSize(tensor_shape, tensor_type);
ShapeExpr alignment = ShapeExpr({IntImm(DataType::Int(64), 64)});
ShapeExpr device_type = ShapeExpr({IntImm(DataType::Int(64), 1)});
auto storage_attr = make_object<AllocStorageAttrs>();
storage_attr->dtype = DataType::Float(32);
storage_attr->device_type = 1;

Var storage =
builder_->Emit(Call(ExternFunc("vm.builtin.alloc_storage"),
{storage_size, alignment}, Attrs(storage_attr)),
"storage");

ShapeExpr offset = ShapeExpr({IntImm(DataType::Int(64), 0)});
auto tensor_attr = make_object<AllocTensorAttrs>();
tensor_attr->dtype = DataType::Float(32);
Expr shape = call->args[0];
return builder_->Emit(
Call(ExternFunc("vm.builtin.alloc_tensor"), {storage, offset, shape}, Attrs(tensor_attr)),
"tensor");
}

return GetRef<Expr>(call);
}
};

Expr ExplicitMemRewrite(const Expr& e) {
return ExplicitMemMutator().Mutate(e);
}
private:
IRModule mod_;
};

TVM_REGISTER_GLOBAL("relax.transform.explicit_memory_rewrite")
.set_body_typed([](Expr expr) {
return ExplicitMemRewrite(expr);
TVM_REGISTER_GLOBAL("relax.transform.memory_lower").set_body_typed([](IRModule mod) {
return MemLowerMutator(mod).Lower();
});

} // namespace relax
Expand Down
Loading

0 comments on commit 8f8cb8b

Please sign in to comment.