Skip to content

Commit

Permalink
Update of PHI transpose_grad (#47311)
Browse files Browse the repository at this point in the history
* - halfway transforming transpose grad

- Fixes

- buildable

* - lint

* rerunning the process
  • Loading branch information
jczaja authored Oct 27, 2022
1 parent 77dbb31 commit 493fbfd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 82 deletions.
64 changes: 0 additions & 64 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,69 +1085,5 @@ class ClipOneDNNHandler
to_void_cast<T>(input_data));
}
};
template <typename T>
class TransposeOneDNNHandler {
public:
TransposeOneDNNHandler(const OneDNNContext& dev_ctx,
std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dev_ctx_(dev_ctx),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}

std::shared_ptr<dnnl::memory> AcquireSrcMemory(const OneDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}

auto src_md = fmt != OneDNNMemoryFormat::nchw
? OneDNNMemDesc(dims_, OneDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}

std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output,
Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = dev_ctx_.Alloc<T>(output);
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}

std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}

protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();

std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(nchw_tz, OneDNNGetDataType<T>(), strides);

return mem_d;
}

private:
const OneDNNContext& dev_ctx_;
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
} // namespace funcs
} // namespace phi
34 changes: 16 additions & 18 deletions paddle/phi/kernels/onednn/transpose_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,33 @@ void TransposeGradKernel(const Context& dev_ctx,
if (!x_grad) return;

const auto& onednn_engine = dev_ctx.GetEngine();
std::vector<int> reversed_axis(axis);

if (axis.size() == 1) {
paddle::framework::TensorCopy(out_grad, out_grad.place(), x_grad);
x_grad->set_format(out_grad.format());
x_grad->set_mem_desc(out_grad.mem_desc());
return;
}

for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
std::vector<int64_t> out_grad_tz = vectorize(out_grad.dims());
funcs::ReorderOneDNNHandler reorder_handler(
out_grad_tz,
out_grad.dtype(),
funcs::ToOneDNNDataType(out_grad.dtype()),
onednn_engine);

const T* out_grad_data = out_grad.data<T>();
dev_ctx.template Alloc<T>(x_grad);
auto nchw_tz = vectorize<int64_t>(out_grad.dims());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
out_grad.mem_desc(), funcs::to_void_cast(out_grad.data<T>()));

funcs::TransposeOneDNNHandler<T> handler(
dev_ctx, nchw_tz, reversed_axis, onednn_engine);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
x_grad, out_grad.mem_desc(), dev_ctx.GetPlace());

auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad.format(), funcs::to_void_cast<T>(out_grad_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, dev_ctx.GetPlace());
auto transpose_p =
handler.AcquireTranspose(transpose_dst_memory_p, transpose_src_memory_p);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);

auto& astream = OneDNNContext::tls().get_stream();
transpose_p->execute(
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
x_grad->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(axis));
}

} // namespace phi
Expand Down

0 comments on commit 493fbfd

Please sign in to comment.