Skip to content

Commit

Permalink
Added improvements for StridedSlice
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jul 27, 2021
1 parent 9acedbd commit 6359fd8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ void MKLDNNStridedSliceNode::createPrimitive() {
auto srcOrder = srcBlockingDesc.getOrder();
params.srcDims = srcBlockingDesc.getBlockDims();
params.dstDims = dstBlockingDesc.getBlockDims();
params.srcMemPtr = srcMemPtr;
params.dstMemPtr = dstMemPtr;
params.dataSize = getSelectedPrimitiveDescriptor()->getConfig().inConfs[DATA_ID].desc.getPrecision().size();

if (params.parametersAreConstant) {
Expand All @@ -280,9 +282,7 @@ void MKLDNNStridedSliceNode::createPrimitive() {
SizeVector newSrcDims, newDstDims;
dimsNormalization(newSrcDims, newDstDims);
dimsGluing(realNDims, newSrcDims, newDstDims);

if (params.dstDims.size() == 1 || params.nDimsForWork != 1)
indicesCalculation();
indicesCalculation();
}
}

Expand Down Expand Up @@ -507,13 +507,30 @@ void MKLDNNStridedSliceNode::dimsGluing(const size_t realNDims, const SizeVector
if (params.dstDims.size() > 2)
params.lastDstDim /= newDstDims[secondDim.first];
}

// some parameter calculations for common execution
params.isOptimized = params.nDimsForWork == 1 && params.dstDims.size() > 1;
if (params.isOptimized) {
if (params.dstDims.size() == 2)
params.dstDims[1] = 1;

params.workAmount = params.dstDims[0] * params.dstDims[1];
params.srcShift = (begin[0] * params.srcStrides[0] + begin[1] * params.srcStrides[1]) * params.dataSize;
} else {
params.srcShift = stride.back() == 1 && stride.size() > 1 ?
begin[params.nDimsForWork] * params.srcStrides[params.nDimsForWork] * params.dataSize : 0;
}
}

void MKLDNNStridedSliceNode::indicesCalculation() {
// indices calculation before execution for the best performance
params.nThreads = parallel_get_max_threads();
params.srcIndices.resize(params.workAmount, 0);
params.dstIndices.resize(params.workAmount, 0);
if (params.isOptimized) {
indicesCalculationForOptimized();
return;
}

auto getSrcIdx = [this](const SizeVector& indexes){
size_t srcIdx = 0;
Expand All @@ -539,10 +556,10 @@ void MKLDNNStridedSliceNode::indicesCalculation() {
if (coords[k] < params.dstDims[k]) {
srcIdx += stride[k] * params.srcStrides[k] * params.dataSize;
break;
} else {
coords[k] = 0;
out = true;
}

coords[k] = 0;
out = true;
}

if (out)
Expand All @@ -551,6 +568,25 @@ void MKLDNNStridedSliceNode::indicesCalculation() {
});
}

void MKLDNNStridedSliceNode::indicesCalculationForOptimized() {
const size_t dstIdx0 = params.dstStrides[0] * params.dataSize;
const size_t dstIdx1 = params.dstStrides[1] * params.dataSize;
const size_t srcIdx0 = stride[0] * params.srcStrides[0] * params.dataSize;
const size_t srcIdx1 = stride[1] * params.srcStrides[1] * params.dataSize;

for (size_t i0 = 0; i0 < params.dstDims[0]; i0++) {
const size_t idx = i0 * params.dstDims[1];

params.dstIndices[idx] = i0 * dstIdx0;
params.srcIndices[idx] = i0 * srcIdx0;

for (size_t i1 = 1; i1 < params.dstDims[1]; i1++) {
params.dstIndices[idx + i1] = params.dstIndices[idx] + i1 * dstIdx1;
params.srcIndices[idx + i1] = params.srcIndices[idx] + i1 * srcIdx1;
}
}
}

