diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/aarch64/matmul/acl_matmul.cpp index 3d3e95a491d..64d594c6a9d 100644 --- a/src/cpu/aarch64/matmul/acl_matmul.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul.cpp @@ -76,12 +76,18 @@ status_t acl_matmul_t::pd_t::init(engine_t *engine) { = utils::everyone_is(data_type::bf16, src_md()->data_type, weights_md()->data_type, dst_md()->data_type) && platform::has_data_type_support(data_type::bf16); + const bool is_bf16f32_ok + = utils::everyone_is(data_type::bf16, src_md()->data_type, + weights_md()->data_type) + && utils::everyone_is(data_type::f32, dst_md()->data_type) + && platform::has_data_type_support(data_type::bf16); // we need to save this state as it can change inside set_default_formats() weights_format_kind_ = weights_md_.format_kind; VDISPATCH_MATMUL(is_dense_format_kind(), VERBOSE_UNSUPPORTED_SPARSE_CFG); - VDISPATCH_MATMUL(utils::one_of(true, is_fp32_ok, is_fp16_ok, is_bf16_ok), + VDISPATCH_MATMUL(utils::one_of(true, is_fp32_ok, is_fp16_ok, is_bf16_ok, + is_bf16f32_ok), VERBOSE_UNSUPPORTED_DT_CFG); VDISPATCH_MATMUL(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG);