Skip to content

Commit

Permalink
[Metaschedule] Support rocm and spirv
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 12, 2022
1 parent eb0cae2 commit daea033
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[Sche
# pylint: disable=protected-access
if target.kind.name == "llvm":
return DefaultLLVM._sch_rules()
if target.kind.name == "cuda":
if target.kind.name in ["cuda", "rocm", "vulkan"]:
return DefaultCUDA._sch_rules()
# pylint: enable=protected-access
raise ValueError(f"Unsupported target: {target}")
Expand All @@ -426,7 +426,7 @@ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]:
# pylint: disable=protected-access
if target.kind.name == "llvm":
return DefaultLLVM._postproc()
if target.kind.name == "cuda":
if target.kind.name in ["cuda", "rocm", "vulkan"]:
return DefaultCUDA._postproc()
# pylint: enable=protected-access
raise ValueError(f"Unsupported target: {target}")
Expand All @@ -445,7 +445,7 @@ def _mutator_probs(
# pylint: disable=protected-access
if target.kind.name == "llvm":
return DefaultLLVM._mutator_probs()
if target.kind.name == "cuda":
if target.kind.name in ["cuda", "rocm", "vulkan"]:
return DefaultCUDA._mutator_probs()
# pylint: enable=protected-access
raise ValueError(f"Unsupported target: {target}")
Expand Down
3 changes: 3 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mtriple")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("max_threads_per_block", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(64))
.add_attr_option<Integer>("max_shared_memory_per_block", Integer(64000))
.set_default_keys({"rocm", "gpu"})
.set_attrs_preprocessor(UpdateROCmAttrs);

Expand Down Expand Up @@ -349,6 +351,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Integer>("supported_subgroup_operations")
// Physical device limits
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("max_threads_per_block", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(1))
.add_attr_option<Integer>("max_block_size_x")
.add_attr_option<Integer>("max_block_size_y")
Expand Down

0 comments on commit daea033

Please sign in to comment.