Skip to content

Commit

Permalink
[device check] replace dpcppsupportfp64 with has_2d_block_array (#413…
Browse files Browse the repository at this point in the history
…8) (#4354)

* [Fix] Replace the dpcppSupportFP64 device check method with has_2d_block_array

---------

Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
Co-authored-by: Jinghui <jinghui.gu@intel.com>
(cherry picked from commit fff816b)

Co-authored-by: zejun <zejun.chen@intel.com>
  • Loading branch information
Kanya-Mo and zejun-chen authored Jun 17, 2024
1 parent 03e7d8f commit d60d451
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 3 additions & 1 deletion csrc/gpu/aten/operators/GRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,8 +682,10 @@ bool is_xetla_gru_available(
const int input_size,
const int hidden_size,
const ScalarType dtype) {
DeviceIndex curDevID;
AT_DPCPP_CHECK(dpcppGetDevice(&curDevID));
// TODO: XeTLA will proive a general API to check supported platform
if (dpcppSupportFP64()) {
if (Settings::I().has_2d_block_array(curDevID)) {
if (dtype == ScalarType::BFloat16) { // TODO: support fp16
// More shapes could be supported by adding kernel configs manually
if (batch_size <= 1024 && input_size <= 512 && hidden_size <= 1024) {
Expand Down
8 changes: 6 additions & 2 deletions csrc/gpu/oneDNN/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,16 @@ static inline int get_memory_layout_for_conv(
const at::Tensor& src,
const at::Tensor& weight,
bool is_transpose) {
DeviceIndex curDevID;
AT_DPCPP_CHECK(dpcppGetDevice(&curDevID));

if (!src.defined() || src.is_sparse()) {
// suggest channels_first
return MEMORY_LAYOUT_FOR_CONV::ChannelsFirst;
}

if (is_transpose || src.is_quantized() || weight.is_quantized() ||
(!dpcppSupportFP64())) {
(!Settings::I().has_2d_block_array(curDevID))) {
if (Settings::I().is_onednn_layout_enabled()) {
// suggest blocked
return MEMORY_LAYOUT_FOR_CONV::Blocked;
Expand All @@ -666,7 +669,8 @@ static inline int get_memory_layout_for_conv(

// inference workloads on ATSM platform, the conv will use blocked format
// used double support to distinguish is atsm or not
auto suggest_block_format = !dpcppSupportFP64() // on ATSM platform
auto suggest_block_format =
!Settings::I().has_2d_block_array(curDevID) // on ATSM platform
&& (c10::InferenceMode::is_enabled() ||
!at::GradMode::is_enabled()); // for inference workload
if (suggest_block_format) {
Expand Down

0 comments on commit d60d451

Please sign in to comment.