Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTTI] Replace std::dynamic_(pointer)?_casts with ov::as_type_(ptr)? - Core, IE #28396

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ bool AssignAndReadValueTransformation::canBeTransformed(const TransformationCont
return false;
}

const auto readValue = std::dynamic_pointer_cast<op::util::ReadValueBase>(op->get_control_dependencies()[0]);
const auto readValue = ov::as_type_ptr<op::util::ReadValueBase>(op->get_control_dependencies()[0]);
if (!readValue) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void make_matcher_type_relaxed(ov::pass::GraphRewrite* transformation) {
auto p_node = std::make_shared<pass::pattern::op::Label>(element::f32, Shape{}, is_op_type);

ov::graph_rewrite_callback callback = [](ov::pass::pattern::Matcher& m) {
auto l_node = std::dynamic_pointer_cast<BaseOp>(m.get_match_root());
auto l_node = ov::as_type_ptr<BaseOp>(m.get_match_root());
if (!l_node) {
THROW_TRANSFORMATION_EXCEPTION << "unexpected operation type for type relaxed conversion";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,25 @@ bool ov::pass::low_precision::MarkupCanBeQuantized::run_on_model(const std::shar
continue;
}

if (const auto convolution = std::dynamic_pointer_cast<ov::opset1::Convolution>(node)) {
if (const auto convolution = ov::as_type_ptr<ov::opset1::Convolution>(node)) {
if (!ConvolutionTransformation::isQuantizedStatic(convolution, defaultPrecisions)) {
setEmptyPrecisions(convolution);
}
continue;
}
if (const auto convolutionBackpropData = std::dynamic_pointer_cast<ov::opset1::ConvolutionBackpropData>(node)) {
if (const auto convolutionBackpropData = ov::as_type_ptr<ov::opset1::ConvolutionBackpropData>(node)) {
if (!ConvolutionBackpropDataTransformation::isQuantizedStatic(convolutionBackpropData, defaultPrecisions)) {
setEmptyPrecisions(convolutionBackpropData);
}
continue;
}
if (const auto groupConvolution = std::dynamic_pointer_cast<ov::opset1::GroupConvolution>(node)) {
if (const auto groupConvolution = ov::as_type_ptr<ov::opset1::GroupConvolution>(node)) {
if (!GroupConvolutionTransformation::isQuantizedStatic(groupConvolution, defaultPrecisions)) {
setEmptyPrecisions(groupConvolution);
}
continue;
}
if (const auto concat = std::dynamic_pointer_cast<ov::opset1::Concat>(node)) {
if (const auto concat = ov::as_type_ptr<ov::opset1::Concat>(node)) {
if (!ConcatTransformation::isQuantizedStatic(concat)) {
setEmptyPrecisions(concat);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ov::pass::CompressWeightsWithFakeQuantize::CompressWeightsWithFakeQuantize() {
{weights_pattern, input_low_pattern, input_high_pattern, output_low_pattern, output_high_pattern});

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto fq = std::dynamic_pointer_cast<op::v0::FakeQuantize>(m.get_match_root());
auto fq = ov::as_type_ptr<op::v0::FakeQuantize>(m.get_match_root());
if (!fq)
return false;
const auto& high_precision_type = fq->get_element_type();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ov::pass::InitConstMask::InitConstMask(const ov::AxisSet& dims,
pattern::type_matches_any({element::i8, element::u8, element::f16, element::f32, element::f64}));

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto const_node = std::dynamic_pointer_cast<opset6::Constant>(m.get_match_root());
auto const_node = ov::as_type_ptr<opset6::Constant>(m.get_match_root());
if (!const_node)
return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ class ov::pass::init_masks::InitMatMulMask : public MatcherPass {

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto& matmul =
std::dynamic_pointer_cast<opset6::MatMul>(pattern_map.at(matmul_pattern).get_node_shared_ptr());
const auto& matmul = ov::as_type_ptr<opset6::MatMul>(pattern_map.at(matmul_pattern).get_node_shared_ptr());
if (!matmul)
return false;

Expand Down Expand Up @@ -117,7 +116,7 @@ class ov::pass::init_masks::InitMatMulMask : public MatcherPass {
return false;
}
// 2. Get constant rank to set mask on last dimension
const auto const_op = std::dynamic_pointer_cast<opset6::Constant>(cur_node);
const auto const_op = ov::as_type_ptr<opset6::Constant>(cur_node);
const auto shape_rank = const_op->get_shape().size();
const size_t shift = (matmul->get_transpose_b()) ? 2 : 1;
if (shape_rank < shift) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class ov::pass::mask_propagation::MatMul : public MatcherPass {
a_mask_row = a_mask.get();
auto b_mask_row = b_mask.get();

const auto matmul_op = std::dynamic_pointer_cast<opset10::MatMul>(m_matmul.get_node_shared_ptr());
const auto matmul_op = ov::as_type_ptr<opset10::MatMul>(m_matmul.get_node_shared_ptr());
const auto transpose_a = matmul_op->get_transpose_a();
const auto transpose_b = matmul_op->get_transpose_b();

Expand Down Expand Up @@ -717,13 +717,13 @@ class ov::pass::mask_propagation::FakeQuantize : public MatcherPass {
m_input_high.get_node_shared_ptr(),
m_output_low.get_node_shared_ptr(),
m_output_high.get_node_shared_ptr()};
auto fq_node = std::dynamic_pointer_cast<opset10::FakeQuantize>(m_output.get_node_shared_ptr());
auto fq_node = ov::as_type_ptr<opset10::FakeQuantize>(m_output.get_node_shared_ptr());
if (!fq_node)
return false;
size_t idx = 0;
if (fq_node->get_auto_broadcast() != ov::op::AutoBroadcastType::NONE) {
for (const auto& node : fq_params_nodes) {
auto const_node = std::dynamic_pointer_cast<op::v0::Constant>(node);
auto const_node = ov::as_type_ptr<op::v0::Constant>(node);
if (!const_node)
OPENVINO_THROW("Unexpected operation type.");
auto new_shape = broadcast_shape_to_rank(const_node->get_shape(),
Expand Down Expand Up @@ -771,7 +771,7 @@ class ov::pass::mask_propagation::Concat : public MatcherPass {
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto& m_output = pattern_map.at(concat);
auto concat_ptr = std::dynamic_pointer_cast<opset10::Concat>(m_output.get_node_shared_ptr());
auto concat_ptr = ov::as_type_ptr<opset10::Concat>(m_output.get_node_shared_ptr());
if (!concat_ptr) {
return false;
}
Expand Down Expand Up @@ -930,7 +930,7 @@ class ov::pass::mask_propagation::Reduce : public MatcherPass {
// Check reduce operation reduces only dimension without masks
if (auto input_mask = getMask(m_input)) {
auto output_mask = std::make_shared<ov::Mask>(m_output.get_partial_shape().rank().get_length());
const auto constant = std::dynamic_pointer_cast<opset10::Constant>(m_weights.get_node_shared_ptr());
const auto constant = ov::as_type_ptr<opset10::Constant>(m_weights.get_node_shared_ptr());
OPENVINO_ASSERT(!!constant, "Dynamic cast returned a nullptr");
const auto reduce_dims = constant->cast_vector<int64_t>();

Expand Down Expand Up @@ -1144,7 +1144,7 @@ class ov::pass::mask_propagation::Reshape : public MatcherPass {
if (is_type<opset10::GroupConvolution>(inp.get_node()))
return true;

auto constant = std::dynamic_pointer_cast<opset10::Constant>(m_weights.get_node_shared_ptr());
auto constant = ov::as_type_ptr<opset10::Constant>(m_weights.get_node_shared_ptr());
if (!constant) {
constant = ov::util::get_constant_from_source(m_weights.get_node_shared_ptr());
if (!constant) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static bool not_empty_mask(ov::Mask::Ptr mask) {
}

static bool is_static_reshape_op(std::shared_ptr<ov::Node> node) {
auto reshape_node = std::dynamic_pointer_cast<ov::opset6::Reshape>(node);
auto reshape_node = ov::as_type_ptr<ov::opset6::Reshape>(node);
if (!reshape_node)
return false;

Expand Down Expand Up @@ -224,7 +224,7 @@ bool ov::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ov::Model>& f)
continue;

// TODO: constant can be shared across functions so we need to avoid consumers from other function
auto const_node = std::dynamic_pointer_cast<opset6::Constant>(node);
auto const_node = ov::as_type_ptr<opset6::Constant>(node);
if (!const_node)
continue;

Expand Down
48 changes: 24 additions & 24 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,38 +69,38 @@ RegType Generator::get_op_out_reg_type(const ov::Output<Node>& out) const {
if (reg_type != RegType::undefined)
return reg_type;
const auto op = out.get_node_shared_ptr();
if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(op) ||
std::dynamic_pointer_cast<ov::op::v0::Result>(op) ||
std::dynamic_pointer_cast<op::LoopBegin>(op) ||
std::dynamic_pointer_cast<op::LoopEnd>(op) ||
std::dynamic_pointer_cast<op::Brgemm>(op) ||
std::dynamic_pointer_cast<op::Buffer>(op) ||
std::dynamic_pointer_cast<op::RankNormalization>(op) ||
std::dynamic_pointer_cast<op::Reshape>(op) ||
std::dynamic_pointer_cast<op::Reorder>(op) ||
std::dynamic_pointer_cast<snippets::op::Store>(op)
if (ov::as_type_ptr<ov::op::v0::Parameter>(op) ||
ov::as_type_ptr<ov::op::v0::Result>(op) ||
ov::as_type_ptr<op::LoopBegin>(op) ||
ov::as_type_ptr<op::LoopEnd>(op) ||
ov::as_type_ptr<op::Brgemm>(op) ||
ov::as_type_ptr<op::Buffer>(op) ||
ov::as_type_ptr<op::RankNormalization>(op) ||
ov::as_type_ptr<op::Reshape>(op) ||
ov::as_type_ptr<op::Reorder>(op) ||
ov::as_type_ptr<snippets::op::Store>(op)
#ifdef SNIPPETS_DEBUG_CAPS
|| std::dynamic_pointer_cast<op::PerfCountBeginBase>(op)
|| std::dynamic_pointer_cast<op::PerfCountEndBase>(op)
|| ov::as_type_ptr<op::PerfCountBeginBase>(op)
|| ov::as_type_ptr<op::PerfCountEndBase>(op)
#endif
)
return RegType::gpr;
else if (std::dynamic_pointer_cast<snippets::op::Load>(op) ||
std::dynamic_pointer_cast<snippets::op::BroadcastLoad>(op) ||
else if (ov::as_type_ptr<snippets::op::Load>(op) ||
ov::as_type_ptr<snippets::op::BroadcastLoad>(op) ||
ov::op::util::is_unary_elementwise_arithmetic(op) ||
ov::op::util::is_binary_elementwise_arithmetic(op) ||
ov::op::util::is_binary_elementwise_comparison(op) ||
ov::op::util::is_binary_elementwise_logical(op) ||
std::dynamic_pointer_cast<ov::op::v1::LogicalNot>(op) ||
std::dynamic_pointer_cast<ov::op::v0::PRelu>(op) ||
std::dynamic_pointer_cast<ov::op::v0::Convert>(op) ||
std::dynamic_pointer_cast<ov::op::v1::Select>(op) ||
std::dynamic_pointer_cast<op::VectorBuffer>(op) ||
std::dynamic_pointer_cast<op::BroadcastMove>(op) ||
std::dynamic_pointer_cast<op::Scalar>(op) ||
std::dynamic_pointer_cast<op::HorizonMax>(op) ||
std::dynamic_pointer_cast<op::HorizonSum>(op) ||
std::dynamic_pointer_cast<op::Fill>(op))
ov::as_type_ptr<ov::op::v1::LogicalNot>(op) ||
ov::as_type_ptr<ov::op::v0::PRelu>(op) ||
ov::as_type_ptr<ov::op::v0::Convert>(op) ||
ov::as_type_ptr<ov::op::v1::Select>(op) ||
ov::as_type_ptr<op::VectorBuffer>(op) ||
ov::as_type_ptr<op::BroadcastMove>(op) ||
ov::as_type_ptr<op::Scalar>(op) ||
ov::as_type_ptr<op::HorizonMax>(op) ||
ov::as_type_ptr<op::HorizonSum>(op) ||
ov::as_type_ptr<op::Fill>(op))
return RegType::vec;
else
OPENVINO_THROW("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!");
Expand Down
8 changes: 4 additions & 4 deletions src/common/snippets/src/lowered/pass/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void PassPipeline::run(const lowered::LinearIR& linear_ir) const {
if (m_pass_config->is_disabled(pass->get_type_info())) {
continue;
}
const auto const_pass = std::dynamic_pointer_cast<ConstPass>(pass);
const auto const_pass = ov::as_type_ptr<ConstPass>(pass);
OPENVINO_ASSERT(const_pass != nullptr,
"Unexpected pass (",
pass->get_type_info(),
Expand All @@ -56,11 +56,11 @@ void PassPipeline::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearI
if (m_pass_config->is_disabled(pass->get_type_info())) {
continue;
}
if (auto lir_pass = std::dynamic_pointer_cast<Pass>(pass)) {
if (auto lir_pass = ov::as_type_ptr<Pass>(pass)) {
lir_pass->run(linear_ir);
} else if (auto const_pass = std::dynamic_pointer_cast<ConstPass>(pass)) {
} else if (auto const_pass = ov::as_type_ptr<ConstPass>(pass)) {
const_pass->run(linear_ir);
} else if (auto ranged_pass = std::dynamic_pointer_cast<RangedPass>(pass)) {
} else if (auto ranged_pass = ov::as_type_ptr<RangedPass>(pass)) {
ranged_pass->run(linear_ir, begin, end);
} else {
OPENVINO_THROW("Unexpected pass (", pass->get_type_info(), ") is registered in PassPipeline");
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ auto Subgraph::constant_input_should_be_inside_body(const std::shared_ptr<ov::No
}

bool Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept {
const auto elementwise = std::dynamic_pointer_cast<const ov::op::util::BinaryElementwiseArithmetic>(node);
const auto elementwise = ov::as_type_ptr<const ov::op::util::BinaryElementwiseArithmetic>(node);
return
(elementwise == nullptr) ||
(elementwise->get_input_partial_shape(0).size() == elementwise->get_input_partial_shape(1).size()) ||
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/pass/fq_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ov::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::FakeQuantizeDecomposition")
auto& pattern_to_output = m.get_pattern_value_map();
const auto fake_quantize_node = std::dynamic_pointer_cast<ov::op::v0::FakeQuantize>(
const auto fake_quantize_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(
pattern_to_output.at(fake_quantize).get_node_shared_ptr());

if (!fake_quantize_node || transformation_callback(fake_quantize_node)) {
Expand Down Expand Up @@ -358,7 +358,7 @@ bool ov::snippets::pass::CommonFakeQuantizeDecomposition::is_supported_fq(const
if (!greater_equal->constant_fold(result, greater_equal->input_values()))
return false;

const auto res_node = std::dynamic_pointer_cast<const ov::op::v0::Constant>(result[0].get_node_shared_ptr());
const auto res_node = ov::as_type_ptr<const ov::op::v0::Constant>(result[0].get_node_shared_ptr());
const auto comp_result = res_node->cast_vector<bool>();
return !std::any_of(comp_result.begin(), comp_result.end(), [](const bool value) {
return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class FakeQuantizeDecompositionTest : public TransformationTestsF {
TransformationTestsF::TearDown();

auto subgraph = FunctionHelper::getSubgraph(model);
auto body = subgraph == nullptr ? nullptr : std::dynamic_pointer_cast<ov::snippets::op::Subgraph>(subgraph)->body_ptr();
auto body = subgraph == nullptr ? nullptr : ov::as_type_ptr<ov::snippets::op::Subgraph>(subgraph)->body_ptr();

auto subgraph_ref = FunctionHelper::getSubgraph(model_ref);
auto body_ref = subgraph_ref == nullptr ? nullptr : std::dynamic_pointer_cast<ov::snippets::op::Subgraph>(subgraph_ref)->body_ptr();
auto body_ref = subgraph_ref == nullptr ? nullptr : ov::as_type_ptr<ov::snippets::op::Subgraph>(subgraph_ref)->body_ptr();

auto res = comparator.compare(body, body_ref);
ASSERT_TRUE(res.valid) << res.message;
Expand All @@ -48,4 +48,4 @@ TEST_F(FakeQuantizeDecompositionTest, smoke_Snippets_PerTensorFakeQuantizeDecomp

} // namespace snippets
} // namespace test
} // namespace ov
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TRANSFORMATIONS_API DepthToSpaceFusion;
*
* // This callback enables DepthToSpaceFusion transformation
* auto callback = [](const std::shared_ptr<const ov::Node> & node) -> bool {
* return std::dynamic_pointer_cast<const ov::opset3::DepthToSpace>(node) != nullptr;
* return ov::as_type_ptr<const ov::opset3::DepthToSpace>(node) != nullptr;
* };
*
* auto p = ov::pass::DepthToSpaceFusion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ class ov::pass::ConvertReduceToPooling : public ov::pass::GraphRewrite {
template <class T>
ov::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
return [&](ov::pass::pattern::Matcher& m) {
auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
auto reduce = ov::as_type_ptr<T>(m.get_match_root());

if (!reduce || transformation_callback(reduce) || ov::shape_size(reduce->input_value(0).get_shape()) == 0) {
return false;
}

auto input = reduce->input_value(0);

auto axes_node = std::dynamic_pointer_cast<ov::op::v0::Constant>(reduce->input_value(1).get_node_shared_ptr());
auto axes_node = ov::as_type_ptr<ov::op::v0::Constant>(reduce->input_value(1).get_node_shared_ptr());
if (!axes_node) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ov::pass::ConvertReduceToReshape : public ov::pass::GraphRewrite {
template <class T>
ov::matcher_pass_callback CvtReduceBase::convert_reduce_to_reshape() {
return [&](ov::pass::pattern::Matcher& m) {
auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
auto reduce = ov::as_type_ptr<T>(m.get_match_root());
if (!reduce)
return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ class PatternValidator {
if (rt_info.count("symbolic_const_value")) {
// symbolic constant node, a symbol reference is observed
auto& symbols = rt_info["symbolic_const_value"].as<std::vector<Symbol>>();
auto constop = std::dynamic_pointer_cast<op::v0::Constant>(value_node);
auto constop = ov::as_type_ptr<op::v0::Constant>(value_node);
if (!constop) {
_VERBOSE_LOG("symbolic_const_value unexpected OP: ", value_node->get_friendly_name());
return false;
Expand Down Expand Up @@ -1292,9 +1292,9 @@ class PatternValidator {
}
continue;
}
if (auto pconst_node = std::dynamic_pointer_cast<ov::op::v0::Constant>(pnode)) {
if (auto pconst_node = ov::as_type_ptr<ov::op::v0::Constant>(pnode)) {
// const_node needs to match type/shape/value
auto vconst_node = std::dynamic_pointer_cast<ov::op::v0::Constant>(value_node);
auto vconst_node = ov::as_type_ptr<ov::op::v0::Constant>(value_node);
if (!vconst_node) {
_VERBOSE_LOG("expecting Constant op, but got ", value_node);
return false;
Expand Down
Loading
Loading