Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict CK solvers to only run on MI100, MI200 and MI300 #2533

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions src/include/miopen/solver/ck_utility_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@ namespace miopen {
namespace solver {
namespace ck_utility {

// Disclaimer: Currently CK is only supported in MI100, MI200 and MI300.
// Please use is_ck_whitelist instead of this function.
static inline bool is_ck_supported_hardware(const Handle& handle)
{
return (StartsWith(handle.GetDeviceName(), "gfx803") && handle.GetMaxComputeUnits() == 64) ||
StartsWith(handle.GetDeviceName(), "gfx900") ||
StartsWith(handle.GetDeviceName(), "gfx906") ||
StartsWith(handle.GetDeviceName(), "gfx908") ||
StartsWith(handle.GetDeviceName(), "gfx90a") ||
StartsWith(handle.GetDeviceName(), "gfx940") ||
StartsWith(handle.GetDeviceName(), "gfx941") ||
StartsWith(handle.GetDeviceName(), "gfx942") ||
StartsWith(handle.GetDeviceName(), "gfx1030") ||
StartsWith(handle.GetDeviceName(), "gfx1031") ||
StartsWith(handle.GetDeviceName(), "gfx1100") ||
StartsWith(handle.GetDeviceName(), "gfx1101") ||
StartsWith(handle.GetDeviceName(), "gfx1102");
}

// MI100 : gfx908
// MI200 : gfx90a
// MI300 : gfx940, gfx941, gfx942
Expand Down
2 changes: 1 addition & 1 deletion src/solver/batchnorm/backward_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ bool BnCKBwdBackward::IsApplicable(
return false;
if(!bn_problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
if(!ck_utility::is_ck_whitelist(context.GetStream()))
return false;
if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/batchnorm/forward_training_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ bool BnCKFwdTraining::IsApplicable(
return false;
if(!bn_problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
if(!ck_utility::is_ck_whitelist(context.GetStream()))
return false;

switch(bn_problem.GetXDesc().GetType())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx,
return false;
if(!ctx.use_hip_kernels)
return false;
if(!ck_utility::is_ck_supported_hardware(ctx.GetStream()))
if(!ck_utility::is_ck_whitelist(ctx.GetStream()))
return false;
if(!problem.IsLayoutDefault())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/norm/forward_layernorm2d_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ bool Layernorm2DCKForward::IsApplicable(
return false;
if(!problem.IsLargeSize())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
if(!ck_utility::is_ck_whitelist(context.GetStream()))
return false;

switch(problem.GetXDesc().GetType())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/norm/forward_layernorm4d_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ bool Layernorm4DCKForward::IsApplicable(
return false;
if(!problem.IsLargeSize())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
if(!ck_utility::is_ck_whitelist(context.GetStream()))
return false;

switch(problem.GetXDesc().GetType())
Expand Down