From 8feb310c7538552e0b6c5e9b64b13edaf685a422 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 12 Oct 2019 08:44:48 +0800 Subject: [PATCH] Fix mkldnn reshape Change-Id: I7e3ed84d02eaac4dbc0637167519e53e7eb8e168 --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 1 - src/operator/nn/mkldnn/mkldnn_base.cc | 6 +- src/operator/nn/mkldnn/mkldnn_flatten-inl.h | 48 ------- src/operator/nn/mkldnn/mkldnn_flatten.cc | 79 ------------ src/operator/nn/mkldnn/mkldnn_ops-inl.h | 5 - src/operator/nn/mkldnn/mkldnn_reshape-inl.h | 27 ++-- src/operator/nn/mkldnn/mkldnn_reshape.cc | 118 ++++++------------ .../mkldnn/mkldnn_quantized_flatten.cc | 4 +- src/operator/tensor/matrix_op.cc | 42 ++----- 9 files changed, 69 insertions(+), 261 deletions(-) delete mode 100644 src/operator/nn/mkldnn/mkldnn_flatten-inl.h delete mode 100644 src/operator/nn/mkldnn/mkldnn_flatten.cc diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 9e8725e776e5..961aa8b05a84 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -194,7 +194,6 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); -bool SupportMKLDNNReshape(const ReshapeParam ¶m, const NDArray &data); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 862947eb726a..fca908fc8e39 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -428,6 +428,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states, const std::vector &outputs) { std::vector in_blobs(inputs.size()); std::vector in_bufs; + std::vector new_req = req; for (size_t i = 0; i < in_blobs.size(); i++) { // If the input data isn't stored in the default format, we shouldn't // call data() directly, which will change the layout of the NDArray. @@ -452,6 +453,9 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states, // for inplace, we already converted & copied input above. if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) { const_cast(output).InvalidateMKLDNNData(); + if (req[i] == kWriteInplace) { + new_req[i] = kWriteTo; + } } else if (req[i] == kAddTo && output.IsMKLDNNData()) { NDArray temp = outputs[i].Reorder2Default(); temp_src.emplace_back(temp); @@ -462,7 +466,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states, out_blobs[i] = output.data(); } - fn(attrs_states, ctx, in_blobs, req, out_blobs); + fn(attrs_states, ctx, in_blobs, new_req, out_blobs); for (size_t i = 0; i < out_blobs.size(); i++) { if (req[i] == kAddTo && outputs[i].IsMKLDNNData()) mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false); diff --git a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h deleted file mode 100644 index ae890d8f3d91..000000000000 --- a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file mkldnn_flatten-inl.h - * \brief Implement flatten operator by using mkldnn reorder primitive - * \author Wuxun Zhang - */ - -#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ -#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ -#if MXNET_USE_MKLDNN == 1 - -#include "mkldnn_reshape-inl.h" - -namespace mxnet { -namespace op { - -class MKLDNNFlattenFwd : public MKLDNNReshapeFwd { - public: - explicit MKLDNNFlattenFwd(const OpReqType &req, const NDArray &input, const NDArray &output) - : MKLDNNReshapeFwd(req, input, output) {} -}; - -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const NDArray &input, - const OpReqType &req, const NDArray &output); - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc deleted file mode 100644 index 4090eb026cfc..000000000000 --- a/src/operator/nn/mkldnn/mkldnn_flatten.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file mkldnn_flatten.cc - * \brief Implement flatten operator by using mkldnn reorder primitive - * \author Wuxun Zhang -*/ - -#if MXNET_USE_MKLDNN == 1 - -#include "mkldnn_flatten-inl.h" - -namespace mxnet { -namespace op { - -static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req, - const NDArray &input, - const NDArray &output) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; -#else - static MX_THREAD_LOCAL std::unordered_map fwds; -#endif - OpSignature key; - key.AddSign(req); - key.AddSign(input); - - auto it = fwds.find(key); - if (it == fwds.end()) { - MKLDNNFlattenFwd fwd(req, input, output); - it = AddToCache(&fwds, key, fwd); - } - return it->second; -} - -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output) { - if (req == kNullOp) return; - CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; - - auto fwd = GetFlattenForward(req, input, output); - auto ws_size = fwd.GetWorkspaceSize(); - void* ws_ptr = nullptr; - if (ws_size) { - mshadow::Stream *s = ctx.get_stream(); - mshadow::Tensor ws = ctx.requested[0] - .get_space_typed(mshadow::Shape1(ws_size), s); - ws_ptr = reinterpret_cast(ws.dptr_); - } - - fwd.Execute(input, output, ws_ptr); -} - -} // namespace op -} // namespace mxnet - -#endif diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index b564a3318402..c0218f4100b5 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -131,11 +131,6 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const OpReqType &req, const NDArray &output); -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h index 63e367b4dc7f..726d72156718 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h @@ -35,30 +35,21 @@ namespace mxnet { namespace op { class MKLDNNReshapeFwd { - protected: + public: + MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output); + int GetWorkspaceSize(); + void SetNewMem(const NDArray &input, const NDArray &output, void *workspace = nullptr); + void Execute(const NDArray &input, const NDArray &output, void *workspace = nullptr); + + private: std::shared_ptr data_; std::shared_ptr out_; std::shared_ptr temp_; std::vector prims_; - bool needInvalidateInput = false; - - public: - MKLDNNReshapeFwd(const OpReqType &req, - const NDArray &input, - const NDArray &output); - int GetWorkspaceSize(); - void SetNewMem(const NDArray &input, - const NDArray &output, - void* workspace = nullptr); - void Execute(const NDArray &input, - const NDArray &output, - void* workspace = nullptr); }; -typedef ParamOpSign MKLDNNReshapeSignature; -MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, - const OpReqType &req, - const NDArray &input, +typedef OpSignature MKLDNNReshapeSignature; +MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, const NDArray &input, const NDArray &output); } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index 063c85dae39a..9c226a052b0b 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -24,65 +24,37 @@ */ #if MXNET_USE_MKLDNN == 1 - -#include -#include "mkldnn_reshape-inl.h" +#include "../../tensor/elemwise_unary_op.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" +#include "./mkldnn_reshape-inl.h" namespace mxnet { namespace op { -bool SupportMKLDNNReshape(const ReshapeParam ¶m, - const NDArray &data) { - auto data_ndim = data.shape().ndim(); - - if (data_ndim > 4 || - data.dtype() != mshadow::kFloat32 || - param.shape.ndim() > 4) - return false; - - return true; -} - -MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, - const NDArray &input, +MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output) { - auto engine = CpuEngine::Get()->get_engine(); - - // data_ - auto in_mem = input.GetMKLDNNData(); - auto in_pd = in_mem->get_primitive_desc(); - data_ = std::make_shared(in_pd, nullptr); - - // temp_ + const auto engine = CpuEngine::Get()->get_engine(); + data_ = std::make_shared(input.GetMKLDNNData()->get_primitive_desc(), nullptr); + // Create temp memory auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); - auto temp_type = static_cast(in_pd.desc().data.data_type); - auto temp_fmt = static_cast(GetDefaultFormat(in_pd.desc())); + auto temp_type = static_cast(get_mkldnn_type(input.dtype())); + auto temp_fmt = static_cast(GetDefaultFormat(input.shape().ndim())); auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine); - temp_ = std::make_shared(temp_pd, nullptr); - - // destination out_ = std::make_shared(temp_pd, nullptr); - if (req == kWriteInplace) { // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with // default layout and copy from the temporal buffer back to output buffer which has the same // address with input buffer. // If the input has default layout, then nothing need to do. if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back - needInvalidateInput = true; + temp_ = std::make_shared(temp_pd, nullptr); + prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default + prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back } } else if (req == kWriteTo) { - if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer - needInvalidateInput = false; - } else { - prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output - needInvalidateInput = false; - } + prims_.push_back(mkldnn::reorder(*data_, *out_)); } else { LOG(FATAL) << "not supported req type: " << req; } @@ -95,22 +67,8 @@ int MKLDNNReshapeFwd::GetWorkspaceSize() { void MKLDNNReshapeFwd::SetNewMem(const NDArray &input, const NDArray &output, void* workspace) { - if (input.IsMKLDNNData()) { - this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle()); - } else { - MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, { - this->data_->set_data_handle(input.data().dptr()); - }) - } - - if (output.IsMKLDNNData()) { - this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle()); - } else { - MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, { - this->out_->set_data_handle(output.data().dptr()); - }) - } - + this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle()); + this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle()); if (workspace) { this->temp_->set_data_handle(workspace); } @@ -119,22 +77,21 @@ void MKLDNNReshapeFwd::SetNewMem(const NDArray &input, void MKLDNNReshapeFwd::Execute(const NDArray &input, const NDArray &output, void* workspace) { - // set memory handles - SetNewMem(input, output, workspace); - // register primitives - auto stream = MKLDNNStream::Get(); - for (auto &v : this->prims_) { - stream->RegisterPrim(v); - } - stream->Submit(); - // invalidate mkldnn memory in input - if (needInvalidateInput) { - const_cast(input).InvalidateMKLDNNData(); + if (this->prims_.size()) { + // set memory handles + SetNewMem(input, output, workspace); + // register primitives + auto stream = MKLDNNStream::Get(); + for (auto &v : this->prims_) { + stream->RegisterPrim(v); + } + stream->Submit(); } + // invalidate mkldnn memory in output + const_cast(output).InvalidateMKLDNNData(); } -MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, - const OpReqType &req, +MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, const NDArray &input, const NDArray &output) { #if DMLC_CXX11_THREAD_LOCAL @@ -144,10 +101,9 @@ MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, static MX_THREAD_LOCAL std::unordered_map fwds; #endif - MKLDNNReshapeSignature key(param); + MKLDNNReshapeSignature key; key.AddSign(req); key.AddSign(input); - key.AddSign(output); auto it = fwds.find(key); if (it == fwds.end()) { @@ -162,22 +118,28 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const NDArray &input, const OpReqType &req, const NDArray &output) { - const ReshapeParam& param = nnvm::get(attrs.parsed); + // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to + // naive implement. + if (input.shape().ndim() > 4 || !SupportMKLDNNQuantize(input.dtype())) { + if (req != kWriteInplace) { + FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, {input}, {req}, {output}); + } + return; + } if (req == kNullOp) return; CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; - - auto fwd = GetReshapeForward(param, req, input, output); + auto fwd = GetReshapeForward(req, input, output); auto ws_size = fwd.GetWorkspaceSize(); void* ws_ptr = nullptr; if (ws_size) { mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor ws = ctx.requested[0] .get_space_typed(mshadow::Shape1(ws_size), s); - ws_ptr = reinterpret_cast(ws.dptr_); + ws_ptr = static_cast(ws.dptr_); } - fwd.Execute(input, output, ws_ptr); } + } // namespace op } // namespace mxnet #endif diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc index 31da936915e6..c059f9868ea0 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc @@ -24,7 +24,7 @@ */ #if MXNET_USE_MKLDNN == 1 -#include "../../nn/mkldnn/mkldnn_flatten-inl.h" +#include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../quantization_utils.h" namespace mxnet { @@ -42,7 +42,7 @@ static void MKLDNNQuantizedFlattenForward(const nnvm::NodeAttrs& attrs, const Op const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); outputs[1].data().dptr()[0] = inputs[1].data().dptr()[0]; outputs[2].data().dptr()[0] = inputs[2].data().dptr()[0]; } diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index c60402488b65..ed4afc0cbcbc 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -111,17 +111,12 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - const ReshapeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); // If inputs are supposed to be in MKLDNN format and // MKLDNNsupport the data type or the shape. Then convert // it to the output format and shape - if (SupportMKLDNNReshape(param, inputs[0])) { - MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); - return; - } - FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); } inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs, @@ -226,6 +221,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from .add_argument("data", "NDArray-or-Symbol", "Input data to reshape.") .add_arguments(ReshapeParam::__FIELDS__()); +#if MXNET_USE_MKLDNN == 1 static void FlattenEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -233,22 +229,12 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); -#if MXNET_USE_MKLDNN == 1 - auto data_ndim = inputs[0].shape().ndim(); - if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) { - MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]); - return; - } else { - // This happens if inputs are supposed to be in MKLDNN format - // but MKLDNN doesn't support the data type or the shape. We're - // forced to convert it to the default format. - FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); - return; - } -#endif + // If inputs are supposed to be in MKLDNN format and + // MKLDNNsupport the data type or the shape. Then convert + // it to the output format and shape + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); } -#if MXNET_USE_MKLDNN == 1 static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -294,14 +280,12 @@ Example:: .set_num_outputs(1) .set_attr("FInferShape", FlattenShape) .set_attr("FInferType", ElemwiseType<1, 1>) -#if MXNET_USE_MKLDNN == 1 -.set_attr("FInferStorageType", FlattenStorageType) -#endif .set_attr("FGradient", ElemwiseGradUseNone{ "_backward_copy" }) .set_attr("FCompute", UnaryOp::IdentityCompute) -.set_attr("FComputeEx", FlattenEx) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", FlattenEx) +.set_attr("FInferStorageType", FlattenStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -1033,7 +1017,7 @@ NNVM_REGISTER_OP(depth_to_space) .describe(R"code(Rearranges(permutes) data from depth into blocks of spatial data. Similar to ONNX DepthToSpace operator: https://github.com/onnx/onnx/blob/master/docs/Operators.md#DepthToSpace. -The output is a new tensor where the values from depth dimension are moved in spatial blocks +The output is a new tensor where the values from depth dimension are moved in spatial blocks to height and width dimension. The reverse of this operation is ``space_to_depth``. .. math:: @@ -1044,7 +1028,7 @@ to height and width dimension. The reverse of this operation is ``space_to_depth y = reshape(x \prime \prime, [N, C / (block\_size ^ 2), H * block\_size, W * block\_size]) \end{gather*} -where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] +where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] and :math:`y` is the output tensor of layout :math:`[N, C / (block\_size ^ 2), H * block\_size, W * block\_size]` Example:: @@ -1084,9 +1068,9 @@ Example:: NNVM_REGISTER_OP(space_to_depth) .describe(R"code(Rearranges(permutes) blocks of spatial data into depth. Similar to ONNX SpaceToDepth operator: -https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth +https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth -The output is a new tensor where the values from height and width dimension are +The output is a new tensor where the values from height and width dimension are moved to the depth dimension. The reverse of this operation is ``depth_to_space``. .. math:: @@ -1097,7 +1081,7 @@ moved to the depth dimension. The reverse of this operation is ``depth_to_space` y = reshape(x \prime \prime, [N, C * (block\_size ^ 2), H / block\_size, W / block\_size]) \end{gather*} -where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] +where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width] and :math:`y` is the output tensor of layout :math:`[N, C * (block\_size ^ 2), H / block\_size, W / block\_size]` Example::