Skip to content

Commit

Permalink
[StackVM] Updated CodeGenStackVM to handle DeclBuffer (apache#15036)
Browse files Browse the repository at this point in the history
Part of changes being split out from
apache#14778 into independent
portions. This commit allows DeclBuffer to occur in the lowered TIR
passed to CodeGenStackVM.
  • Loading branch information
Lunderberg authored Jun 7, 2023
1 parent 26907f9 commit 3d72d4b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
LOG(FATAL) << "Dynamic allocation not supported";
}

void CodeGenStackVM::VisitStmt_(const DeclBufferNode* op) { VisitStmt(op->body); }

void CodeGenStackVM::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
Expand Down
1 change: 1 addition & 0 deletions src/target/stackvm/codegen_stackvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class CodeGenStackVM : public ExprFunctor<void(const PrimExpr&)>,
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const IfThenElseNode* op) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const DeclBufferNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
void VisitStmt_(const AssertStmtNode* op) final;
void VisitStmt_(const EvaluateNode* op) final;
Expand Down
22 changes: 18 additions & 4 deletions tests/python/unittest/test_target_codegen_vm_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import tvm
import tvm.testing
from tvm import te
from tvm.script import tir as T, ir as I

import numpy as np


Expand Down Expand Up @@ -122,8 +124,20 @@ def check(f):
run_jit(mod, check)


def test_codegen_decl_buffer():
"""The codegen should accept DeclBuffer nodes in its input"""

@I.ir_module
class mod:
@T.prim_func
def kernel(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "kernel"})
A_buf = T.decl_buffer([256], dtype="float32", scope="global", data=A_data)

target = tvm.target.Target("stackvm")
stackvm_codegen = tvm.get_global_func("target.build.stackvm")
stackvm_codegen(mod, target)


if __name__ == "__main__":
test_vm_parallel()
test_stack_vm_loop()
test_stack_vm_basic()
test_stack_vm_cond()
tvm.testing.main()

0 comments on commit 3d72d4b

Please sign in to comment.