Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNA] KSOFunction test fix #6678

Merged
merged 2 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,12 +752,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<FuseFQIntoWeightsPass>();
passes->registerPass<MoveFakeQuantizeLayerIntoQuantParamsPass>();

passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<BroadcastConstPass>();

passes->registerPass<TransposeWeightsFromNCHWToNHWCPass>();

passes->registerPass<SubstitutePReluPass>();
passes->registerPass<SubstituteSoftSignPass>();

passes->registerPass<BroadcastConstPass>();
passes->registerPass<ReorderMaxPoolPass>();
passes->registerPass<EltwiseSplitOverChannelsPass>();
passes->registerPass<InsertSplitAligningFilterPass>();
Expand All @@ -775,7 +777,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
#if GNA_LIB_VER == 2
passes->registerPass<ForbidActivationFusingPass>();
#endif
passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<FuseMultipleIdentitiesPass>();
passIdx = passes->run(passIdx);
};
Expand Down
48 changes: 22 additions & 26 deletions inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,16 +1530,7 @@ void SubstituteScaleShiftBroadCastPass::run() {
continue;
}

// only 3d scaleshift supported where number of c is arbitrary
auto lastD = reshape_batch ? dataDims[1] : dataDims.back();
if (lastD != weightsElements) {
THROW_GNA_EXCEPTION << "Unsupported layer: " << l->name
<< " should have last dim(" << lastD << ") equal to weights(" << weightsElements << ") length";
}
if (dataDims.size() == 2) {
THROW_GNA_EXCEPTION << "For layer: " << l->name
<< " weights size(" << weightsElements<< ") invalid: should match input size of(" << lastD << ")";
}
// TODO: add broadcasting rules checks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add those checks right now?


gnalog() << "Substitution ScaleShift broadcast for layer: " << l->name << "\n";
if (nElements % scaleShift->_weights->size()) {
Expand Down Expand Up @@ -2220,6 +2211,17 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
}
};

auto transpInfoMatchWeightsSize = [](const std::vector<TranspositionInfo> &transpositionInfo, size_t weightsSize, const std::string &layerName) {
size_t totalElements = 0;
for (auto && transpositionInfoPart : transpositionInfo) {
totalElements += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (totalElements != weightsSize) {
THROW_GNA_EXCEPTION << layerName << " weights elements from transposition info (" << totalElements
<< ") don't match input dimensions (" << weightsSize << ")";
}
};

for (auto &&l : *pLayers) {
if (LayerInfo(l).isScaleShift()) {
std::vector<TranspositionInfo> transpositionInfo;
Expand All @@ -2237,6 +2239,10 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
}
auto weightable = dynamic_cast<WeightableLayer*>(l.get());
IE_ASSERT(weightable != nullptr);

size_t totalWeights = weightable->_weights->size();
transpInfoMatchWeightsSize(transpositionInfo, totalWeights, l->name);

ConvertTensorFromNCHWToNHWC(weightable->precision.size(), 1, weightable->_weights->size(),
weightable->_weights->cbuffer().as<uint8_t*>(), true, transpositionInfo);
if (weightable->_biases) {
Expand Down Expand Up @@ -2270,14 +2276,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
// If we found a split it's not possible to rotate data
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a split before it";
}
size_t totalColumns = 0;
for (auto && transpositionInfoPart : transpositionInfo) {
totalColumns += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (weightsColumns != totalColumns) {
THROW_GNA_EXCEPTION << l->name << " weights columns from transposition info (" << totalColumns
<< ") don't match input dimensions (" << weightsColumns << ")";
}

transpInfoMatchWeightsSize(transpositionInfo, weightsColumns, l->name);

ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
true, transpositionInfo);
gnalog() << l->name << " weights rows transposition info:\n";
Expand All @@ -2297,14 +2298,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
// If we found a concat it's not possible to rotate data
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a concat after it";
}
size_t totalRows = 0;
for (const auto& transpositionInfoPart : transpositionInfo) {
totalRows += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (weightsRows != totalRows) {
THROW_GNA_EXCEPTION << l->name << " weights rows from transposition info (" << totalRows
<< ") don't match output dimensions (" << weightsRows << ")";
}

transpInfoMatchWeightsSize(transpositionInfo, weightsRows, l->name);

ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
false, transpositionInfo);
gnalog() << l->name << " weights columns transposition info:\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*ConstantResultSubgraphTest.*inPrc=(U8|I8|I32|U64|I64|BOOL).*)",
// TODO: Issue 51528
R"(.*CachingSupport.*_(u8|i16)_.*)",
// TODO: Issue 51525
R"(.*CachingSupport.*KSOFunction.*)",
// TODO: Issue 57363 (Param -> Result subgraphs)
R"(.*smoke_MemoryTest.*LOW_LATENCY.*iteration_count=1_.*)",
// TODO: Issue 57368 (accuracy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ const std::vector<std::vector<std::vector<size_t>>> shapes = {
{{1, 64}, {64, 1}},
{{8, 256}, {16, 128}},
{{6, 384}, {18, 128}},
{{8, 2048}, {32, 512}}
{{8, 2048}, {32, 512}},
{{2, 4, 64, 64}, {1, 8, 64, 64}}
};

const std::vector<InferenceEngine::Precision> netPrecisions = {
Expand Down