From d60d45187b1dd891ec8aa2abc42eca8eda5cb242 Mon Sep 17 00:00:00 2001 From: Kanya-Mo <167922169+Kanya-Mo@users.noreply.github.com> Date: Sun, 16 Jun 2024 18:46:47 -0700 Subject: [PATCH] [device check] replace dpcppsupportfp64 with has_2d_block_array (#4138) (#4354) * [Fix] Replace the dpcppSupportFP64 device check method with has_2d_block_array --------- Signed-off-by: Chen, Zejun Co-authored-by: Jinghui (cherry picked from commit fff816b944b032d9f11ca5c18888175f2f238aae) Co-authored-by: zejun --- csrc/gpu/aten/operators/GRU.cpp | 4 +++- csrc/gpu/oneDNN/Utils.h | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/csrc/gpu/aten/operators/GRU.cpp b/csrc/gpu/aten/operators/GRU.cpp index ac62c73b3..72bf5fad8 100644 --- a/csrc/gpu/aten/operators/GRU.cpp +++ b/csrc/gpu/aten/operators/GRU.cpp @@ -682,8 +682,10 @@ bool is_xetla_gru_available( const int input_size, const int hidden_size, const ScalarType dtype) { + DeviceIndex curDevID; + AT_DPCPP_CHECK(dpcppGetDevice(&curDevID)); // TODO: XeTLA will proive a general API to check supported platform - if (dpcppSupportFP64()) { + if (Settings::I().has_2d_block_array(curDevID)) { if (dtype == ScalarType::BFloat16) { // TODO: support fp16 // More shapes could be supported by adding kernel configs manually if (batch_size <= 1024 && input_size <= 512 && hidden_size <= 1024) { diff --git a/csrc/gpu/oneDNN/Utils.h b/csrc/gpu/oneDNN/Utils.h index 0b618fbfd..9c3583c7b 100644 --- a/csrc/gpu/oneDNN/Utils.h +++ b/csrc/gpu/oneDNN/Utils.h @@ -644,13 +644,16 @@ static inline int get_memory_layout_for_conv( const at::Tensor& src, const at::Tensor& weight, bool is_transpose) { + DeviceIndex curDevID; + AT_DPCPP_CHECK(dpcppGetDevice(&curDevID)); + if (!src.defined() || src.is_sparse()) { // suggest channels_first return MEMORY_LAYOUT_FOR_CONV::ChannelsFirst; } if (is_transpose || src.is_quantized() || weight.is_quantized() || - (!dpcppSupportFP64())) { + (!Settings::I().has_2d_block_array(curDevID))) { if (Settings::I().is_onednn_layout_enabled()) { // suggest blocked return MEMORY_LAYOUT_FOR_CONV::Blocked; @@ -666,7 +669,8 @@ static inline int get_memory_layout_for_conv( // inference workloads on ATSM platform, the conv will use blocked format // used double support to distinguish is atsm or not - auto suggest_block_format = !dpcppSupportFP64() // on ATSM platform + auto suggest_block_format = + !Settings::I().has_2d_block_array(curDevID) // on ATSM platform && (c10::InferenceMode::is_enabled() || !at::GradMode::is_enabled()); // for inference workload if (suggest_block_format) {