Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix mkldnn reshape #16455

Merged
merged 1 commit into from
Oct 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
bool SupportMKLDNNReshape(const ReshapeParam &param, const NDArray &data);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
6 changes: 5 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states,
const std::vector<NDArray> &outputs) {
std::vector<TBlob> in_blobs(inputs.size());
std::vector<NDArray> in_bufs;
std::vector<OpReqType> new_req = req;
for (size_t i = 0; i < in_blobs.size(); i++) {
// If the input data isn't stored in the default format, we shouldn't
// call data() directly, which will change the layout of the NDArray.
Expand All @@ -452,6 +453,9 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states,
// for inplace, we already converted & copied input above.
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) {
const_cast<NDArray &>(output).InvalidateMKLDNNData();
if (req[i] == kWriteInplace) {
new_req[i] = kWriteTo;
}
} else if (req[i] == kAddTo && output.IsMKLDNNData()) {
NDArray temp = outputs[i].Reorder2Default();
temp_src.emplace_back(temp);
Expand All @@ -462,7 +466,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states,
out_blobs[i] = output.data();
}

fn(attrs_states, ctx, in_blobs, req, out_blobs);
fn(attrs_states, ctx, in_blobs, new_req, out_blobs);
for (size_t i = 0; i < out_blobs.size(); i++) {
if (req[i] == kAddTo && outputs[i].IsMKLDNNData())
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false);
Expand Down
48 changes: 0 additions & 48 deletions src/operator/nn/mkldnn/mkldnn_flatten-inl.h

This file was deleted.

79 changes: 0 additions & 79 deletions src/operator/nn/mkldnn/mkldnn_flatten.cc

This file was deleted.

5 changes: 0 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpReqType &req,
const NDArray &output);

void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
Expand Down
27 changes: 9 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,21 @@ namespace mxnet {
namespace op {

class MKLDNNReshapeFwd {
protected:
public:
MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output);
int GetWorkspaceSize();
void SetNewMem(const NDArray &input, const NDArray &output, void *workspace = nullptr);
void Execute(const NDArray &input, const NDArray &output, void *workspace = nullptr);

private:
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> temp_;
std::vector<mkldnn::primitive> prims_;
bool needInvalidateInput = false;

public:
MKLDNNReshapeFwd(const OpReqType &req,
const NDArray &input,
const NDArray &output);
int GetWorkspaceSize();
void SetNewMem(const NDArray &input,
const NDArray &output,
void* workspace = nullptr);
void Execute(const NDArray &input,
const NDArray &output,
void* workspace = nullptr);
};

typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
const OpReqType &req,
const NDArray &input,
typedef OpSignature MKLDNNReshapeSignature;
MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, const NDArray &input,
const NDArray &output);

} // namespace op
Expand Down
Loading