From 758891c3f0632548e69b01e877c235debdacced3 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 16 Sep 2024 12:21:57 -0700 Subject: [PATCH] Reapply D62466496: Build optimized kernels with bf16 support and gate usage at runtime (#5376) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5376 Now with fewer broken tests. ghstack-source-id: 242772181 Reviewed By: kimishpatel Differential Revision: D62680594 fbshipit-source-id: 517791f303165423977593631e93368b95864e95 (cherry picked from commit 2b3cc276a85147b5c5852d69b8d850dcdacd6977) --- kernels/optimized/blas/BlasKernel.cpp | 67 +++++++++++-------- kernels/optimized/lib_defs.bzl | 16 ++++- kernels/test/op_linear_test.cpp | 8 +-- shim/xplat/executorch/build/env_interface.bzl | 3 +- 4 files changed, 61 insertions(+), 33 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index cfa362420f..cfee709ae6 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -10,6 +10,7 @@ #ifdef __aarch64__ #include +#include #endif using torch::executor::BFloat16; @@ -23,7 +24,7 @@ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { return vfmaq_f32(a, b, c); #else return vaddq_f32(a, vmulq_f32(b, c)); -#endif +#endif // __ARM_FEATURE_FMA } // The below reduce overload and fp16_dot_with_fp32_arith are adapted @@ -78,35 +79,39 @@ static ET_INLINE float32x4_t f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { return vbfdotq_f32(a, b, c); } -#endif +#endif // __ARM_FEATURE_BF16 +template static ET_INLINE void dot_with_fp32_arith_main_inner_loop( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { #ifdef __ARM_FEATURE_BF16 - const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - sum[registerPairIndex] = - f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); -#else - const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - - sum[2 * registerPairIndex] = f32_fma_bf16( - sum[2 * registerPairIndex], - vget_low_u16(temp_vec1), - vget_low_u16(temp_vec2)); - sum[2 * registerPairIndex + 1] = f32_fma_bf16( - sum[2 * registerPairIndex + 1], - vget_high_u16(temp_vec1), - vget_high_u16(temp_vec2)); -#endif + if (useBfloat16Dot) { + const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + sum[registerPairIndex] = + f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); + } else +#endif // __ARM_FEATURE_BF16 + { + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16( + sum[2 * registerPairIndex], + vget_low_u16(temp_vec1), + vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16( + sum[2 * registerPairIndex + 1], + vget_high_u16(temp_vec1), + vget_high_u16(temp_vec2)); + } } static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( @@ -121,7 +126,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); } -template +template float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); @@ -130,7 +135,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { const auto* vec2_ = vec2 + j; utils::ForcedUnroll{}( [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); + dot_with_fp32_arith_main_inner_loop( + vec1_, vec2_, sum, k); }); } auto reducedSum = reduce(sum); @@ -157,9 +163,16 @@ float bf16_dot_with_fp32_arith( const BFloat16* vec1, const BFloat16* vec2, int64_t len) { - return dot_with_fp32_arith(vec1, vec2, len); +#ifdef __ARM_FEATURE_BF16 + if (cpuinfo_has_arm_bf16()) { + return dot_with_fp32_arith(vec1, vec2, len); + } else +#endif // __ARM_FEATURE_BF16 + { + return dot_with_fp32_arith(vec1, vec2, len); + } } -#endif +#endif // __aarch64__ } // namespace internal } // namespace cpublas } // namespace executorch diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 04ee0cfde4..f4c103c0a0 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -1,5 +1,6 @@ load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFORM_REGEX") load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native") +load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") # Because vec exists as a collection of header files, compile and preprocessor @@ -109,6 +110,8 @@ def define_libs(): ], ) + LIBBLAS_DEPS = [third_party_dep("cpuinfo")] + for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp"), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl")]: runtime.cxx_library( name = libblas_name, @@ -129,6 +132,14 @@ def define_libs(): ] if not runtime.is_oss else [], "DEFAULT": [], }), + fbandroid_platform_compiler_flags = [ + ( + "^android-arm64.*$", + [ + "-march=armv8+bf16", + ], + ), + ], fbandroid_platform_preprocessor_flags = [ ( "^android-arm64.*$", @@ -145,6 +156,9 @@ def define_libs(): ], ), ], + fbobjc_compiler_flags = [ + "-march=armv8+bf16", + ], fbobjc_exported_preprocessor_flags = [ "-DET_BUILD_WITH_BLAS", "-DET_BUILD_FOR_APPLE", @@ -155,7 +169,7 @@ def define_libs(): deps = select({ ":linux-x86_64": [mkl_dep] if not runtime.is_oss else [], "DEFAULT": [], - }), + }) + LIBBLAS_DEPS, exported_deps = [ "//executorch/extension/parallel:thread_parallel", "//executorch/kernels/optimized:libutils", diff --git a/kernels/test/op_linear_test.cpp b/kernels/test/op_linear_test.cpp index 96875cc6f7..47f8925af0 100644 --- a/kernels/test/op_linear_test.cpp +++ b/kernels/test/op_linear_test.cpp @@ -43,16 +43,16 @@ class OpLinearOutTest : public OperatorTest { } } - // matmul gives 4 * 2 * 3 = 24 - Tensor x = tf.full({3, 4}, 2); - Tensor y = tf.full({5, 4}, 3); + // matmul gives 32 * 2 * 3 = 192 + Tensor x = tf.full({3, 32}, 2); + Tensor y = tf.full({5, 32}, 3); // Output shape should be (3, 5) Tensor out = tf.zeros({3, 5}); op_linear_out(x, y, out); - Tensor expected = tf.full({3, 5}, 24); + Tensor expected = tf.full({3, 5}, 192); EXPECT_TENSOR_EQ(out, expected); } diff --git a/shim/xplat/executorch/build/env_interface.bzl b/shim/xplat/executorch/build/env_interface.bzl index 5b0acd36da..b6e30cd9f6 100644 --- a/shim/xplat/executorch/build/env_interface.bzl +++ b/shim/xplat/executorch/build/env_interface.bzl @@ -118,7 +118,8 @@ def _remove_platform_specific_args(kwargs): """ keys = [] for key in kwargs: - if key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or key.startswith("fbobjc"): + if (key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or + key.startswith("fbobjc") or key.endswith("_platform_compiler_flags")): keys.append(key) for key in keys: kwargs.pop(key)