Skip to content

Commit

Permalink
[TIR] Improved error message if tir.Schedule passed to lower/build (a…
Browse files Browse the repository at this point in the history
…pache#11913)

Previously, if a TIR Schedule is passed to `tvm.lower`, the error
message is returned `ValueError: ('Expected input to be an IRModule,
PrimFunc or Schedule, but got, ', <class
'tvm.tir.schedule.schedule.Schedule'>)`.  This can cause user
confusion, as the expected class name in the error message does not
differentiate between between a `tvm.te.Schedule` and a
`tvm.tir.Schedule`.  Updated error message to explicitly state that
this should be a `te.Schedule`.
  • Loading branch information
Lunderberg authored and Mikael Sevenier committed Jul 26, 2022
1 parent ecb03f6 commit ac36cfe
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@

import tvm.tir

from tvm import te

from tvm.runtime import Module
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
Expand Down Expand Up @@ -62,7 +63,7 @@ def get_binds(args, compact=False, binds=None):


def schedule_to_module(
sch: schedule.Schedule,
sch: te.Schedule,
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
Expand Down Expand Up @@ -91,7 +92,7 @@ def schedule_to_module(


def lower(
inp: Union[schedule.Schedule, PrimFunc, IRModule],
inp: Union[te.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
Expand Down Expand Up @@ -129,13 +130,15 @@ def lower(
return ffi.lower_module(inp, simple_mode)
if isinstance(inp, PrimFunc):
return ffi.lower_primfunc(inp, name, simple_mode)
if isinstance(inp, schedule.Schedule):
if isinstance(inp, te.Schedule):
return ffi.lower_schedule(inp, args, name, binds, simple_mode)
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp))
raise ValueError(
f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}"
)


def build(
inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
inputs: Union[te.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
target: Optional[Union[str, Target]] = None,
target_host: Optional[Union[str, Target]] = None,
Expand Down Expand Up @@ -219,7 +222,7 @@ def build(
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if isinstance(inputs, te.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
input_mod = lower(inputs, args, name=name, binds=binds)
Expand All @@ -234,7 +237,8 @@ def build(
input_mod = lower(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
f"Inputs must be Schedule, IRModule or dict of target to IRModule, "
f"Inputs must be te.Schedule, IRModule, PrimFunc, "
f"or dict of target to IRModule, "
f"but got {type(inputs)}."
)

Expand Down

0 comments on commit ac36cfe

Please sign in to comment.