diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 86157e0fb32e8..d517c09349281 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -412,7 +412,7 @@ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[Sche # pylint: disable=protected-access if target.kind.name == "llvm": return DefaultLLVM._sch_rules() - if target.kind.name == "cuda": + if target.kind.name in ["cuda", "rocm", "vulkan"]: return DefaultCUDA._sch_rules() # pylint: enable=protected-access raise ValueError(f"Unsupported target: {target}") @@ -426,7 +426,7 @@ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]: # pylint: disable=protected-access if target.kind.name == "llvm": return DefaultLLVM._postproc() - if target.kind.name == "cuda": + if target.kind.name in ["cuda", "rocm", "vulkan"]: return DefaultCUDA._postproc() # pylint: enable=protected-access raise ValueError(f"Unsupported target: {target}") @@ -445,7 +445,7 @@ def _mutator_probs( # pylint: disable=protected-access if target.kind.name == "llvm": return DefaultLLVM._mutator_probs() - if target.kind.name == "cuda": + if target.kind.name in ["cuda", "rocm", "vulkan"]: return DefaultCUDA._mutator_probs() # pylint: enable=protected-access raise ValueError(f"Unsupported target: {target}") diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 6fef8b48c396d..de969cae1ccee 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -308,7 +308,9 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option("mtriple") .add_attr_option("system-lib") .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) .add_attr_option("thread_warp_size", Integer(64)) + .add_attr_option("max_shared_memory_per_block", Integer(64000)) .set_default_keys({"rocm", "gpu"}) .set_attrs_preprocessor(UpdateROCmAttrs); @@ -349,6 +351,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supported_subgroup_operations") // Physical device limits .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) .add_attr_option("thread_warp_size", Integer(1)) .add_attr_option("max_block_size_x") .add_attr_option("max_block_size_y")