Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Optimize MemoryInput/Output for empty shapes #27015

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ std::vector<EdgePtr> Node::getChildEdgesAtPort(int inputNum) const {
if (!edge)
OPENVINO_THROW("Node ", getName(), " contains dead weak ptr");
if (edge->getInputNum() == inputNum)
res.push_back(edge);
res.emplace_back(std::move(edge));
}
return res;
}
Expand Down Expand Up @@ -793,11 +793,10 @@ void Node::redefineOutputMemory(const std::vector<VectorDims> &newOutputShapes)
void Node::redefineOutputMemory(const size_t port, const VectorDims& new_output_shape) {
const auto edges = getChildEdgesAtPort(port);

static const VectorDims single_element_shape = {1};

// avoid 0D shape incompatible
auto new_shape = new_output_shape;
if (new_shape.empty()) {
new_shape.push_back(1);
}
const auto& new_shape = new_output_shape.empty() ? single_element_shape : new_output_shape;

const auto& curr_desc = edges[0]->getMemory().getDesc();
if (curr_desc.getShape().isStatic() && curr_desc.getShape().getStaticDims() == new_shape) {
Expand Down
55 changes: 37 additions & 18 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,27 @@ void MemoryOutput::runStatic(dnnl::stream strm) {
void MemoryOutput::runDynamic(dnnl::stream strm) {
//first we have to resize the output memory
auto inputMem = getSrcMemoryAtPort(0);
const auto& newDims = inputMem->getStaticDims();
OPENVINO_ASSERT(extMemDesc,
"MemoryOutput ",
getName(),
" uninitialized assigned memory");

auto newExternDesc = extMemDesc->cloneWithNewDims(newDims);

OPENVINO_ASSERT(assignedMem,
"MemoryOutput ",
getName(),
" uninitialized assigned memory");
assignedMem->redefineDesc(newExternDesc);

runStatic(strm);
const auto& newShape = inputMem->getShape();
const auto& stateShape = assignedMem->getShape();

if (stateShape.isDynamic() || stateShape.getStaticDims() != newShape.getStaticDims()) {
OPENVINO_ASSERT(extMemDesc,
"MemoryOutput ",
getName(),
" uninitialized assigned memory");
auto newExternDesc = extMemDesc->cloneWithNewDims(newShape.getStaticDims());
assignedMem->redefineDesc(newExternDesc);
}

if (!newShape.hasZeroDims()) { // no need to copy data for empty tensor
runStatic(strm);
}
}

bool MemoryOutputStub::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
Expand Down Expand Up @@ -593,31 +599,44 @@ void MemoryInput::runDynamic(dnnl::stream strm) {
getName(),
" assigned state has null memory ptr");

// check whether we can share memory block
const auto& stateDims = assignedMem->getStaticDims();
const bool hasZeroDims = std::count(std::begin(stateDims), std::end(stateDims), 0) > 0;
auto internDesc = getBaseMemDescAtOutputPort(0)->cloneWithNewDims(stateDims, hasZeroDims);

OPENVINO_ASSERT(memBlock,
"MemoryInput ",
getName(),
" has uninitialized memory block.");

// check whether we can share memory block
const auto& shape = assignedMem->getShape();
const bool hasZeroDims = shape.hasZeroDims();
const bool processInitGraph = needInitGraphProcessing();
const auto& stateDims = shape.getStaticDims();

if (hasZeroDims && !processInitGraph) {
// fast track as we don't really need to share memory and transfer any data for empty tensors
memBlock->reset();
redefineOutputMemory(0, stateDims);
return;
}

auto dst = getDstMemoryAtPort(0);
auto currentOutputDesc = dst->getDescPtr();

auto internDesc = currentOutputDesc->isDefined() && (currentOutputDesc->getShape().getStaticDims() == stateDims)
? currentOutputDesc
: getBaseMemDescAtOutputPort(0)->cloneWithNewDims(stateDims, hasZeroDims);

if (internDesc->isCompatible(assignedMem->getDesc())) {
memBlock->setMemBlock(assignedMem->getMemoryBlock());
} else {
memBlock->reset();
}

const bool processInitGraph = needInitGraphProcessing();
//reshape output
const auto& newDims = processInitGraph ? getSrcMemoryAtPort(0)->getStaticDims() : stateDims;

redefineOutputMemory({newDims});
redefineOutputMemory(0, newDims);

//copy data when necessary
auto src = processInitGraph ? getSrcMemoryAtPort(0) : assignedMem;
auto dst = getDstMemoryAtPort(0);
if (src->getData() != dst->getData()) {
dst->load(*src);
}
Expand Down Expand Up @@ -847,6 +866,6 @@ void MemoryInputSDPA::resolveInPlaceEdges(Edge::LOOK look) {
}
}

} // namespace node
} // namespace node
} // namespace intel_cpu
} // namespace ov
Loading