Skip to content

Commit

Permalink
Update Shape lowering pass (apache#38)
Browse files Browse the repository at this point in the history
* Update shape lowering pass.

* Rebase.
  • Loading branch information
YuchenJin authored and junrushao committed Feb 9, 2023
1 parent df3842f commit e441225
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 36 deletions.
77 changes: 59 additions & 18 deletions src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,38 @@
namespace tvm {
namespace relax {

/*!
* \brief Visitor to apply a function to every Expr it visits. Also applies the function
* to the shape field of the var definition site if the var's shape is a ShapeExpr.
*/
class ExprApplyVisitWithShape : public ExprVisitor {
public:
explicit ExprApplyVisitWithShape(std::function<void(const Expr&)> f) : f_(f) {}

void VisitVarDef(const Var& var) {
if (var.as<DataflowVarNode>()) {
this->VisitExpr(Downcast<DataflowVar>(var));
} else {
this->VisitExpr(var);
}
if (var->shape_.operator bool() && var->shape_.value().as<ShapeExprNode>()) {
f_(Downcast<ShapeExpr>(var->shape_.value()));
}
}

void VisitExpr(const Expr& e) final {
ExprVisitor::VisitExpr(e);
f_(e);
}

private:
std::function<void(const Expr&)> f_;
};

void PostOrderVisitWithShape(const Expr& e, std::function<void(const Expr&)> fvisit) {
ExprApplyVisitWithShape(fvisit).VisitExpr(e);
}

class VMShapeLowerMutator : public ExprMutator {
public:
static DataType ShapeDType() { return DataType::Int(64); };
Expand All @@ -58,18 +90,11 @@ class VMShapeLowerMutator : public ExprMutator {
}

void VisitBinding_(const MatchShapeNode* binding) override {
Expr shape = ExprMutator::VisitExpr(binding->value);
static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape");
auto store_shape_attr = make_object<ShapeHeapAttrs>();

Array<PrimExpr> pattern = binding->pattern;
Array<Integer> indices;
for (size_t i = 0; i < pattern.size(); ++i) {
int idx = expr2slot_.at(pattern[i]);
indices.push_back(idx);
}
store_shape_attr->indices = indices;
builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv");
Expr value = ExprMutator::VisitExpr(binding->value);

// TODO(@yuchen): match_shape overloaded semantic: value is ShapeType
Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {value}), "sh");
StoreShape(shape, binding->pattern);
}

Expr VisitExpr_(const ShapeExprNode* node) override {
Expand Down Expand Up @@ -97,16 +122,18 @@ class VMShapeLowerMutator : public ExprMutator {
}

Expr VisitExpr_(const FunctionNode* node) override {
builder_->BeginBindingBlock();
builder_->Emit(VarBinding(
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));
Array<Var> params;
for (Var param : node->params) {
params.push_back(this->VisitVarDef(param));
if (param->shape_.operator bool() && param->shape_.value().as<ShapeExprNode>()) {
Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh");
StoreShape(shape, Downcast<ShapeExpr>(param->shape_.value())->values);
}
}
Type ret_type = this->VisitType(node->ret_type);

builder_->BeginBindingBlock();
builder_->Emit(VarBinding(
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));

Expr new_body = this->VisitExpr(node->body);

Array<BindingBlock> blocks;
Expand Down Expand Up @@ -174,10 +201,24 @@ class VMShapeLowerMutator : public ExprMutator {
}
}
};
PostOrderVisit(expr, func);
PostOrderVisitWithShape(expr, func);
return ret;
}

/*! \brief Store symbolic shape into indices of the VM shape heap. */
void StoreShape(Expr shape, Array<PrimExpr> pattern) {
static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape");
auto store_shape_attr = make_object<ShapeHeapAttrs>();

Array<Integer> indices;
for (size_t i = 0; i < pattern.size(); ++i) {
int idx = expr2slot_.at(pattern[i]);
indices.push_back(idx);
}
store_shape_attr->indices = indices;
builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv");
}

bool IsConstantShape(ShapeExpr shape) const {
for (PrimExpr e : shape->values) {
if (!e->IsInstance<IntImmNode>()) {
Expand Down
66 changes: 63 additions & 3 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.ir.module import IRModule

import tvm.script
from tvm.script import relax as R
from tvm.script import tir as T, relax as R


def test_fma_rewrite():
Expand Down Expand Up @@ -179,8 +179,7 @@ def test_vm_shape_lowering():
class TestVMShapeLower:
@R.function
def foo(x: Tensor[_, "float32"]) -> Shape:
sh = relax.call_packed("vm.builtin.shape_of", x)
relax.match_shape(sh, (n, m))
relax.match_shape(x, (n, m))
return (n * 2, m * 3)

mod = TestVMShapeLower
Expand All @@ -196,6 +195,7 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
s1 = func.body.blocks[0].bindings[0].value
assert isinstance(s1.op, relax.ExternFunc)
assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap"
assert s1.args[0].values[0] == 4
s2 = func.body.blocks[1].bindings[0].value
assert isinstance(s2.op, relax.ExternFunc)
assert s2.op.global_symbol == "vm.builtin.shape_of"
Expand All @@ -209,6 +209,65 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
assert isinstance(s5, tvm.relay.Call)
assert s5.op.name == "relax.vm.builtin.load_shape"


def test_vm_shape_lowering_func_param_with_shape():
src = """@tvm.script.ir_module
class InputModule:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
m = T.var("int32")
n = T.var("int32")
k = T.var("int32")
A = T.match_buffer(x, (m,n))
B = T.match_buffer(y, (n,k))
C = T.match_buffer(z, (m,k))
for i, j, k in T.grid(m, k, n):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@R.function
def foo(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor:
gv0 = R.call_dps((m, k), tir_matmul, (x, w))
return gv0
"""
mod = tvm.script.relax.parser.from_source(src)

# after vm shape lowering
new_mod = relax.transform.VMShapeLower()(mod)

assert isinstance(new_mod, tvm.IRModule)
assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc)
assert isinstance(new_mod["tir_matmul"], tvm.tir.function.PrimFunc)
func = new_mod["foo"]
assert isinstance(func, tvm.relax.expr.Function)

x, w = func.params
s1 = func.body.blocks[0].bindings[0].value
assert isinstance(s1.op, relax.ExternFunc)
assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap"
assert s1.args[0].values[0] == 3

s2 = func.body.blocks[0].bindings[1].value
assert isinstance(s2.op, relax.ExternFunc)
assert s2.op.global_symbol == "vm.builtin.shape_of"
assert s2.args[0] == x
s3 = func.body.blocks[0].bindings[2].value
assert isinstance(s3, tvm.relay.Call)
assert s3.op.name == "relax.vm.builtin.store_shape"

s4 = func.body.blocks[0].bindings[3].value
assert isinstance(s4.op, relax.ExternFunc)
assert s4.op.global_symbol == "vm.builtin.shape_of"
assert s4.args[0] == w
s5 = func.body.blocks[0].bindings[2].value
assert isinstance(s5, tvm.relay.Call)
assert s5.op.name == "relax.vm.builtin.store_shape"


def test_to_anf():
x = relax.Var("x", type_annotation=relax.DynTensorType())
gv = relax.op.add(x, x)
Expand Down Expand Up @@ -241,4 +300,5 @@ def f(x: Tensor[_, "float32"]):
test_call_dps_rewrite()
test_vm_memory_lower()
test_vm_shape_lowering()
test_vm_shape_lowering_func_param_with_shape()
test_to_anf()
69 changes: 54 additions & 15 deletions tests/python/relax/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_vm_compile_stage0():
class TestVMCompileStage0:
@R.function
def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]):
z = relax.call_packed("test.vm.identity", x, y)
z = R.call_packed("test.vm.identity", x, y)
return y

mod = TestVMCompileStage0
Expand Down Expand Up @@ -272,13 +272,13 @@ def shape_func0(heap: T.handle) -> None:
@R.function
def foo(x: Tensor[_, "float32"]) -> Shape:
shape_heap: Tensor[(4,), "int64"] = relax.call_packed(
shape_heap: Tensor[(4,), "int64"] = R.call_packed(
"vm.builtin.alloc_shape_heap", (4,)
)
gv0 = relax.call_packed("vm.builtin.shape_of", x)
gv1 = relax.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1))
gv0 = R.call_packed("vm.builtin.shape_of", x)
gv1 = R.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1))
gv2 = shape_func0(shape_heap)
gv3 = relax.call_packed("vm.builtin.load_shape", shape_heap, (2, 3))
gv3 = R.call_packed("vm.builtin.load_shape", shape_heap, (2, 3))
return gv3
"""

Expand All @@ -301,8 +301,7 @@ def test_vm_compile_stage2():
class TestVMCompileStage2:
@R.function
def foo(x: Tensor[_, "float32"]) -> Shape:
sh = relax.call_packed("vm.builtin.shape_of", x)
relax.match_shape(sh, (n, m))
R.match_shape(x, (n, m))
return (n * 2, m * 3)

mod = TestVMCompileStage2
Expand All @@ -323,9 +322,9 @@ def test_vm_compile_stage3():
class TestVMCompileStage3:
@R.function
def foo(x: Tensor[(32, 16), "float32"]) -> Tensor:
with relax.dataflow():
y = relax.call_dps((32, 16), "test.vm.identity", (x))
relax.output(y)
with R.dataflow():
y = R.call_dps((32, 16), "test.vm.identity", (x))
R.output(y)
return y

mod = TestVMCompileStage3
Expand All @@ -345,11 +344,10 @@ def test_vm_compile_e2e():
class TestVMCompileE2E:
@R.function
def foo(x: Tensor[_, "float32"]) -> Tensor:
with relax.dataflow():
sh = relax.call_packed("vm.builtin.shape_of", x)
x0 = relax.match_shape(sh, (n, m))
y = relax.call_dps((n, m * 2), "test.vm.tile", (x))
relax.output(y)
with R.dataflow():
R.match_shape(x, (n, m))
y = R.call_dps((n, m * 2), "test.vm.tile", (x))
R.output(y)
return y

mod = TestVMCompileE2E
Expand All @@ -364,6 +362,46 @@ def foo(x: Tensor[_, "float32"]) -> Tensor:
res = vm["foo"](inp)
np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), res.asnumpy())

def test_vm_compile_e2e_func_param_with_shape():
src = """@tvm.script.ir_module
class TestVMCompileE2E2:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
m = T.var("int32")
n = T.var("int32")
k = T.var("int32")
A = T.match_buffer(x, (m,n))
B = T.match_buffer(y, (n,k))
C = T.match_buffer(z, (m,k))
for i, j, k in T.grid(m, k, n):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@R.function
def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor:
gv0 = R.call_dps((m, k), tir_matmul, (x, w))
return gv0
"""

mod = tvm.script.relax.parser.from_source(src)

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

import numpy as np
data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
res = vm["func"](data, weight)
expected = np.dot(data.asnumpy(), weight.asnumpy())
np.testing.assert_allclose(expected, res.asnumpy(), rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
test_vm_execute()
Expand All @@ -380,3 +418,4 @@ def foo(x: Tensor[_, "float32"]) -> Tensor:
test_vm_compile_stage2()
test_vm_compile_stage3()
test_vm_compile_e2e()
test_vm_compile_e2e_func_param_with_shape()

0 comments on commit e441225

Please sign in to comment.