Skip to content

Commit

Permalink
add new function
Browse files Browse the repository at this point in the history
  • Loading branch information
yucai-intel committed Sep 5, 2024
1 parent b1087b8 commit b047ddd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/ATen/native/xpu/sycl/MaxUnpoolingKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/xpu/sycl/Atomics.h>
#include <ATen/native/xpu/sycl/BatchKernel.h>
#include <ATen/native/xpu/sycl/NumericLimits.h>
#include <c10/core/Scalar.h>
#include <c10/util/Exception.h>
#include <comm/MemoryFormat.h>
#include <comm/SYCLHelpers.h>

using namespace dnnl;
using namespace at::native;
using namespace at::native::onednn;
#include <torch/library.h>
#include <optional>

namespace at::native::xpu {

Expand Down
18 changes: 18 additions & 0 deletions src/comm/MemoryFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,23 @@ inline bool is_smf_channels_last(const Tensor& t) {
return is_channels_last(suggest_memory_format_sycl(t));
}

inline MemoryFormat get_cl_tag_by_ndim(const int64_t ndim) {
TORCH_CHECK(
3 == ndim || 4 == ndim || 5 == ndim,
"ndim must be 3, 4 or 5 when get cl tag");
if (3 == ndim) {
#ifdef USE_CHANNELS_LAST_1D
return CHANNELSLAST1D_SYCL;
#else
// if doesn't enable channels last 1d
return at::MemoryFormat::Contiguous;
#endif
} else if (5 == ndim) {
return at::MemoryFormat::ChannelsLast3d;
} else {
return at::MemoryFormat::ChannelsLast;
}
}

} // namespace sycl
} // namespace xpu

0 comments on commit b047ddd

Please sign in to comment.