diff --git a/src/common/transformations/src/transformations/convert_precision.cpp b/src/common/transformations/src/transformations/convert_precision.cpp index f6cd1ab20012f7..6d9d11ff52bcba 100644 --- a/src/common/transformations/src/transformations/convert_precision.cpp +++ b/src/common/transformations/src/transformations/convert_precision.cpp @@ -1091,6 +1091,26 @@ std::shared_ptr change_constant_precision +std::shared_ptr change_constant_precision( + std::shared_ptr& constant) { + using src_type = typename element_type_traits::value_type; + using dst_type = typename element_type_traits::value_type; + + const auto* src_data = constant->get_data_ptr(); + const auto size = shape_size(constant->get_shape()); + + auto new_constant = std::make_shared(ov::element::Type_t::f16, constant->get_shape()); + new_constant->output(0).set_names(constant->output(0).get_names()); + auto* dst_data = const_cast(reinterpret_cast(new_constant->get_data_ptr())); + if (dst_data == nullptr) + OPENVINO_THROW("Can't get destination data pointer"); + + ov::reference::convert_from_bf16_to_f16_with_clamp(src_data, dst_data, size); + + return new_constant; +} + template <> std::shared_ptr change_constant_precision( std::shared_ptr& constant) { @@ -1326,6 +1346,8 @@ bool fuse_type_to_constant(const std::shared_ptr& node, new_const = change_constant_precision(constant); } else if (from == ov::element::bf16 && to == ov::element::f32) { new_const = change_constant_precision(constant); + } else if (from == ov::element::bf16 && to == ov::element::f16) { + new_const = change_constant_precision(constant); } else if (from == ov::element::f32 && to == ov::element::f16) { new_const = change_constant_precision(constant); } else if (from == ov::element::f16 && to == ov::element::f32) { diff --git a/src/common/transformations/tests/utils/convert_precision.cpp b/src/common/transformations/tests/utils/convert_precision.cpp index b685218b78dda9..f1782b6905c3c7 100644 --- a/src/common/transformations/tests/utils/convert_precision.cpp +++ b/src/common/transformations/tests/utils/convert_precision.cpp @@ -382,6 +382,40 @@ TEST(TransformationTests, ConvertPrecision_Convert_clamp_1) { ASSERT_TRUE(res.valid) << res.message; } +TEST(TransformationTests, ConvertPrecision_Convert_clamp_bf16_f16) { + // fp16 out of range should be clamped to [fp16_min, fp16_max] + std::shared_ptr model(nullptr), model_ref(nullptr); + { + auto input = std::make_shared(element::f16, Shape{1, 1000, 3}); + auto const_node = opset10::Constant::create(element::bf16, Shape{3}, {100000.0f, -100000.0f, 10.0f}); + auto convert = std::make_shared(const_node, element::f16); + auto add_1 = make_shared(input, convert); + model = std::make_shared(NodeVector{add_1}, ParameterVector{input}); + + pass::Manager manager; + static const precisions_map precisions = {{element::bf16, element::f16}}; + manager.register_pass(); + manager.register_pass(precisions); + manager.run_passes(model); + } + + { + auto max_fp16 = static_cast(std::numeric_limits::max()); + auto input = std::make_shared(element::f16, Shape{1, 1000, 3}); + auto const_node = opset10::Constant::create(element::f16, Shape{3}, {max_fp16, -max_fp16, 10.0f}); + auto add_1 = make_shared(input, const_node); + + model_ref = std::make_shared(NodeVector{add_1}, ParameterVector{input}); + } + ASSERT_NO_THROW(check_rt_info(model)); + const auto fc = FunctionsComparator::with_default() + .enable(FunctionsComparator::PRECISIONS) + .enable(FunctionsComparator::CONST_VALUES) + .enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); + const auto res = fc.compare(model, model_ref); + ASSERT_TRUE(res.valid) << res.message; +} + #if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) TEST(TransformationTests, ConvertPrecision_Convert_clamp_2) { #else diff --git a/src/core/reference/include/openvino/reference/convert.hpp b/src/core/reference/include/openvino/reference/convert.hpp index 980344fd7b083f..867679885337f8 100644 --- a/src/core/reference/include/openvino/reference/convert.hpp +++ b/src/core/reference/include/openvino/reference/convert.hpp @@ -82,5 +82,8 @@ size_t count_out_of_f16_range(const float* arg, size_t count); // Convert values from f32 to f16 with clamping to f16 min/max when value is out of normal finite numbers range void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t count); + +// Convert values from bf16 to f16 with clamping to f16 min/max when value is out of normal finite numbers range +void convert_from_bf16_to_f16_with_clamp(const bfloat16* arg, float16* out, size_t count); } // namespace reference } // namespace ov diff --git a/src/core/reference/src/op/convert.cpp b/src/core/reference/src/op/convert.cpp index 5f7f4baf9251db..856ec55a3428b2 100644 --- a/src/core/reference/src/op/convert.cpp +++ b/src/core/reference/src/op/convert.cpp @@ -66,6 +66,22 @@ void jit_convert_vec(jit::Generator& gen, const Xbyak::RegExp gen.vmovdqu(gen.xword[dst], f16vec); // move result to destination } +template <> +void jit_convert_vec(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) { + const auto f32vec = gen.ymm4; + const auto f16vec = gen.xmm3; + + auto upper_bound = gen.ymm5; + auto lower_bound = gen.ymm6; + + gen.vpmovzxwd(f32vec, gen.yword[src]); // load bf16 into tmp + gen.vpslld(f32vec, f32vec, 16); // convert bf16->f32 by bit shift + gen.vminps(f32vec, f32vec, upper_bound); // clamp f16 max + gen.vmaxps(f32vec, f32vec, lower_bound); // clamp f16 lowest + gen.vcvtps2ph(f16vec, f32vec, 0); // convert f32 -> f16 + gen.vmovdqu(gen.xword[dst], f16vec); // move result to destination +} + template <> void jit_convert_vec(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) { const auto f32vec = gen.ymm4; @@ -92,6 +108,11 @@ void jit_convert_vec_prepare(jit::Generator& gen) { gen.vmovdqu(lower_bound, gen.yword[addr]); } +template <> +void jit_convert_vec_prepare(jit::Generator& gen) { + jit_convert_vec_prepare(gen); +} + template <> void jit_convert_vec(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) { auto f16vec = gen.xmm3; @@ -552,6 +573,23 @@ void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t c #endif // defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) } +void convert_from_bf16_to_f16_with_clamp(const bfloat16* arg, float16* out, size_t count) { +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) + convert_impl(arg, out, count); +#else + // FIXME CVS-125496: duplicate and stub for ARM, provide optimized solution + for (size_t i = 0; i < count; ++i) { + if (arg[i] > std::numeric_limits::max()) { + out[i] = std::numeric_limits::max(); + } else if (arg[i] < std::numeric_limits::lowest()) { + out[i] = std::numeric_limits::lowest(); + } else { + out[i] = static_cast(arg[i]); + } + } +#endif // defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) +} + size_t count_out_of_f16_range(const float* arg, size_t count) { size_t num_out_of_range = 0;