Skip to content

Commit

Permalink
Add qs8/qu8 vadd/vaddc RVV microkernel implementations and configs
Browse files Browse the repository at this point in the history
  • Loading branch information
KaustubhIMG committed Sep 20, 2024
1 parent d779b27 commit a08c147
Show file tree
Hide file tree
Showing 18 changed files with 654 additions and 0 deletions.
8 changes: 8 additions & 0 deletions cmake/gen/rvv_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ SET(PROD_RVV_MICROKERNEL_SRCS
src/f32-vrnd/gen/f32-vrndu-rvv-u4v.c
src/f32-vrnd/gen/f32-vrndz-rvv-u4v.c
src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c
src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u2v.c
src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c
src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u2v.c
src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c
src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u2v.c
src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c
src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u2v.c
src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c
src/x32-packw/gen/x32-packw-x4v-gemm-goi-rvv-u8.c
src/x32-transposec/gen/x32-transposec-4x4-rvv.c
Expand Down Expand Up @@ -180,9 +184,13 @@ SET(NON_PROD_RVV_MICROKERNEL_SRCS
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c
src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u1v.c
src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c
src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u1v.c
src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c
src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u1v.c
src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c
src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u1v.c
src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c
src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c
src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u4.c
Expand Down
8 changes: 8 additions & 0 deletions gen/rvv_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ PROD_RVV_MICROKERNEL_SRCS = [
"src/f32-vrnd/gen/f32-vrndu-rvv-u4v.c",
"src/f32-vrnd/gen/f32-vrndz-rvv-u4v.c",
"src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c",
"src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u2v.c",
"src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c",
"src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u2v.c",
"src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c",
"src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u2v.c",
"src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c",
"src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u2v.c",
"src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c",
"src/x32-packw/gen/x32-packw-x4v-gemm-goi-rvv-u8.c",
"src/x32-transposec/gen/x32-transposec-4x4-rvv.c",
Expand Down Expand Up @@ -177,9 +181,13 @@ NON_PROD_RVV_MICROKERNEL_SRCS = [
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c",
"src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u1v.c",
"src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c",
"src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u1v.c",
"src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c",
"src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u1v.c",
"src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c",
"src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u1v.c",
"src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c",
"src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c",
"src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u4.c",
Expand Down
13 changes: 13 additions & 0 deletions scripts/generate-qs8-vadd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ tools/xngen src/qs8-vaddc/neon.c.in -D BATCH_TILE=32 -D LD128=0 -D DATATYPE=QU8

tools/xngen src/qs8-vaddc/neon.c.in -D BATCH_TILE=16 -D LD128=1 -D DATATYPE=QU8 -o src/qu8-vaddc/gen/qu8-vaddc-minmax-neon-ld128-u16.c &

################################ RISC-V Vector ################################
tools/xngen src/qs8-vadd/rvv.c.in -D LMUL=1 -D DATATYPE=QS8 -o src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u1v.c &
tools/xngen src/qs8-vadd/rvv.c.in -D LMUL=2 -D DATATYPE=QS8 -o src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u2v.c &

tools/xngen src/qs8-vadd/rvv.c.in -D LMUL=1 -D DATATYPE=QU8 -o src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u1v.c &
tools/xngen src/qs8-vadd/rvv.c.in -D LMUL=2 -D DATATYPE=QU8 -o src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u2v.c &

tools/xngen src/qs8-vaddc/rvv.c.in -D LMUL=1 -D DATATYPE=QS8 -o src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u1v.c &
tools/xngen src/qs8-vaddc/rvv.c.in -D LMUL=2 -D DATATYPE=QS8 -o src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u2v.c &

tools/xngen src/qs8-vaddc/rvv.c.in -D LMUL=1 -D DATATYPE=QU8 -o src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u1v.c &
tools/xngen src/qs8-vaddc/rvv.c.in -D LMUL=2 -D DATATYPE=QU8 -o src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u2v.c &

################################### x86 SSE ###################################
tools/xngen src/qs8-vadd/sse-mul16-ld64.c.in -D BATCH_TILE=8 -D SSE=2 -D AVX=0 -D DATATYPE=QS8 -o src/qs8-vadd/gen/qs8-vadd-minmax-sse2-mul16-ld64-u8.c &
tools/xngen src/qs8-vadd/sse-mul16-ld64.c.in -D BATCH_TILE=16 -D SSE=2 -D AVX=0 -D DATATYPE=QS8 -o src/qs8-vadd/gen/qs8-vadd-minmax-sse2-mul16-ld64-u16.c &
Expand Down
14 changes: 14 additions & 0 deletions src/configs/binary-elementwise-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,13 @@ static void init_qs8_vadd_config(void) {
qs8_vadd_config.minmax.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vaddc_minmax_ukernel__wasmsimd_u32;
qs8_vadd_config.init.qs8_add = xnn_init_qs8_add_minmax_scalar_params;
qs8_vadd_config.minmax.element_tile = 32;
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
qs8_vadd_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vadd_minmax_ukernel__rvv_u2v;
qs8_vadd_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vaddc_minmax_ukernel__rvv_u2v;
qs8_vadd_config.minmax.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vaddc_minmax_ukernel__rvv_u2v;
qs8_vadd_config.init.qs8_add = xnn_init_qs8_add_minmax_scalar_params;
qs8_vadd_config.minmax.element_tile = hardware_config->vlenb;
#else
qs8_vadd_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vadd_minmax_ukernel__scalar_u4;
qs8_vadd_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vaddc_minmax_ukernel__scalar_u4;
Expand Down Expand Up @@ -1232,6 +1239,13 @@ static void init_qu8_vadd_config(void) {
qu8_vadd_config.minmax.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vaddc_minmax_ukernel__wasmsimd_u32;
qu8_vadd_config.init.qu8_add = xnn_init_qu8_add_minmax_scalar_params;
qu8_vadd_config.minmax.element_tile = 32;
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
qu8_vadd_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vadd_minmax_ukernel__rvv_u2v;
qu8_vadd_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vaddc_minmax_ukernel__rvv_u2v;
qu8_vadd_config.minmax.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vaddc_minmax_ukernel__rvv_u2v;
qu8_vadd_config.init.qu8_add = xnn_init_qu8_add_minmax_scalar_params;
qu8_vadd_config.minmax.element_tile = hardware_config->vlenb;
#else
qu8_vadd_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vadd_minmax_ukernel__scalar_u4;
qu8_vadd_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vaddc_minmax_ukernel__scalar_u4;
Expand Down
5 changes: 5 additions & 0 deletions src/qs8-vadd/qs8-vadd-minmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ XNN_UKERNEL_WITH_PARAMS(xnn_arch_hvx, xnn_qs8_vadd_minmax_ukernel__hvx_u96, 96,
XNN_UKERNEL_WITH_PARAMS(xnn_arch_hvx, xnn_qs8_vadd_minmax_ukernel__hvx_u128, 128, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
#endif // XNN_ENABLE_HVX && (XNN_ARCH_HEXAGON)

#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vadd_minmax_ukernel__rvv_u1v, 1, true, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vadd_minmax_ukernel__rvv_u2v, 2, true, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
#endif // XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR

XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vadd_minmax_ukernel__scalar_u1, 1, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vadd_minmax_ukernel__scalar_u2, 2, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vadd_minmax_ukernel__scalar_u4, 4, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
Expand Down
71 changes: 71 additions & 0 deletions src/qs8-vadd/rvv.c.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QS8", "QU8"]
$assert LMUL in [1, 2, 4, 8]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vbinary.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]

void xnn_${DATATYPE.lower()}_vadd_minmax_ukernel__rvv_u${LMUL}v(
size_t batch,
const ${XINT8_T}* input_a,
const ${XINT8_T}* input_b,
${XINT8_T}* output,
const struct xnn_${DATATYPE.lower()}_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(${XINT8_T}) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);

const int32_t bias = params->scalar.bias;
const int32_t a_multiplier = params->scalar.a_multiplier;
const int32_t b_multiplier = params->scalar.b_multiplier;
const uint32_t shift = params->scalar.shift;
const int32_t output_min = params->scalar.output_min;
const int32_t output_max = params->scalar.output_max;
const int32_t output_zero_point = params->scalar.output_zero_point;

do {
int32_t n = __riscv_vsetvl_e8m${LMUL}(batch); batch -= n;

$if DATATYPE == "QS8":
vint8m${LMUL}_t in_a_i8v = __riscv_vle8_v_i8m${LMUL}(input_a, n); input_a += n;
vint8m${LMUL}_t in_b_i8v = __riscv_vle8_v_i8m${LMUL}(input_b, n); input_b += n;
vint16m${LMUL*2}_t a_i16v = __riscv_vwcvt_x_x_v_i16m${LMUL*2}(in_a_i8v, n);
vint16m${LMUL*2}_t b_i16v = __riscv_vwcvt_x_x_v_i16m${LMUL*2}(in_b_i8v, n);
$else:
vuint8m${LMUL}_t in_a_u8v = __riscv_vle8_v_u8m${LMUL}(input_a, n); input_a += n;
vuint8m${LMUL}_t in_b_u8v = __riscv_vle8_v_u8m${LMUL}(input_b, n); input_b += n;
vuint16m${LMUL*2}_t a_u16v = __riscv_vwcvtu_x_x_v_u16m${LMUL*2}(in_a_u8v, n);
vuint16m${LMUL*2}_t b_u16v = __riscv_vwcvtu_x_x_v_u16m${LMUL*2}(in_b_u8v, n);
vint16m${LMUL*2}_t a_i16v = __riscv_vreinterpret_v_u16m${LMUL*2}_i16m${LMUL*2}(a_u16v);
vint16m${LMUL*2}_t b_i16v = __riscv_vreinterpret_v_u16m${LMUL*2}_i16m${LMUL*2}(b_u16v);
vint32m${LMUL*4}_t a_i32v = __riscv_vwcvt_x_x_v_i32m${LMUL*4}(a_i16v, n);
vint32m${LMUL*4}_t b_i32v = __riscv_vwcvt_x_x_v_i32m${LMUL*4}(b_i16v, n);
a_i32v = __riscv_vmul_vx_i32m${LMUL*4}(a_i32v, a_multiplier, n);
b_i32v = __riscv_vmul_vx_i32m${LMUL*4}(b_i32v, b_multiplier, n);
vint32m${LMUL*4}_t acc_i32v = __riscv_vadd_vx_i32m${LMUL*4}(a_i32v, bias, n);
acc_i32v = __riscv_vadd_vv_i32m${LMUL*4}(acc_i32v, b_i32v, n);
vint32m${LMUL*4}_t out_i32v = __riscv_vsra_vx_i32m${LMUL*4}(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m${LMUL*4}(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m${LMUL*4}(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m${LMUL*4}(out_i32v, output_max, n);
vint16m${LMUL*2}_t out_i16v = __riscv_vncvt_x_x_w_i16m${LMUL*2}(out_i32v, n);
$if DATATYPE == "QS8":
vint8m${LMUL}_t out_i8v = __riscv_vncvt_x_x_w_i8m${LMUL}(out_i16v, n);
__riscv_vse8_v_i8m${LMUL}(output, out_i8v, n); output += n;
$else:
a_u16v = __riscv_vreinterpret_v_i16m${LMUL*2}_u16m${LMUL*2}(out_i16v);
vuint8m${LMUL}_t out_u8v = __riscv_vncvt_x_x_w_u8m${LMUL}(a_u16v, n);
__riscv_vse8_v_u8m${LMUL}(output, out_u8v, n); output += n;
} while (batch != 0);
}
5 changes: 5 additions & 0 deletions src/qs8-vaddc/qs8-vaddc-minmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vaddc_minmax_ukernel__wasmsimd_u24, 24, false
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vaddc_minmax_ukernel__wasmsimd_u32, 32, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD

#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vaddc_minmax_ukernel__rvv_u1v, 1, true, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vaddc_minmax_ukernel__rvv_u2v, 2, true, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
#endif // XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR

XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vaddc_minmax_ukernel__scalar_u1, 1, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vaddc_minmax_ukernel__scalar_u2, 2, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vaddc_minmax_ukernel__scalar_u4, 4, false, int8_t, struct xnn_qs8_add_minmax_params, xnn_init_qs8_add_minmax_scalar_params)
Expand Down
62 changes: 62 additions & 0 deletions src/qs8-vaddc/rvv.c.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QS8", "QU8"]
$assert LMUL in [1, 2, 4, 8]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vbinary.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]

void xnn_${DATATYPE.lower()}_vaddc_minmax_ukernel__rvv_u${LMUL}v(
size_t batch,
const ${XINT8_T}* input_a,
const ${XINT8_T}* input_b,
${XINT8_T}* output,
const struct xnn_${DATATYPE.lower()}_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(${XINT8_T}) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);

const int32_t bias = params->scalar.bias + (int32_t) *input_b * params->scalar.b_multiplier;
const int32_t a_multiplier = params->scalar.a_multiplier;
const uint32_t shift = params->scalar.shift;
const int32_t output_min = params->scalar.output_min;
const int32_t output_max = params->scalar.output_max;
const int32_t output_zero_point = params->scalar.output_zero_point;

do {
int32_t n = __riscv_vsetvl_e8m${LMUL}(batch); batch -= n;

$if DATATYPE == "QS8":
vint8m${LMUL}_t in_a_i8v = __riscv_vle8_v_i8m${LMUL}(input_a, n); input_a += n;
vint16m${LMUL*2}_t a_i16v = __riscv_vwcvt_x_x_v_i16m${LMUL*2}(in_a_i8v, n);
$else:
vuint8m${LMUL}_t in_a_u8v = __riscv_vle8_v_u8m${LMUL}(input_a, n); input_a += n;
vuint16m${LMUL*2}_t a_u16v = __riscv_vwcvtu_x_x_v_u16m${LMUL*2}(in_a_u8v, n);
vint16m${LMUL*2}_t a_i16v = __riscv_vreinterpret_v_u16m${LMUL*2}_i16m${LMUL*2}(a_u16v);
vint32m${LMUL*4}_t a_i32v = __riscv_vwcvt_x_x_v_i32m${LMUL*4}(a_i16v, n);
a_i32v = __riscv_vmul_vx_i32m${LMUL*4}(a_i32v, a_multiplier, n);
vint32m${LMUL*4}_t acc_i32v = __riscv_vadd_vx_i32m${LMUL*4}(a_i32v, bias, n);
vint32m${LMUL*4}_t out_i32v = __riscv_vsra_vx_i32m${LMUL*4}(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m${LMUL*4}(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m${LMUL*4}(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m${LMUL*4}(out_i32v, output_max, n);
vint16m${LMUL*2}_t out_i16v = __riscv_vncvt_x_x_w_i16m${LMUL*2}(out_i32v, n);
$if DATATYPE == "QS8":
vint8m${LMUL}_t out_i8v = __riscv_vncvt_x_x_w_i8m${LMUL}(out_i16v, n);
__riscv_vse8_v_i8m${LMUL}(output, out_i8v, n); output += n;
$else:
a_u16v = __riscv_vreinterpret_v_i16m${LMUL*2}_u16m${LMUL*2}(out_i16v);
vuint8m${LMUL}_t out_u8v = __riscv_vncvt_x_x_w_u8m${LMUL}(a_u16v, n);
__riscv_vse8_v_u8m${LMUL}(output, out_u8v, n); output += n;
} while (batch != 0);
}
59 changes: 59 additions & 0 deletions src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u1v.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-vadd/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vbinary.h"


void xnn_qs8_vadd_minmax_ukernel__rvv_u1v(
size_t batch,
const int8_t* input_a,
const int8_t* input_b,
int8_t* output,
const struct xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(int8_t) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);

const int32_t bias = params->scalar.bias;
const int32_t a_multiplier = params->scalar.a_multiplier;
const int32_t b_multiplier = params->scalar.b_multiplier;
const uint32_t shift = params->scalar.shift;
const int32_t output_min = params->scalar.output_min;
const int32_t output_max = params->scalar.output_max;
const int32_t output_zero_point = params->scalar.output_zero_point;

do {
int32_t n = __riscv_vsetvl_e8m1(batch); batch -= n;

vint8m1_t in_a_i8v = __riscv_vle8_v_i8m1(input_a, n); input_a += n;
vint8m1_t in_b_i8v = __riscv_vle8_v_i8m1(input_b, n); input_b += n;
vint16m2_t a_i16v = __riscv_vwcvt_x_x_v_i16m2(in_a_i8v, n);
vint16m2_t b_i16v = __riscv_vwcvt_x_x_v_i16m2(in_b_i8v, n);
vint32m4_t a_i32v = __riscv_vwcvt_x_x_v_i32m4(a_i16v, n);
vint32m4_t b_i32v = __riscv_vwcvt_x_x_v_i32m4(b_i16v, n);
a_i32v = __riscv_vmul_vx_i32m4(a_i32v, a_multiplier, n);
b_i32v = __riscv_vmul_vx_i32m4(b_i32v, b_multiplier, n);
vint32m4_t acc_i32v = __riscv_vadd_vx_i32m4(a_i32v, bias, n);
acc_i32v = __riscv_vadd_vv_i32m4(acc_i32v, b_i32v, n);
vint32m4_t out_i32v = __riscv_vsra_vx_i32m4(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m4(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m4(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m4(out_i32v, output_max, n);
vint16m2_t out_i16v = __riscv_vncvt_x_x_w_i16m2(out_i32v, n);
vint8m1_t out_i8v = __riscv_vncvt_x_x_w_i8m1(out_i16v, n);
__riscv_vse8_v_i8m1(output, out_i8v, n); output += n;
} while (batch != 0);
}
59 changes: 59 additions & 0 deletions src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u2v.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-vadd/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vbinary.h"


void xnn_qs8_vadd_minmax_ukernel__rvv_u2v(
size_t batch,
const int8_t* input_a,
const int8_t* input_b,
int8_t* output,
const struct xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(int8_t) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);

const int32_t bias = params->scalar.bias;
const int32_t a_multiplier = params->scalar.a_multiplier;
const int32_t b_multiplier = params->scalar.b_multiplier;
const uint32_t shift = params->scalar.shift;
const int32_t output_min = params->scalar.output_min;
const int32_t output_max = params->scalar.output_max;
const int32_t output_zero_point = params->scalar.output_zero_point;

do {
int32_t n = __riscv_vsetvl_e8m2(batch); batch -= n;

vint8m2_t in_a_i8v = __riscv_vle8_v_i8m2(input_a, n); input_a += n;
vint8m2_t in_b_i8v = __riscv_vle8_v_i8m2(input_b, n); input_b += n;
vint16m4_t a_i16v = __riscv_vwcvt_x_x_v_i16m4(in_a_i8v, n);
vint16m4_t b_i16v = __riscv_vwcvt_x_x_v_i16m4(in_b_i8v, n);
vint32m8_t a_i32v = __riscv_vwcvt_x_x_v_i32m8(a_i16v, n);
vint32m8_t b_i32v = __riscv_vwcvt_x_x_v_i32m8(b_i16v, n);
a_i32v = __riscv_vmul_vx_i32m8(a_i32v, a_multiplier, n);
b_i32v = __riscv_vmul_vx_i32m8(b_i32v, b_multiplier, n);
vint32m8_t acc_i32v = __riscv_vadd_vx_i32m8(a_i32v, bias, n);
acc_i32v = __riscv_vadd_vv_i32m8(acc_i32v, b_i32v, n);
vint32m8_t out_i32v = __riscv_vsra_vx_i32m8(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m8(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m8(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m8(out_i32v, output_max, n);
vint16m4_t out_i16v = __riscv_vncvt_x_x_w_i16m4(out_i32v, n);
vint8m2_t out_i8v = __riscv_vncvt_x_x_w_i8m2(out_i16v, n);
__riscv_vse8_v_i8m2(output, out_i8v, n); output += n;
} while (batch != 0);
}
Loading

0 comments on commit a08c147

Please sign in to comment.