Skip to content

Commit

Permalink
stash
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 25, 2024
1 parent f70e5ef commit 4b1b2b3
Show file tree
Hide file tree
Showing 4 changed files with 1,790 additions and 890 deletions.
16 changes: 8 additions & 8 deletions src/layer/arm/gemm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5530,14 +5530,14 @@ static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob
if (output_transpose)
{
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
Expand Down Expand Up @@ -5728,14 +5728,14 @@ static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat&
if (output_transpose)
{
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
Expand Down Expand Up @@ -5849,14 +5849,14 @@ static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, con
if (output_transpose)
{
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
Expand Down Expand Up @@ -5923,14 +5923,14 @@ static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Ma
if (output_transpose)
{
if (top_blob.elembits() == 16)
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
else
{
if (top_blob.elembits() == 16)
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales);
unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
else
unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta);
}
Expand Down
8 changes: 4 additions & 4 deletions src/layer/arm/gemm_arm_asimddp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, in
transpose_pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales)
void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales);
unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}

void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales)
void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales);
transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}
#endif // NCNN_BF16

Expand Down
Loading

0 comments on commit 4b1b2b3

Please sign in to comment.