Skip to content

Commit

Permalink
Use vnclip+max instead of min+max+vncvt
Browse files Browse the repository at this point in the history
  • Loading branch information
KaustubhIMG committed Sep 25, 2024
1 parent bb3d9e9 commit 763c584
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/configs/binary-elementwise-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ static void init_qs8_vadd_config(void) {
qs8_vadd_config.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vadd_minmax_ukernel__rvv_u2v;
qs8_vadd_config.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vaddc_minmax_ukernel__rvv_u2v;
qs8_vadd_config.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vaddc_minmax_ukernel__rvv_u2v;
qs8_vadd_config.init = xnn_init_qs8_add_minmax_scalar_params;
qs8_vadd_config.init = (xnn_init_binary_params_fn) xnn_init_qs8_add_minmax_scalar_params;
qs8_vadd_config.element_tile = hardware_config->vlenb;
#else
qs8_vadd_config.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qs8_vadd_minmax_ukernel__scalar_u4;
Expand Down Expand Up @@ -1110,7 +1110,7 @@ static void init_qu8_vadd_config(void) {
qu8_vadd_config.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vadd_minmax_ukernel__rvv_u2v;
qu8_vadd_config.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vaddc_minmax_ukernel__rvv_u2v;
qu8_vadd_config.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vaddc_minmax_ukernel__rvv_u2v;
qu8_vadd_config.init = xnn_init_qu8_add_minmax_scalar_params;
qu8_vadd_config.init = (xnn_init_binary_params_fn) xnn_init_qu8_add_minmax_scalar_params;
qu8_vadd_config.element_tile = hardware_config->vlenb;
#else
qu8_vadd_config.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_qu8_vadd_minmax_ukernel__scalar_u4;
Expand Down
5 changes: 2 additions & 3 deletions src/qs8-vadd/rvv.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ void xnn_${DATATYPE.lower()}_vadd_minmax_ukernel__rvv_u${LMUL}v(
acc_i32v = __riscv_vadd_vv_i32m${LMUL*4}(acc_i32v, b_i32v, n);
vint32m${LMUL*4}_t out_i32v = __riscv_vsra_vx_i32m${LMUL*4}(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m${LMUL*4}(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m${LMUL*4}(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m${LMUL*4}(out_i32v, output_max, n);
vint16m${LMUL*2}_t out_i16v = __riscv_vncvt_x_x_w_i16m${LMUL*2}(out_i32v, n);
vint16m${LMUL*2}_t out_i16v = __riscv_vnclip_wx_i16m${LMUL*2}(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m${LMUL*2}(out_i16v, output_min, n);
$if DATATYPE == "QS8":
vint8m${LMUL}_t out_i8v = __riscv_vncvt_x_x_w_i8m${LMUL}(out_i16v, n);
__riscv_vse8_v_i8m${LMUL}(output, out_i8v, n); output += n;
Expand Down
5 changes: 2 additions & 3 deletions src/qs8-vaddc/rvv.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ void xnn_${DATATYPE.lower()}_vaddc_minmax_ukernel__rvv_u${LMUL}v(
vint32m${LMUL*4}_t acc_i32v = __riscv_vadd_vx_i32m${LMUL*4}(a_i32v, bias, n);
vint32m${LMUL*4}_t out_i32v = __riscv_vsra_vx_i32m${LMUL*4}(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m${LMUL*4}(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m${LMUL*4}(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m${LMUL*4}(out_i32v, output_max, n);
vint16m${LMUL*2}_t out_i16v = __riscv_vncvt_x_x_w_i16m${LMUL*2}(out_i32v, n);
vint16m${LMUL*2}_t out_i16v = __riscv_vnclip_wx_i16m${LMUL*2}(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m${LMUL*2}(out_i16v, output_min, n);
$if DATATYPE == "QS8":
vint8m${LMUL}_t out_i8v = __riscv_vncvt_x_x_w_i8m${LMUL}(out_i16v, n);
__riscv_vse8_v_i8m${LMUL}(output, out_i8v, n); output += n;
Expand Down
5 changes: 2 additions & 3 deletions src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ void xnn_qs8_vadd_minmax_ukernel__rvv_u1v(
acc_i32v = __riscv_vadd_vv_i32m4(acc_i32v, b_i32v, n);
vint32m4_t out_i32v = __riscv_vsra_vx_i32m4(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m4(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m4(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m4(out_i32v, output_max, n);
vint16m2_t out_i16v = __riscv_vncvt_x_x_w_i16m2(out_i32v, n);
vint16m2_t out_i16v = __riscv_vnclip_wx_i16m2(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m2(out_i16v, output_min, n);
vint8m1_t out_i8v = __riscv_vncvt_x_x_w_i8m1(out_i16v, n);
__riscv_vse8_v_i8m1(output, out_i8v, n); output += n;
} while (batch != 0);
Expand Down
5 changes: 2 additions & 3 deletions src/qs8-vmul/gen/qs8-vadd-minmax-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ void xnn_qs8_vadd_minmax_ukernel__rvv_u2v(
acc_i32v = __riscv_vadd_vv_i32m8(acc_i32v, b_i32v, n);
vint32m8_t out_i32v = __riscv_vsra_vx_i32m8(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m8(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m8(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m8(out_i32v, output_max, n);
vint16m4_t out_i16v = __riscv_vncvt_x_x_w_i16m4(out_i32v, n);
vint16m4_t out_i16v = __riscv_vnclip_wx_i16m4(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m4(out_i16v, output_min, n);
vint8m2_t out_i8v = __riscv_vncvt_x_x_w_i8m2(out_i16v, n);
__riscv_vse8_v_i8m2(output, out_i8v, n); output += n;
} while (batch != 0);
Expand Down
5 changes: 2 additions & 3 deletions src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ void xnn_qs8_vaddc_minmax_ukernel__rvv_u1v(
vint32m4_t acc_i32v = __riscv_vadd_vx_i32m4(a_i32v, bias, n);
vint32m4_t out_i32v = __riscv_vsra_vx_i32m4(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m4(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m4(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m4(out_i32v, output_max, n);
vint16m2_t out_i16v = __riscv_vncvt_x_x_w_i16m2(out_i32v, n);
vint16m2_t out_i16v = __riscv_vnclip_wx_i16m2(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m2(out_i16v, output_min, n);
vint8m1_t out_i8v = __riscv_vncvt_x_x_w_i8m1(out_i16v, n);
__riscv_vse8_v_i8m1(output, out_i8v, n); output += n;
} while (batch != 0);
Expand Down
5 changes: 2 additions & 3 deletions src/qs8-vmulc/gen/qs8-vaddc-minmax-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ void xnn_qs8_vaddc_minmax_ukernel__rvv_u2v(
vint32m8_t acc_i32v = __riscv_vadd_vx_i32m8(a_i32v, bias, n);
vint32m8_t out_i32v = __riscv_vsra_vx_i32m8(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m8(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m8(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m8(out_i32v, output_max, n);
vint16m4_t out_i16v = __riscv_vncvt_x_x_w_i16m4(out_i32v, n);
vint16m4_t out_i16v = __riscv_vnclip_wx_i16m4(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m4(out_i16v, output_min, n);
vint8m2_t out_i8v = __riscv_vncvt_x_x_w_i8m2(out_i16v, n);
__riscv_vse8_v_i8m2(output, out_i8v, n); output += n;
} while (batch != 0);
Expand Down
5 changes: 2 additions & 3 deletions src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ void xnn_qu8_vadd_minmax_ukernel__rvv_u1v(
acc_i32v = __riscv_vadd_vv_i32m4(acc_i32v, b_i32v, n);
vint32m4_t out_i32v = __riscv_vsra_vx_i32m4(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m4(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m4(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m4(out_i32v, output_max, n);
vint16m2_t out_i16v = __riscv_vncvt_x_x_w_i16m2(out_i32v, n);
vint16m2_t out_i16v = __riscv_vnclip_wx_i16m2(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m2(out_i16v, output_min, n);
a_u16v = __riscv_vreinterpret_v_i16m2_u16m2(out_i16v);
vuint8m1_t out_u8v = __riscv_vncvt_x_x_w_u8m1(a_u16v, n);
__riscv_vse8_v_u8m1(output, out_u8v, n); output += n;
Expand Down
5 changes: 2 additions & 3 deletions src/qu8-vmul/gen/qu8-vadd-minmax-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ void xnn_qu8_vadd_minmax_ukernel__rvv_u2v(
acc_i32v = __riscv_vadd_vv_i32m8(acc_i32v, b_i32v, n);
vint32m8_t out_i32v = __riscv_vsra_vx_i32m8(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m8(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m8(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m8(out_i32v, output_max, n);
vint16m4_t out_i16v = __riscv_vncvt_x_x_w_i16m4(out_i32v, n);
vint16m4_t out_i16v = __riscv_vnclip_wx_i16m4(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m4(out_i16v, output_min, n);
a_u16v = __riscv_vreinterpret_v_i16m4_u16m4(out_i16v);
vuint8m2_t out_u8v = __riscv_vncvt_x_x_w_u8m2(a_u16v, n);
__riscv_vse8_v_u8m2(output, out_u8v, n); output += n;
Expand Down
5 changes: 2 additions & 3 deletions src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ void xnn_qu8_vaddc_minmax_ukernel__rvv_u1v(
vint32m4_t acc_i32v = __riscv_vadd_vx_i32m4(a_i32v, bias, n);
vint32m4_t out_i32v = __riscv_vsra_vx_i32m4(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m4(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m4(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m4(out_i32v, output_max, n);
vint16m2_t out_i16v = __riscv_vncvt_x_x_w_i16m2(out_i32v, n);
vint16m2_t out_i16v = __riscv_vnclip_wx_i16m2(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m2(out_i16v, output_min, n);
a_u16v = __riscv_vreinterpret_v_i16m2_u16m2(out_i16v);
vuint8m1_t out_u8v = __riscv_vncvt_x_x_w_u8m1(a_u16v, n);
__riscv_vse8_v_u8m1(output, out_u8v, n); output += n;
Expand Down
5 changes: 2 additions & 3 deletions src/qu8-vmulc/gen/qu8-vaddc-minmax-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ void xnn_qu8_vaddc_minmax_ukernel__rvv_u2v(
vint32m8_t acc_i32v = __riscv_vadd_vx_i32m8(a_i32v, bias, n);
vint32m8_t out_i32v = __riscv_vsra_vx_i32m8(acc_i32v, shift, n);
out_i32v = __riscv_vadd_vx_i32m8(out_i32v, output_zero_point, n);
out_i32v = __riscv_vmax_vx_i32m8(out_i32v, output_min, n);
out_i32v = __riscv_vmin_vx_i32m8(out_i32v, output_max, n);
vint16m4_t out_i16v = __riscv_vncvt_x_x_w_i16m4(out_i32v, n);
vint16m4_t out_i16v = __riscv_vnclip_wx_i16m4(out_i32v, output_max, __RISCV_VXRM_RNE, n);
out_i16v = __riscv_vmax_vx_i16m4(out_i16v, output_min, n);
a_u16v = __riscv_vreinterpret_v_i16m4_u16m4(out_i16v);
vuint8m2_t out_u8v = __riscv_vncvt_x_x_w_u8m2(a_u16v, n);
__riscv_vse8_v_u8m2(output, out_u8v, n); output += n;
Expand Down

0 comments on commit 763c584

Please sign in to comment.