diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 6a9cbae002..6ea3fce402 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -872,24 +872,22 @@ LogicalResult ONNXGroupNormalizationCommon( //"numgroups" and "C" should have the same dimension index llvm::SmallVector axesList, biasScaleVal; - if constexpr - scaleAndBiasWithNumGroupShape() { - // Opset18 Uses "numgroups" the number of groups of channels for the scale - // and bias - // Unsqueeze scale/bias from [NG] to [1 x NG x 1 x ... x 1] with numInNorm - // 1s. - biasScaleVal.emplace_back(numGroups); - for (int64_t i = 1; i <= numInNorm; ++i) { - biasScaleVal.emplace_back(1); - axesList.emplace_back(i); - } - - axes = create.onnx.constantInt64(axesList); - biasScaleType = RankedTensorType::get(biasScaleVal, elementType); - newScale = create.onnx.unsqueeze(biasScaleType, scale, axes); - newBias = create.onnx.unsqueeze(biasScaleType, bias, axes); + if constexpr (scaleAndBiasWithNumGroupShape) { + // Opset18 Uses "numgroups" the number of groups of channels for the scale + // and bias + // Unsqueeze scale/bias from [NG] to [1 x NG x 1 x ... x 1] with numInNorm + // 1s. + biasScaleVal.emplace_back(numGroups); + for (int64_t i = 1; i <= numInNorm; ++i) { + biasScaleVal.emplace_back(1); + axesList.emplace_back(i); } - else { + + axes = create.onnx.constantInt64(axesList); + biasScaleType = RankedTensorType::get(biasScaleVal, elementType); + newScale = create.onnx.unsqueeze(biasScaleType, scale, axes); + newBias = create.onnx.unsqueeze(biasScaleType, bias, axes); + } else { // Opset21 Uses "C" the number of channels for the scale and bias // The equivalent of "C" when split is "NG x C/NG" // Reshape scale/bias from [C] to [NG x C/NG x 1 x ... x 1] with numInNorm