Skip to content

Commit

Permalink
[TensorIR][Minor] Allow Tuple/Array in TE lowering (#8916)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Sep 3, 2021
1 parent ac9bfd9 commit a890bb9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import tvm._ffi
import tvm.tir
import tvm.tir._ffi_api

from tvm._ffi.base import string_types
from tvm.ir import Array
from tvm.runtime import convert

from . import _ffi_api
from . import tag as _tag
from . import tensor as _tensor
from . import _ffi_api


def placeholder(shape, dtype=None, name="placeholder"):
Expand Down Expand Up @@ -431,6 +431,7 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):

def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
"""Create a TensorIR PrimFunc from tensor expression
Parameters
----------
ops : List[Tensor]
Expand Down Expand Up @@ -473,6 +474,6 @@ def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
func : tir.PrimFunc
The created function.
"""
if not isinstance(ops, list):
if not isinstance(ops, (list, tuple, Array)):
ops = [ops]
return _ffi_api.CreatePrimFunc(ops)

0 comments on commit a890bb9

Please sign in to comment.