Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemm int8 quantization #5706

Merged
merged 56 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
daecbe6
staging
nihui Sep 19, 2024
ab9f555
apply code-format changes
nihui Sep 19, 2024
dd6bf5c
clean
nihui Sep 20, 2024
86b03df
build for armv7
nihui Sep 20, 2024
7f2d1da
quantize gemm
nihui Sep 20, 2024
37208ab
write gemm quantize scales
nihui Sep 20, 2024
47cd674
update doc
nihui Sep 20, 2024
ecf2e3b
wip
nihui Sep 20, 2024
f70e5ef
fp32 alpha beta
nihui Sep 23, 2024
4b1b2b3
stash
nihui Sep 25, 2024
d59410f
stash
nihui Sep 25, 2024
07f6755
fix int8 bf16s
nihui Sep 26, 2024
6302017
build
nihui Sep 26, 2024
dcd0636
opt++
nihui Sep 26, 2024
09af847
fix
nihui Sep 26, 2024
8068e6c
test++
nihui Sep 27, 2024
a881bd9
less openmp args
nihui Sep 27, 2024
110f316
apply code-format changes
nihui Sep 27, 2024
4798d10
revert cpu runtime off build behavior
nihui Sep 27, 2024
392e38b
revert 2nd
nihui Sep 27, 2024
207e166
vfpv4 fp16
nihui Sep 29, 2024
e5e5cd3
apply code-format changes
nihui Sep 29, 2024
93fad8d
apply code-format changes
nihui Sep 29, 2024
eb0c833
stash
nihui Sep 29, 2024
55e0e57
stash
nihui Sep 30, 2024
87d5cd4
Merge branch 'gemm-quantize-r' of github.com:nihui/ncnn into gemm-qua…
nihui Sep 30, 2024
1e50e88
build
nihui Sep 30, 2024
d9e9b38
opt++
nihui Oct 8, 2024
5566752
opt++
nihui Oct 8, 2024
03168d7
apply code-format changes
nihui Oct 8, 2024
a280bfb
fast path
nihui Oct 8, 2024
3748dcc
cc
nihui Oct 8, 2024
9fa2532
fabsf
nihui Oct 8, 2024
6a1c346
fix build
nihui Oct 8, 2024
878f723
fix build
nihui Oct 8, 2024
8ab3c4d
fix tests
nihui Oct 9, 2024
24d2c15
x86 riscv fallback
nihui Oct 9, 2024
20a221f
skip gemm vulkan int8
nihui Oct 9, 2024
5b099e4
fix
nihui Oct 9, 2024
55e2de7
fix
nihui Oct 10, 2024
061205a
fix noint8 test, fix arm bf16 test
nihui Oct 10, 2024
73e7364
enable vfpv4 on neon build only
nihui Oct 10, 2024
9c3057b
fix test, test++
nihui Oct 10, 2024
0b7755d
fix gemm vulkan without C
nihui Oct 10, 2024
f5a828b
test++
nihui Oct 10, 2024
5e4fcfa
test++
nihui Oct 10, 2024
545f075
fp16 pack8 output, cc
nihui Oct 11, 2024
16447e8
fix
nihui Oct 11, 2024
7bcaeeb
cc
nihui Oct 11, 2024
a1f9b28
fix
nihui Oct 11, 2024
ec29d98
test fp16s
nihui Oct 11, 2024
e078b6e
enable elempack=8 only for asimdhp+
nihui Oct 12, 2024
854be47
cc
nihui Oct 12, 2024
1e21bb9
cc
nihui Oct 12, 2024
b8421f7
tiled gemm int8 test
nihui Oct 12, 2024
4f24312
opt arm64 tiles, fix asimdhp dispatch
nihui Oct 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,25 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm")
endif()

if(CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s, _a, _b; _s = vmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM_NEON)

unset(CMAKE_REQUIRED_FLAGS)
else()
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
if(NCNN_COMPILER_SUPPORT_ARM_NEON)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4)
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
endif()
unset(CMAKE_REQUIRED_FLAGS)
else()
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

unset(CMAKE_REQUIRED_FLAGS)
if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4)
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
endif()

unset(CMAKE_REQUIRED_FLAGS)
endif()
endif()

