Skip to content

Commit

Permalink
opt(armcommon): add int8 dot NCHW conv
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 65c3c219a3343438fbd1d60f0fc3f5f29b91cf11
  • Loading branch information
megvii-mge committed Jul 29, 2024
1 parent 76894c2 commit 959c301
Show file tree
Hide file tree
Showing 9 changed files with 2,053 additions and 50 deletions.
6 changes: 3 additions & 3 deletions benchmark/tools/cc_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def main(passed_args=None):
kernel_rate.append(v[1] * 100)

topK = 10
kernel_name = kernel_name[0:topK]
kernel_rate = kernel_rate[0:topK]
kernel_name_topk = kernel_name[0:topK]
kernel_rate_topk = kernel_rate[0:topK]
br1 = np.arange(len(kernel_name))
for i, j in zip(kernel_name, kernel_rate):
for i, j in zip(kernel_name_topk, kernel_rate_topk):
print("{} {}%".format(i, j))
if not args.no_figure:
plt.figure(figsize=(25, 6))
Expand Down
46 changes: 29 additions & 17 deletions compiler/lib/KernelGen/Arm/Arm64/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,16 @@ std::string ActivationGenAsmBase::GenAsmQuantStore(
CC_ASSERT(args_reg.size() >= 1);
reg_0 = args_reg[0];
} else if (mode == "H_SWISH") {
CC_ASSERT(int_regs.size() == 2) << "you need to impl int_reg == 1";
CC_ASSERT(args_reg.size() >= 6);
CC_ASSERT(
(int_regs.size() == 2 && args_reg.size() >= 6) ||
(int_regs.size() == 1 && args_reg.size() >= 5));
reg_0 = args_reg[0];
reg_3 = args_reg[1];
reg_6 = args_reg[2];
reg_6_inv = args_reg[3];
reg_t0 = args_reg[4];
reg_t1 = args_reg[5];
if (int_regs.size() == 2)
reg_t1 = args_reg[5];
}
std::stringstream temp_ss;
temp_ss << R"(
Expand All @@ -117,20 +119,30 @@ std::string ActivationGenAsmBase::GenAsmQuantStore(
"fmax ${int_reg_2}.4s, ${int_reg_2}.4s, ${zero_reg}.4s\n" )";
}
} else if (mode == "H_SWISH") {
CC_ASSERT(int_regs.size() == 2);
//! PERF: reorder below to improve perf
temp_ss << R"(
"fadd ${reg_t0}.4s, ${int_reg}.4s, ${reg_3}.4s\n"
"fadd ${reg_t1}.4s, ${int_reg_2}.4s, ${reg_3}.4s\n"
"fmax ${reg_t0}.4s, ${reg_t0}.4s, ${zero_reg}.4s\n"
"fmax ${reg_t1}.4s, ${reg_t1}.4s, ${zero_reg}.4s\n"
"fmin ${reg_t0}.4s, ${reg_t0}.4s, ${reg_6}.4s\n"
"fmin ${reg_t1}.4s, ${reg_t1}.4s, ${reg_6}.4s\n"
"fmul ${int_reg}.4s, ${reg_t0}.4s, ${int_reg}.4s\n"
"fmul ${int_reg_2}.4s, ${reg_t1}.4s, ${int_reg_2}.4s\n"
"fmul ${int_reg}.4s, ${int_reg}.4s, ${reg_6_inv}.4s\n"
"fmul ${int_reg_2}.4s, ${int_reg_2}.4s, ${reg_6_inv}.4s\n"
)";
if (int_regs.size() == 1) {
temp_ss << R"(
"fadd ${reg_t0}.4s, ${int_reg}.4s, ${reg_3}.4s\n"
"fmax ${reg_t0}.4s, ${reg_t0}.4s, ${zero_reg}.4s\n"
"fmin ${reg_t0}.4s, ${reg_t0}.4s, ${reg_6}.4s\n"
"fmul ${int_reg}.4s, ${reg_t0}.4s, ${int_reg}.4s\n"
"fmul ${int_reg}.4s, ${int_reg}.4s, ${reg_6_inv}.4s\n"
)";
} else {
CC_ASSERT(int_regs.size() == 2);
//! PERF: reorder below to improve perf
temp_ss << R"(
"fadd ${reg_t0}.4s, ${int_reg}.4s, ${reg_3}.4s\n"
"fadd ${reg_t1}.4s, ${int_reg_2}.4s, ${reg_3}.4s\n"
"fmax ${reg_t0}.4s, ${reg_t0}.4s, ${zero_reg}.4s\n"
"fmax ${reg_t1}.4s, ${reg_t1}.4s, ${zero_reg}.4s\n"
"fmin ${reg_t0}.4s, ${reg_t0}.4s, ${reg_6}.4s\n"
"fmin ${reg_t1}.4s, ${reg_t1}.4s, ${reg_6}.4s\n"
"fmul ${int_reg}.4s, ${reg_t0}.4s, ${int_reg}.4s\n"
"fmul ${int_reg_2}.4s, ${reg_t1}.4s, ${int_reg_2}.4s\n"
"fmul ${int_reg}.4s, ${int_reg}.4s, ${reg_6_inv}.4s\n"
"fmul ${int_reg_2}.4s, ${int_reg_2}.4s, ${reg_6_inv}.4s\n"
)";
}

} else {
CC_ASSERT(mode == "IDENTITY");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ std::string gen_im2col(TContext* ctx, TContext* inner_ctx) {
auto sh = ctx->getAttrInt("stride_h");
auto sw = ctx->getAttrInt("stride_w");
if (sh == sw && sw == 1) {
ss << nchw_im2col_s1_kern;
ss << gen_nchw_im2col_s1_kern(inner_ctx);
} else {
ss << nchw_im2col_kern;
ss << gen_nchw_im2col_kern(inner_ctx);
}
ss << nchw_pad_src_kern;
ss << gen_nchw_pad_src_kern(inner_ctx);
}
return ss.str();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ std::string gen_im2col(TContext* ctx, TContext* inner_ctx) {
auto sh = ctx->getAttrInt("stride_h");
auto sw = ctx->getAttrInt("stride_w");
if (sh == sw && sw == 1) {
ss << nchw_im2col_s1_kern;
ss << gen_nchw_im2col_s1_kern(inner_ctx);
} else {
ss << nchw_im2col_kern;
ss << gen_nchw_im2col_kern(inner_ctx);
}
ss << nchw_pad_src_kern;
ss << gen_nchw_pad_src_kern(inner_ctx);
}
return ss.str();
}
Expand Down Expand Up @@ -89,8 +89,8 @@ bool ConvIm2colDot::IsAvailable(TContext* ctx) const {
std::string dst_oprands = std::string("operand:") + std::to_string(nr_operands - 1);
bool param_value_ok =
ctx->getAttrUInt("dilate_h") == 1 && ctx->getAttrUInt("dilate_w") == 1;
bool param_mode_ok =
(fmt == "NCHW44_DOT") && ctx->getAttrStr("mode") == "CROSS_CORRELATION";
bool param_mode_ok = (fmt == "NCHW44_DOT" || fmt == "NCHW") &&
ctx->getAttrStr("mode") == "CROSS_CORRELATION";
bool noline_ok = !ctx->haveAttr("nonlineMode") ||
ctx->getAttrStr("nonlineMode") == "IDENTITY" ||
ctx->getAttrStr("nonlineMode") == "RELU" ||
Expand Down Expand Up @@ -174,7 +174,7 @@ std::string ConvIm2colDot::GetInitBody(TContext* ctx) const {

MatmulInternal* ConvIm2colDot::GetInnerCtxMatmul(TContext* ctx) const {
static MatmulInt8DotM8N12MK4Kernel inner_mk4_gemm;
static MatmulM8N12Kernel inner_gemm;
static MatmulInt8M8N12K4Kernel inner_gemm;
auto fmt = ctx->getAttrStr("format");
if (fmt == "NCHW44_DOT") {
return &inner_mk4_gemm;
Expand Down Expand Up @@ -214,12 +214,10 @@ std::string ConvIm2colDot::GetWorkspaceBodyCondition(TContext* ctx, bool jit) co
const uint32_t iw = in_layout.dims[3];
const Layout weight_layout = inputs[1]->layout;
uint32_t group = 1;
uint32_t oc = weight_layout.dims[0] * 4;
uint32_t fh = weight_layout.dims[2];
uint32_t fw = weight_layout.dims[3];
if (weight_layout.nr_dim == ${group_weight_dim}) {
group = weight_layout.dims[0];
oc = weight_layout.dims[1] * 4;
fh = weight_layout.dims[3];
fw = weight_layout.dims[4];
}
Expand Down
Loading

0 comments on commit 959c301

Please sign in to comment.