diff --git a/src/ATen/native/xpu/sycl/MaxUnpoolingKernels.cpp b/src/ATen/native/xpu/sycl/MaxUnpoolingKernels.cpp index afc28f456..1a8f33192 100644 --- a/src/ATen/native/xpu/sycl/MaxUnpoolingKernels.cpp +++ b/src/ATen/native/xpu/sycl/MaxUnpoolingKernels.cpp @@ -7,16 +7,15 @@ #include #include #include -#include #include #include #include +#include +#include #include #include - -using namespace dnnl; -using namespace at::native; -using namespace at::native::onednn; +#include +#include namespace at::native::xpu { diff --git a/src/comm/MemoryFormat.h b/src/comm/MemoryFormat.h index 63df5c663..af816aaaa 100644 --- a/src/comm/MemoryFormat.h +++ b/src/comm/MemoryFormat.h @@ -131,5 +131,23 @@ inline bool is_smf_channels_last(const Tensor& t) { return is_channels_last(suggest_memory_format_sycl(t)); } +inline MemoryFormat get_cl_tag_by_ndim(const int64_t ndim) { + TORCH_CHECK( + 3 == ndim || 4 == ndim || 5 == ndim, + "ndim must be 3, 4 or 5 when get cl tag"); + if (3 == ndim) { +#ifdef USE_CHANNELS_LAST_1D + return CHANNELSLAST1D_SYCL; +#else + // if doesn't enable channels last 1d + return at::MemoryFormat::Contiguous; +#endif + } else if (5 == ndim) { + return at::MemoryFormat::ChannelsLast3d; + } else { + return at::MemoryFormat::ChannelsLast; + } +} + } // namespace sycl } // namespace xpu