Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExecuTorch] Reapply D62466496: Build optimized kernels with bf16 support and gate usage at runtime #5420

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions kernels/optimized/blas/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#ifdef __aarch64__
#include <arm_neon.h>
#include <cpuinfo.h>
#endif

using torch::executor::BFloat16;
Expand All @@ -23,7 +24,7 @@ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
return vfmaq_f32(a, b, c);
#else
return vaddq_f32(a, vmulq_f32(b, c));
#endif
#endif // __ARM_FEATURE_FMA
}

// The below reduce overload and fp16_dot_with_fp32_arith are adapted
Expand Down Expand Up @@ -78,35 +79,39 @@ static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
return vbfdotq_f32(a, b, c);
}
#endif
#endif // __ARM_FEATURE_BF16

template <bool useBfloat16Dot>
static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
#ifdef __ARM_FEATURE_BF16
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
#else
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
#endif
if (useBfloat16Dot) {
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
} else
#endif // __ARM_FEATURE_BF16
{
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
}
}

static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
Expand All @@ -121,7 +126,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
}

template <typename T>
template <typename T, bool useBfloat16Dot>
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
Expand All @@ -130,7 +135,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
const auto* vec2_ = vec2 + j;
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
vec1_, vec2_, sum, k);
});
}
auto reducedSum = reduce(sum);
Expand All @@ -157,9 +163,16 @@ float bf16_dot_with_fp32_arith(
const BFloat16* vec1,
const BFloat16* vec2,
int64_t len) {
return dot_with_fp32_arith(vec1, vec2, len);
#ifdef __ARM_FEATURE_BF16
if (cpuinfo_has_arm_bf16()) {
return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
} else
#endif // __ARM_FEATURE_BF16
{
return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
}
}
#endif
#endif // __aarch64__
} // namespace internal
} // namespace cpublas
} // namespace executorch
16 changes: 15 additions & 1 deletion kernels/optimized/lib_defs.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFORM_REGEX")
load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

# Because vec exists as a collection of header files, compile and preprocessor
Expand Down Expand Up @@ -109,6 +110,8 @@ def define_libs():
],
)

LIBBLAS_DEPS = [third_party_dep("cpuinfo")]

for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp"), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl")]:
runtime.cxx_library(
name = libblas_name,
Expand All @@ -129,6 +132,14 @@ def define_libs():
] if not runtime.is_oss else [],
"DEFAULT": [],
}),
fbandroid_platform_compiler_flags = [
(
"^android-arm64.*$",
[
"-march=armv8+bf16",
],
),
],
fbandroid_platform_preprocessor_flags = [
(
"^android-arm64.*$",
Expand All @@ -145,6 +156,9 @@ def define_libs():
],
),
],
fbobjc_compiler_flags = [
"-march=armv8+bf16",
],
fbobjc_exported_preprocessor_flags = [
"-DET_BUILD_WITH_BLAS",
"-DET_BUILD_FOR_APPLE",
Expand All @@ -155,7 +169,7 @@ def define_libs():
deps = select({
":linux-x86_64": [mkl_dep] if not runtime.is_oss else [],
"DEFAULT": [],
}),
}) + LIBBLAS_DEPS,
exported_deps = [
"//executorch/extension/parallel:thread_parallel",
"//executorch/kernels/optimized:libutils",
Expand Down
8 changes: 4 additions & 4 deletions kernels/test/op_linear_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ class OpLinearOutTest : public OperatorTest {
}
}

// matmul gives 4 * 2 * 3 = 24
Tensor x = tf.full({3, 4}, 2);
Tensor y = tf.full({5, 4}, 3);
// matmul gives 32 * 2 * 3 = 192
Tensor x = tf.full({3, 32}, 2);
Tensor y = tf.full({5, 32}, 3);

// Output shape should be (3, 5)
Tensor out = tf.zeros({3, 5});

op_linear_out(x, y, out);

Tensor expected = tf.full({3, 5}, 24);
Tensor expected = tf.full({3, 5}, 192);

EXPECT_TENSOR_EQ(out, expected);
}
Expand Down
3 changes: 2 additions & 1 deletion shim/xplat/executorch/build/env_interface.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def _remove_platform_specific_args(kwargs):
"""
keys = []
for key in kwargs:
if key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or key.startswith("fbobjc"):
if (key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or
key.startswith("fbobjc") or key.endswith("_platform_compiler_flags")):
keys.append(key)
for key in keys:
kwargs.pop(key)
Expand Down
Loading