Skip to content

Commit

Permalink
Fix Incorrect output for A or B with dim=1 in GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
w8501 committed Jul 28, 2023
1 parent c45c01c commit 509f84a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 72 deletions.
52 changes: 26 additions & 26 deletions src/layer/arm/gemm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3785,9 +3785,9 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c

static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -3840,8 +3840,8 @@ static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -3899,7 +3899,7 @@ static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int

static int gemm_AT_arm(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -3994,7 +3994,7 @@ static int gemm_AT_arm(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob,

static int gemm_BT_arm(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -4018,8 +4018,8 @@ static int gemm_BT_arm(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob,
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -4329,20 +4329,20 @@ int Gemm_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down Expand Up @@ -4502,9 +4502,9 @@ int Gemm_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
#if NCNN_BF16
static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -4557,8 +4557,8 @@ static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -4617,7 +4617,7 @@ static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo

static int gemm_AT_arm_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -4713,7 +4713,7 @@ static int gemm_AT_arm_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top

static int gemm_BT_arm_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -4737,8 +4737,8 @@ static int gemm_BT_arm_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -5001,20 +5001,20 @@ int Gemm_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down
26 changes: 13 additions & 13 deletions src/layer/arm/gemm_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2344,9 +2344,9 @@ static void get_optimal_tile_mnk_fp16sa(int M, int N, int K, int constant_TILE_M

static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -2399,8 +2399,8 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -2458,7 +2458,7 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl

static int gemm_AT_arm_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -2553,7 +2553,7 @@ static int gemm_AT_arm_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& to

static int gemm_BT_arm_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -2577,8 +2577,8 @@ static int gemm_BT_arm_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& to
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -2835,20 +2835,20 @@ int Gemm_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down
26 changes: 13 additions & 13 deletions src/layer/arm/gemm_arm_vfpv4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ extern void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max

static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -86,8 +86,8 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -146,7 +146,7 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo

static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand Down Expand Up @@ -242,7 +242,7 @@ static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top

static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);

// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);

Expand All @@ -266,8 +266,8 @@ static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
const int transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1);

const int max_ii = std::min((M - i), TILE_M);

Expand Down Expand Up @@ -530,20 +530,20 @@ int Gemm_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
{
const Mat& B = bottom_blobs[0];
M = constantM;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = constantN;
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;
M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack);
N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1);
}

Mat C;
Expand Down
Loading

0 comments on commit 509f84a

Please sign in to comment.