Skip to content

Commit

Permalink
Update ExecuTorch for XNNPACK 87ee0b4
Browse files Browse the repository at this point in the history
Differential Revision: D61822607
  • Loading branch information
GregoryComer committed Aug 26, 2024
1 parent 1f0487d commit cea3bbb
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 70 deletions.
4 changes: 4 additions & 0 deletions backends/xnnpack/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ set(XNNPACK_ENABLE_AVXVNNI
OFF
CACHE BOOL ""
)
set(XNNPACK_ENABLE_KLEIDIAI
OFF
CACHE BOOL ""
)
add_subdirectory("${XNNPACK_SOURCE_DIR}")
include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR})
list(APPEND xnnpack_third_party XNNPACK)
Expand Down
73 changes: 66 additions & 7 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ namespace executor {
namespace xnnpack {
namespace delegate {

/*
* Provide compile-time allocation.
*/
class CompileAllocator {
public:
/*
* Allocate memory which will be automatically freed at the end
* of the compilation process.
*/
void* allocateTemporary(size_t size) {
auto mem = new uint8_t[size];
temporaries_.emplace_back(mem);
return mem;
}

private:
std::vector<std::unique_ptr<uint8_t[]>> temporaries_;
};

// Flatbuffer types
using ValuePtr = const fb_xnnpack::XValue*;
using NodePtr = const fb_xnnpack::XNode*;
Expand All @@ -35,6 +54,23 @@ using DefineNodeFunc = Error (*)(
const std::unordered_map<uint32_t, uint32_t>&,
NodePtr) noexcept;

/*
Convert a tensor from fp32 to bf16.
*/
void convertF32TensorToBF16(
const float* f32_data,
uint16_t* bf16_data_out,
size_t numel) {
for (auto i = 0u; i < numel; i++) {
// Adjust the f32 value such that it rounds properly after truncation.
// Constant factor scales 1+2^-8 to 1+2e-7.
float f32_adjusted = f32_data[i] * 1.00389105f;
uint32_t f32_bits;
memcpy(&f32_bits, &f32_adjusted, sizeof(float));
bf16_data_out[i] = static_cast<uint16_t>(f32_bits >> 16);
}
}

/*
Gets the output min and output max for a given node operator
*/
Expand Down Expand Up @@ -152,7 +188,8 @@ Error defineTensor(
GraphPtr flatbuffer_graph,
const uint8_t* constant_data_ptr,
std::vector<uint32_t>& input_ids,
std::vector<uint32_t>& output_ids) {
std::vector<uint32_t>& output_ids,
CompileAllocator& allocator) {
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;

Expand Down Expand Up @@ -356,12 +393,31 @@ Error defineTensor(
size_t group_size = qparams->group_size();
size_t output_channels = tensor_value->dims()->Get(0);
size_t input_channels = tensor_value->dims()->Get(1);

const uint16_t* scale_data = nullptr;
uint32_t scale_numel = 0;

// Block scales are preferably serialized as bf16 but can also be
// serialized as fp32 for backwards compatability.
if (qparams->scale_bf16() != nullptr) {
scale_data =
static_cast<const uint16_t*>(qparams->scale_bf16()->data());
scale_numel = qparams->scale_bf16()->size();
} else {
// Read fp32 scales, convert to bf16.
auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(
qparams->scale()->size() * sizeof(uint16_t)));
scale_numel = qparams->scale()->size();
convertF32TensorToBF16(
qparams->scale()->data(), conv_buffer, scale_numel);
scale_data = conv_buffer;
}

ET_CHECK_OR_RETURN_ERROR(
qparams->scale()->size() ==
output_channels * input_channels / group_size,
scale_numel == output_channels * input_channels / group_size,
Internal,
"scale size %zu != output channels %zu * group size %zu",
(size_t)qparams->scale()->size(),
static_cast<size_t>(scale_numel),
output_channels,
group_size);
int32_t zero_point =
Expand All @@ -370,18 +426,19 @@ Error defineTensor(
Debug,
"define quant tensor (per channel group): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, grpup_size: %zu, output_channels: %zu, dtype: %u, zero_point: %d, datatype: %d\n",
buffer_ptr,
qparams->scale()->size(),
scale_numel,
qparams->channel_dim(),
group_size,
output_channels,
datatype,
zero_point,
datatype);

status = xnn_define_blockwise_quantized_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/datatype,
/*zero_point=*/zero_point,
/*scale=*/qparams->scale()->data(),
/*scale=*/scale_data,
/*num_dims=*/tensor_value->num_dims(),
/*channel_dim=*/qparams->channel_dim(),
/*block_size=*/qparams->group_size(),
Expand Down Expand Up @@ -1617,6 +1674,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
const uint8_t* flatbuffer_data = nullptr;
const uint8_t* constant_data = nullptr;
CompileAllocator compile_allocator;

// Header status can only either be Error::Ok or Error::NotFound
if (header.ok()) {
Expand Down Expand Up @@ -1688,7 +1746,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
flatbuffer_graph,
constant_data,
input_ids,
output_ids);
output_ids,
compile_allocator);

if (err != Error::Ok) {
return err;
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ table PerChannelGroupQuant {
scale:[float];
channel_dim:int;
group_size:int;
scale_bf16:[ushort];
}

table XNNTensorValue {
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ table PerChannelGroupQuant {
scale:[float];
channel_dim:int;
group_size:int;
scale_bf16:[ushort];
}

table PerChannelQuant {
Expand Down
12 changes: 6 additions & 6 deletions backends/xnnpack/test/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ def test_qd8_per_channel_linear_parallel_and_sequential(self):
)
def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
M_sizes = [1, 2, 17, 31]
K_sizes = [8, 32, 64, 128]
bl_sizes = [8, 16, 16, 32]
K_sizes = [32, 32, 64, 128]
bl_sizes = [32, 32, 32, 64]
N_sizes = [2, 17, 92, 128]

for use_bias in [True, False]:
Expand All @@ -430,8 +430,8 @@ def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
)
def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
M_sizes = [1, 2, 17, 31]
K_sizes = [8, 32, 64, 128]
bl_sizes = [8, 16, 16, 32]
K_sizes = [32, 32, 64, 128]
bl_sizes = [32, 32, 32, 64]
N_sizes = [2, 17, 92, 128]

for use_bias in [True, False]:
Expand Down Expand Up @@ -602,8 +602,8 @@ def _test_groupwise_dq_linear(
use_bias: bool = False,
group_size: int = 8,
num_linears: int = 1,
atol: float = 1e-3,
rtol: float = 1e-3,
atol: float = 5e-3,
rtol: float = 5e-3,
):
quantize_(mod, int8_dynamic_activation_int4_weight(group_size=group_size))
unwrap_tensor_subclass(mod)
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/third-party/XNNPACK
Submodule XNNPACK updated 13544 files
27 changes: 21 additions & 6 deletions backends/xnnpack/third-party/generate-xnnpack-wrappers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

from __future__ import print_function
from pathlib import Path
import collections
import os
import sys
Expand Down Expand Up @@ -36,8 +37,8 @@
"PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512VNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_RVV_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
"PROD_AVXVNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)",
Expand All @@ -46,7 +47,7 @@
# add non-prod microkernel sources here:
}

SRC_NAMES = set([
SRC_NAMES = {
"OPERATOR_SRCS",
"SUBGRAPH_SRCS",
"LOGGING_SRCS",
Expand Down Expand Up @@ -81,30 +82,42 @@
"PROD_AVX512F_MICROKERNEL_SRCS",
"PROD_AVX512SKX_MICROKERNEL_SRCS",
"PROD_AVX512VBMI_MICROKERNEL_SRCS",
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS",
"PROD_AVX512VNNI_MICROKERNEL_SRCS",
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS",
"PROD_RVV_MICROKERNEL_SRCS",
"PROD_AVXVNNI_MICROKERNEL_SRCS",
"AARCH32_ASM_MICROKERNEL_SRCS",
"AARCH64_ASM_MICROKERNEL_SRCS",

# add non-prod microkernel sources here:
])
}

def handle_singleline_parse(line):
start_index = line.find("(")
end_index = line.find(")")
line = line[start_index+1:end_index]
key_val = line.split(" ")
return key_val[0], list(map(lambda x: x[4:], key_val[1:]))
return key_val[0], [x[4:] for x in key_val[1:]]

def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"):
print(f"Updating sources from {cmakefile}")
sources = collections.defaultdict(list)
with open(os.path.join(xnnpack_path, cmakefile)) as cmake:
lines = cmake.readlines()
i = 0
while i < len(lines):
line = lines[i]

if lines[i].startswith("INCLUDE"):
file, _ = handle_singleline_parse(line)
if file.startswith("cmake/gen/"):
path = Path(xnnpack_path) / "XNNPACK" / file
local_sources = update_sources(xnnpack_path, path.absolute().as_posix())
for k,v in local_sources.items():
if k in sources:
sources[k] = sources[k] + local_sources[k]
else:
sources[k] = local_sources[k]

if lines[i].startswith("SET") and "src/" in lines[i]:
name, val = handle_singleline_parse(line)
Expand Down Expand Up @@ -132,7 +145,7 @@ def gen_wrappers(xnnpack_path):
xnnpack_sources = collections.defaultdict(list)
sources = update_sources(xnnpack_path)

microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake")
microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/gen/microkernels.cmake")
for key in microkernels_sources:
sources[key] = microkernels_sources[key]

Expand Down Expand Up @@ -186,6 +199,8 @@ def gen_wrappers(xnnpack_path):


def main(argv):
print("Generating wrappers...")

if argv is None or len(argv) == 0:
gen_wrappers(".")
else:
Expand Down
26 changes: 0 additions & 26 deletions backends/xnnpack/third-party/xnnpack.buck.bzl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
load("//third-party:glob_defs.bzl", "subdir_glob")
load(
":xnnpack_src_defs.bzl",
"JIT_SRCS",
"LOGGING_SRCS",
"OPERATOR_SRCS",
"SUBGRAPH_SRCS",
Expand Down Expand Up @@ -69,27 +68,6 @@ def define_xnnpack():
],
)

# @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode.
native.cxx_library(
name = "jit_memory",
srcs = JIT_SRCS,
headers = subdir_glob([
("XNNPACK/src", "**/*.h"),
]),
header_namespace = "",
compiler_flags = [
"-std=c++17",
],
preferred_linkage = "static",
preprocessor_flags = [
"-DXNN_LOG_LEVEL=0",
],
exported_deps = [
":clog",
":interface",
],
)

# @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode.
native.cxx_library(
name = "operators",
Expand Down Expand Up @@ -139,7 +117,6 @@ def define_xnnpack():
preferred_linkage = "static",
preprocessor_flags = [
"-DXNN_LOG_LEVEL=0",
"-DXNN_ENABLE_JIT=0",
"-DXNN_ENABLE_SPARSE=0",
"-DXNN_ENABLE_GEMM_M_SPECIALIZATION=0",
"-DXNN_ENABLE_MEMOPT",
Expand Down Expand Up @@ -1223,7 +1200,6 @@ def define_xnnpack():
]

ARM_XNNPACK_DEPS = [
":jit_memory",
":ukernels_armsimd32",
":ukernels_fp16arith",
":ukernels_asm",
Expand All @@ -1246,7 +1222,6 @@ def define_xnnpack():
"XNNPACK/src/configs/hardware-config.c",
"XNNPACK/src/microparams-init.c",
"XNNPACK/src/operator-run.c",
"XNNPACK/src/operators/post-operation.c",
"XNNPACK/src/microkernel-utils.c",
],
headers = subdir_glob([
Expand All @@ -1271,7 +1246,6 @@ def define_xnnpack():
"-DXNN_NO_X8_OPERATORS",
"-DXNN_ENABLE_MEMOPT",
"-DXNN_ENABLE_SPARSE=0",
"-DXNN_ENABLE_JIT=0",
"-DXNN_ENABLE_ASSEMBLY",
"-DXNN_ENABLE_GEMM_M_SPECIALIZATION",
"-DXNN_ENABLE_ARM_DOTPROD",
Expand Down
12 changes: 0 additions & 12 deletions backends/xnnpack/third-party/xnnpack_src_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -493,30 +493,18 @@ AARCH64_ASM_MICROKERNEL_SRCS = [
"XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S",
"XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S",
"XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S",
"XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S",
"XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S",
]

XNNPACK_SRCS = [
Expand Down
Loading

0 comments on commit cea3bbb

Please sign in to comment.