From a890bb9943ab789db5d6b25efc8cfcab359033eb Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 3 Sep 2021 03:51:48 -0700 Subject: [PATCH] [TensorIR][Minor] Allow Tuple/Array in TE lowering (#8916) --- python/tvm/te/operation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 6af3429b3eef..a0b9b4373535 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -22,13 +22,13 @@ import tvm._ffi import tvm.tir import tvm.tir._ffi_api - from tvm._ffi.base import string_types +from tvm.ir import Array from tvm.runtime import convert +from . import _ffi_api from . import tag as _tag from . import tensor as _tensor -from . import _ffi_api def placeholder(shape, dtype=None, name="placeholder"): @@ -431,6 +431,7 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None): def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression + Parameters ---------- ops : List[Tensor] @@ -473,6 +474,6 @@ def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: func : tir.PrimFunc The created function. """ - if not isinstance(ops, list): + if not isinstance(ops, (list, tuple, Array)): ops = [ops] return _ffi_api.CreatePrimFunc(ops)