Skip to content

Commit

Permalink
fp32 alpha beta
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 23, 2024
1 parent ecf2e3b commit f70e5ef
Show file tree
Hide file tree
Showing 4 changed files with 846 additions and 446 deletions.
57 changes: 17 additions & 40 deletions src/layer/arm/gemm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Gemm_arm::Gemm_arm()
#endif // __ARM_NEON

#if NCNN_BF16
support_bf16_storage = true;
// support_bf16_storage = true;
#endif

nT = 0;
Expand Down Expand Up @@ -5303,7 +5303,7 @@ int Gemm_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
#endif // NCNN_BF16

#if NCNN_INT8
static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
// NCNN_LOGE("gemm_arm_int8");

Expand Down Expand Up @@ -5532,22 +5532,22 @@ static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
}
}

return 0;
}

static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
// NCNN_LOGE("gemm_AT_arm_int8");

Expand Down Expand Up @@ -5730,22 +5730,22 @@ static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat&
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
}
}

return 0;
}

static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
// NCNN_LOGE("gemm_BT_arm_int8");

Expand Down Expand Up @@ -5851,22 +5851,22 @@ static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, con
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
}
}

return 0;
}

static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& BT, float B_int8_scale, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& BT, float B_int8_scale, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
// NCNN_LOGE("gemm_AT_BT_arm_int8");

Expand Down Expand Up @@ -5925,14 +5925,14 @@ static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Ma
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
}
}
Expand Down Expand Up @@ -6028,29 +6028,6 @@ int Gemm_arm::create_pipeline_int8(const Option& opt)
{
CT_data = C_data;

#if __ARM_NEON
if (constant_broadcast_type_C == 3 && opt.use_packing_layout)
{
int C_elempack = constantM % 4 == 0 ? 4 : 1;
convert_packing(C_data, CT_data, C_elempack, opt);
}
#endif // __ARM_NEON

// pre-multiply C with beta
if (beta != 1.f)
{
Mat C2;
C2.create_like(CT_data);

const int size = CT_data.total() * CT_data.elempack;
for (int i = 0; i < size; i++)
{
C2[i] = CT_data[i] * beta;
}

CT_data = C2;
}

#if NCNN_BF16
if (support_bf16_storage && opt.use_bf16_storage)
{
Expand Down Expand Up @@ -6209,23 +6186,23 @@ int Gemm_arm::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat
int ret = 0;
if (constantA && constantB)
{
ret = gemm_AT_BT_arm_int8(AT_data, A_data_int8_scales, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_AT_BT_arm_int8(AT_data, A_data_int8_scales, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
}
else if (constantA)
{
const Mat& B = bottom_blobs[0];
ret = gemm_AT_arm_int8(AT_data, A_data_int8_scales, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_AT_arm_int8(AT_data, A_data_int8_scales, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
ret = gemm_BT_arm_int8(A, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_BT_arm_int8(A, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
ret = gemm_arm_int8(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_arm_int8(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
}

return ret;
Expand Down
8 changes: 4 additions & 4 deletions src/layer/arm/gemm_arm_asimddp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ void transpose_pack_B_tile_fp32_to_int8_asimddp(const Mat& B, Mat& BT, int j, in
transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales)
void unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales);
unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}

void transpose_unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales)
void transpose_unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
transpose_unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales);
transpose_unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}

void gemm_transB_packed_tile_int8_asimddp(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk)
Expand Down
Loading

0 comments on commit f70e5ef

Please sign in to comment.