void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
if (!params.parametersAreConstant) {
auto srcDims = getParentEdgeAt(DATA_ID)->getDims();
Expand Down Expand Up @@ -583,42 +619,15 @@ void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
SizeVector newSrcDims, newDstDims;
dimsNormalization(newSrcDims, newDstDims);
dimsGluing(dstDims.ndims(), newSrcDims, newDstDims);

if (params.dstDims.size() == 1 || params.nDimsForWork != 1)
indicesCalculation();
indicesCalculation();
}

if (params.dstDims.size() > 1 && params.nDimsForWork == 1)
stridedSliceV();
else
stridedSlice();
}

void MKLDNNStridedSliceNode::stridedSliceV() {
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(this->getParentEdgeAt(DATA_ID)->getMemoryPtr()->GetPtr()) +
(begin[0] * params.srcStrides[0] + begin[1] * params.srcStrides[1]) * params.dataSize;
uint8_t* dstData = reinterpret_cast<uint8_t*>(this->getChildEdgeAt(0)->getMemoryPtr()->GetPtr());

const size_t dstIdx = params.dstStrides[0] * params.dataSize;
const size_t srcIdx = stride[0] * params.srcStrides[0] * params.dataSize;
const size_t dstShift = params.dstStrides[1] * params.dataSize;
const size_t srcShift = stride[1] * params.srcStrides[1] * params.dataSize;

if (params.dstDims.size() > 2) {
parallel_for2d(params.dstDims[0], params.dstDims[1], [&](const size_t i, const size_t j) {
cpu_memcpy(&dstData[i * dstIdx + j * dstShift], &srcData[i * srcIdx + j * srcShift], params.lastDstDim);
});
} else {
parallel_for(params.dstDims[0], [&](const size_t i) {
cpu_memcpy(&dstData[i * dstIdx], &srcData[i * srcIdx], params.lastDstDim);
});
}
stridedSlice();
}

void MKLDNNStridedSliceNode::stridedSlice() {
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(this->getParentEdgeAt(DATA_ID)->getMemoryPtr()->GetPtr()) +
(stride.back() == 1 && stride.size() > 1 ? begin[params.nDimsForWork] * params.srcStrides[params.nDimsForWork] * params.dataSize : 0);
uint8_t* dstData = reinterpret_cast<uint8_t*>(this->getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
inline void MKLDNNStridedSliceNode::stridedSlice() {
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(params.srcMemPtr->GetPtr()) + params.srcShift;
uint8_t* dstData = reinterpret_cast<uint8_t*>(params.dstMemPtr->GetPtr());

parallel_nt(params.nThreads, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class MKLDNNStridedSliceNode : public MKLDNNNode {
static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;

private:
void stridedSliceV();
void stridedSlice();
inline void stridedSlice();

void addHiddenDims(const size_t nSrcDims);
void orderParametersByLayouts();
void dimsNormalization(InferenceEngine::SizeVector& newSrcDims, InferenceEngine::SizeVector& newDstDims);
void dimsGluing(const size_t realNDims, const InferenceEngine::SizeVector& newSrcDims, const InferenceEngine::SizeVector& newDstDims);
void indicesCalculation();
void indicesCalculationForOptimized();

const size_t DATA_ID = 0;
const size_t BEGIN_ID = 1;
Expand All @@ -56,6 +56,8 @@ class MKLDNNStridedSliceNode : public MKLDNNNode {
InferenceEngine::SizeVector strideDims;

struct {
MKLDNNMemoryPtr srcMemPtr = nullptr;
MKLDNNMemoryPtr dstMemPtr = nullptr;
InferenceEngine::SizeVector srcDims;
InferenceEngine::SizeVector dstDims;
InferenceEngine::SizeVector srcStrides;
Expand All @@ -69,6 +71,8 @@ class MKLDNNStridedSliceNode : public MKLDNNNode {
size_t workAmount = 0;
size_t lastDstDim = 0;
size_t dataSize = 0;
size_t srcShift = 0;
bool isOptimized = false;
bool equalDims = false;
bool parametersAreConstant = true;
} params;
Expand Down

0 comments on commit 6359fd8

Please sign in to comment.