Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Added improvements for StridedSlice #6658

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ void MKLDNNStridedSliceNode::createPrimitive() {
auto srcOrder = srcBlockingDesc.getOrder();
params.srcDims = srcBlockingDesc.getBlockDims();
params.dstDims = dstBlockingDesc.getBlockDims();
params.srcMemPtr = srcMemPtr;
params.dstMemPtr = dstMemPtr;
params.dataSize = getSelectedPrimitiveDescriptor()->getConfig().inConfs[DATA_ID].desc->getPrecision().size();

if (params.parametersAreConstant) {
Expand All @@ -282,9 +284,7 @@ void MKLDNNStridedSliceNode::createPrimitive() {
SizeVector newSrcDims, newDstDims;
dimsNormalization(newSrcDims, newDstDims);
dimsGluing(realNDims, newSrcDims, newDstDims);

if (params.dstDims.size() == 1 || params.nDimsForWork != 1)
indicesCalculation();
indicesCalculation();
}
}

Expand Down Expand Up @@ -510,14 +510,35 @@ void MKLDNNStridedSliceNode::dimsGluing(const size_t realNDims, const SizeVector
if (params.dstDims.size() > 2)
params.lastDstDim /= newDstDims[secondDim.first];
}

// some parameter calculations for common execution
params.isOptimized = params.nDimsForWork == 1 && params.dstDims.size() > 1;
if (params.isOptimized) {
if (params.dstDims.size() == 2)
params.dstDims[1] = 1;

params.workAmount = params.dstDims[0] * params.dstDims[1];
params.srcShift = (begin[0] * params.srcStrides[0] + begin[1] * params.srcStrides[1]) * params.dataSize;
} else {
params.srcShift = stride.back() == 1 && stride.size() > 1 ?
begin[params.nDimsForWork] * params.srcStrides[params.nDimsForWork] * params.dataSize : 0;
}
}

void MKLDNNStridedSliceNode::indicesCalculation() {
// indices calculation before execution for the best performance
params.nThreads = parallel_get_max_threads();
params.srcIndices.resize(params.workAmount, 0);
params.dstIndices.resize(params.workAmount, 0);

// should choose more optimal thread count
const size_t nthr = parallel_get_max_threads();
params.nThreads = nthr > params.workAmount ? params.workAmount : nthr;

if (params.isOptimized) {
indicesCalculationForOptimized();
return;
}

auto getSrcIdx = [this](const SizeVector& indexes){
size_t srcIdx = 0;
for (int i = 0; i < params.nDimsForWork; ++i)
Expand All @@ -542,10 +563,10 @@ void MKLDNNStridedSliceNode::indicesCalculation() {
if (coords[k] < params.dstDims[k]) {
srcIdx += stride[k] * params.srcStrides[k] * params.dataSize;
break;
} else {
coords[k] = 0;
out = true;
}

coords[k] = 0;
out = true;
}

if (out)
Expand All @@ -554,6 +575,25 @@ void MKLDNNStridedSliceNode::indicesCalculation() {
});
}

void MKLDNNStridedSliceNode::indicesCalculationForOptimized() {
const size_t dstIdx0 = params.dstStrides[0] * params.dataSize;
const size_t dstIdx1 = params.dstStrides[1] * params.dataSize;
const size_t srcIdx0 = stride[0] * params.srcStrides[0] * params.dataSize;
const size_t srcIdx1 = stride[1] * params.srcStrides[1] * params.dataSize;

for (size_t i0 = 0; i0 < params.dstDims[0]; i0++) {
const size_t idx = i0 * params.dstDims[1];

params.dstIndices[idx] = i0 * dstIdx0;
params.srcIndices[idx] = i0 * srcIdx0;

for (size_t i1 = 1; i1 < params.dstDims[1]; i1++) {
params.dstIndices[idx + i1] = params.dstIndices[idx] + i1 * dstIdx1;
params.srcIndices[idx + i1] = params.srcIndices[idx] + i1 * srcIdx1;
}
}
}