if(NCNN_COMPILER_SUPPORT_ARM_VFPV4 OR NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
Expand Down
96 changes: 48 additions & 48 deletions cmake/ncnn_add_layer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -144,25 +144,25 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__")
endif()
if(NCNN_AVX512VNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
endif()
if(NCNN_AVX512BF16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__")
endif()
if(NCNN_AVX512FP16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__")
endif()
if(NCNN_AVXVNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_AVX2)
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
if(NCNN_XOP)
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__XOP__")
endif()
if(NCNN_F16C)
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__F16C__")
endif()
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
Expand All @@ -175,25 +175,25 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__")
endif()
if(NCNN_AVX512VNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
endif()
if(NCNN_AVX512BF16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__")
endif()
if(NCNN_AVX512FP16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__")
endif()
if(NCNN_AVXVNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_AVX2)
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
if(NCNN_XOP)
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
ncnn_add_arch_opt_source(${class} xop "/arch:AVX -mxop /D__SSSE3__ /D__SSE4_1__ /D__XOP__")
endif()
if(NCNN_F16C)
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX -mf16c /D__SSSE3__ /D__SSE4_1__ /D__F16C__")
endif()
else()
Expand All @@ -206,25 +206,25 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
ncnn_add_arch_opt_layer(${class} avx "-mavx")
endif()
if(NCNN_AVX512VNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
ncnn_add_arch_opt_source(${class} avx512vnni "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni")
endif()
if(NCNN_AVX512BF16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
ncnn_add_arch_opt_source(${class} avx512bf16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16")
endif()
if(NCNN_AVX512FP16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
ncnn_add_arch_opt_source(${class} avx512fp16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16")
endif()
if(NCNN_AVXVNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni")
endif()
if(NCNN_AVX2)
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c")
endif()
if(NCNN_XOP)
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
ncnn_add_arch_opt_source(${class} xop "-mavx -mxop")
endif()
if(NCNN_F16C)
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
ncnn_add_arch_opt_source(${class} f16c "-mavx -mf16c")
endif()
endif()
Expand Down Expand Up @@ -254,28 +254,28 @@ macro(ncnn_add_layer class)
if(NCNN_ARM82)
ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM82DOT)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD")
endif()
if(NCNN_ARM82FP16FML)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML")
endif()
if(NCNN_ARM84BF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM84I8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8")
endif()
# TODO add support for sve family
if(NCNN_ARM86SVE)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
endif()
if(NCNN_ARM86SVE2)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
endif()
if(NCNN_ARM86SVEBF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
endif()
if(NCNN_ARM86SVEI8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
endif()
if(NCNN_ARM86SVEF32MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
endif()
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
if(NCNN_VFPV4)
Expand All @@ -284,28 +284,28 @@ macro(ncnn_add_layer class)
if(NCNN_ARM82)
ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 -march=armv8.2-a+fp16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM82DOT)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 -march=armv8.2-a+fp16+dotprod /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD")
endif()
if(NCNN_ARM82FP16FML)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 -march=armv8.2-a+fp16+fp16fml /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML")
endif()
if(NCNN_ARM84BF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+bf16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM84I8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+i8mm /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8")
endif()
# TODO add support for sve family
if(NCNN_ARM86SVE)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
endif()
if(NCNN_ARM86SVE2)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
endif()
if(NCNN_ARM86SVEBF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
endif()
if(NCNN_ARM86SVEI8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
endif()
if(NCNN_ARM86SVEF32MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
endif()
else()
if(NCNN_VFPV4)
Expand All @@ -314,31 +314,31 @@ macro(ncnn_add_layer class)
if(NCNN_ARM82)
ncnn_add_arch_opt_source(${class} asimdhp "-march=armv8.2-a+fp16")
endif()
if(NCNN_ARM82DOT)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
ncnn_add_arch_opt_source(${class} asimddp "-march=armv8.2-a+fp16+dotprod")
endif()
if(NCNN_ARM82FP16FML)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
ncnn_add_arch_opt_source(${class} asimdfhm "-march=armv8.2-a+fp16+fp16fml")
endif()
if(NCNN_ARM84BF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
ncnn_add_arch_opt_source(${class} bf16 "-march=armv8.4-a+fp16+dotprod+bf16")
endif()
if(NCNN_ARM84I8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
ncnn_add_arch_opt_source(${class} i8mm "-march=armv8.4-a+fp16+dotprod+i8mm")
endif()
if(NCNN_ARM86SVE)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
ncnn_add_arch_opt_source(${class} sve "-march=armv8.6-a+fp16+dotprod+sve")
endif()
if(NCNN_ARM86SVE2)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
ncnn_add_arch_opt_source(${class} sve2 "-march=armv8.6-a+fp16+dotprod+sve2")
endif()
if(NCNN_ARM86SVEBF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
ncnn_add_arch_opt_source(${class} svebf16 "-march=armv8.6-a+fp16+dotprod+sve+bf16")
endif()
if(NCNN_ARM86SVEI8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
ncnn_add_arch_opt_source(${class} svei8mm "-march=armv8.6-a+fp16+dotprod+sve+i8mm")
endif()
if(NCNN_ARM86SVEF32MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
ncnn_add_arch_opt_source(${class} svef32mm "-march=armv8.6-a+fp16+dotprod+sve+f32mm")
endif()
endif()
Expand Down
7 changes: 5 additions & 2 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -942,15 +942,18 @@ y = (gemm(a, b) + c * beta) * alpha
| 12 | output_elempack | int | 0 | |
| 13 | output_elemtype | int | 0 | |
| 14 | output_transpose | int| 0 | |
| 18 | int8_scale_term | int | 0 | |
| 20 | constant_TILE_M | int | 0 | |
| 21 | constant_TILE_N | int | 0 | |
| 22 | constant_TILE_K | int | 0 | |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| A_data | float | [M, K] or [K, M] |
| B_data | float | [N, K] or [K, N] |
| A_data | float/fp16/int8 | [M, K] or [K, M] |
| B_data | float/fp16/int8 | [N, K] or [K, N] |
| C_data | float | [1], [M] or [N] or [1, M] or [N,1] or [N, M] |
| A_data_int8_scales| float | [M] |
| B_data_int8_scales| float | [1] |

# GridSample
```
Expand Down
Loading
Loading