Skip to content

Commit

Permalink
[ONNX] Disabled constant folding in DequantizeLinear-21
Browse files Browse the repository at this point in the history
  • Loading branch information
gkrivor authored Oct 30, 2024
1 parent f3e3f75 commit 14b7e18
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/frontends/onnx/frontend/src/op/dequantize_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "transformations/rt_info/disable_constant_folding.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
using namespace ov::op;
Expand Down Expand Up @@ -227,6 +228,7 @@ ov::OutputVector dequantize_linear(const ov::frontend::onnx::Node& node) {
zp = inputs[2];
if (zp.get_element_type() != scale.get_element_type()) {
zp = std::make_shared<v1::ConvertLike>(zp, scale);
ov::pass::disable_constant_folding(zp.get_node_shared_ptr());
}
zp = std::make_shared<v0::Unsqueeze>(zp, unsqueezed_axes);
}
Expand All @@ -241,7 +243,11 @@ ov::OutputVector dequantize_linear(const ov::frontend::onnx::Node& node) {
src_x.get_shape()[0] % block_size == 0,
"DequantizeLinear doesn't support case when first dimension of X cannot be divided by block_size");

const auto& x = src_x.get_element_type() == scale_type ? src_x : std::make_shared<v1::ConvertLike>(src_x, scale);
ov::Output<ov::Node> x = src_x;
if (src_x.get_element_type() != scale_type) {
x = std::make_shared<v1::ConvertLike>(src_x, scale);
ov::pass::disable_constant_folding(x.get_node_shared_ptr());
}
// For further broadcasting scales and zp - reshape input to a shape [x.shape[0]/block_size, block_size, x.shape[1]]
ov::Output<ov::Node> broadcastable_x =
op::util::reshape(x, Shape{static_cast<size_t>(x.get_shape()[0]) / block_size, block_size, x.get_shape()[1]});
Expand Down

0 comments on commit 14b7e18

Please sign in to comment.