void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
if (!params.parametersAreConstant) {
auto srcDims = getParentEdgeAt(DATA_ID)->getShape().getStaticDims();
Expand Down Expand Up @@ -586,42 +626,15 @@ void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
SizeVector newSrcDims, newDstDims;
dimsNormalization(newSrcDims, newDstDims);
dimsGluing(dstDims.size(), newSrcDims, newDstDims);

if (params.dstDims.size() == 1 || params.nDimsForWork != 1)
indicesCalculation();
indicesCalculation();
}

if (params.dstDims.size() > 1 && params.nDimsForWork == 1)
stridedSliceV();
else
stridedSlice();
}

void MKLDNNStridedSliceNode::stridedSliceV() {
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(this->getParentEdgeAt(DATA_ID)->getMemoryPtr()->GetPtr()) +
(begin[0] * params.srcStrides[0] + begin[1] * params.srcStrides[1]) * params.dataSize;
uint8_t* dstData = reinterpret_cast<uint8_t*>(this->getChildEdgeAt(0)->getMemoryPtr()->GetPtr());

const size_t dstIdx = params.dstStrides[0] * params.dataSize;
const size_t srcIdx = stride[0] * params.srcStrides[0] * params.dataSize;
const size_t dstShift = params.dstStrides[1] * params.dataSize;
const size_t srcShift = stride[1] * params.srcStrides[1] * params.dataSize;

if (params.dstDims.size() > 2) {
parallel_for2d(params.dstDims[0], params.dstDims[1], [&](const size_t i, const size_t j) {
cpu_memcpy(&dstData[i * dstIdx + j * dstShift], &srcData[i * srcIdx + j * srcShift], params.lastDstDim);
});
} else {
parallel_for(params.dstDims[0], [&](const size_t i) {
cpu_memcpy(&dstData[i * dstIdx], &srcData[i * srcIdx], params.lastDstDim);
});
}
stridedSlice();
}

void MKLDNNStridedSliceNode::stridedSlice() {
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(this->getParentEdgeAt(DATA_ID)->getMemoryPtr()->GetPtr()) +
(stride.back() == 1 && stride.size() > 1 ? begin[params.nDimsForWork] * params.srcStrides[params.nDimsForWork] * params.dataSize : 0);
uint8_t* dstData = reinterpret_cast<uint8_t*>(this->getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
inline void MKLDNNStridedSliceNode::stridedSlice() {
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(params.srcMemPtr->GetPtr()) + params.srcShift;
uint8_t* dstData = reinterpret_cast<uint8_t*>(params.dstMemPtr->GetPtr());

parallel_nt(params.nThreads, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class MKLDNNStridedSliceNode : public MKLDNNNode {
static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;

private:
void stridedSliceV();
void stridedSlice();
inline void stridedSlice();

void addHiddenDims(const size_t nSrcDims);
void orderParametersByLayouts();
void dimsNormalization(InferenceEngine::SizeVector& newSrcDims, InferenceEngine::SizeVector& newDstDims);
void dimsGluing(const size_t realNDims, const InferenceEngine::SizeVector& newSrcDims, const InferenceEngine::SizeVector& newDstDims);
void indicesCalculation();
void indicesCalculationForOptimized();

const size_t DATA_ID = 0;
const size_t BEGIN_ID = 1;
Expand All @@ -56,6 +56,8 @@ class MKLDNNStridedSliceNode : public MKLDNNNode {
InferenceEngine::SizeVector strideDims;

struct {
MKLDNNMemoryPtr srcMemPtr = nullptr;
MKLDNNMemoryPtr dstMemPtr = nullptr;
InferenceEngine::SizeVector srcDims;
InferenceEngine::SizeVector dstDims;
InferenceEngine::SizeVector srcStrides;
Expand All @@ -69,6 +71,8 @@ class MKLDNNStridedSliceNode : public MKLDNNNode {
size_t workAmount = 0;
size_t lastDstDim = 0;
size_t dataSize = 0;
size_t srcShift = 0;
bool isOptimized = false;
bool equalDims = false;
bool parametersAreConstant = true;
} params;
Expand Down