From be606fa6f87db0a4e9166b79f5d9cb86804525ba Mon Sep 17 00:00:00 2001 From: daquexian Date: Thu, 22 Aug 2019 17:13:26 +0800 Subject: [PATCH] use bgemm_naive instead of bconv_naive as fallback due to the 128-align weight --- dabnn/layers/BinConv.cpp | 57 +++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/dabnn/layers/BinConv.cpp b/dabnn/layers/BinConv.cpp index 27ad5ce..7481337 100644 --- a/dabnn/layers/BinConv.cpp +++ b/dabnn/layers/BinConv.cpp @@ -25,14 +25,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, stride_h(stride_h), stride_w(stride_w) { auto &mat_map = net.lock()->mat_map_; - const auto binaized_name = "binaized_for_" + output + "_cal"; - if (mat_map.find(binaized_name) == mat_map.end()) { - auto &input_mat = *mat_map[input]; - mat_map[binaized_name] = - std::make_shared(input_mat.h, input_mat.w, input_mat.elem_c, - DataType::Bit, binaized_name); + if (direct_conv_compatible()) { + const auto binaized_name = "binaized_for_" + output + "_cal"; + if (mat_map.find(binaized_name) == mat_map.end()) { + auto &input_mat = *mat_map[input]; + mat_map[binaized_name] = + std::make_shared(input_mat.h, input_mat.w, input_mat.elem_c, + DataType::Bit, binaized_name); + } + binarized_mat = mat(binaized_name); } - binarized_mat = mat(binaized_name); const auto pad_name = "pad_for_" + output + "_cal"; if (mat_map.find(pad_name) == mat_map.end()) { @@ -43,18 +45,18 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, } padded_mat = mat(pad_name); - const auto col_mat_name = "col_for_" + output + "_cal"; - if (mat_map.find(col_mat_name) == mat_map.end()) { - const auto len = - output_mat->h * output_mat->w * - align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128); - mat_map[col_mat_name] = - std::make_shared(1, 1, len, bnn::DataType::Bit); - } - col_mat = mat(col_mat_name); + if (net.lock()->optimize && !direct_conv_compatible()) { + + const auto col_mat_name = "col_for_" + output + "_cal"; + if (mat_map.find(col_mat_name) == mat_map.end()) { + const auto len = + output_mat->h * output_mat->w * + align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128); + mat_map[col_mat_name] = + std::make_shared(1, 1, len, bnn::DataType::Bit); + } + col_mat = mat(col_mat_name); - if (net.lock()->optimize && !direct_conv_compatible() && - gemm_compatible()) { const auto trans_weight_mat_name = "trans_" + weight; // transpose the weight for bgemm const int m = weight_mat->n; @@ -126,7 +128,7 @@ void BinConv::forward_impl() const { pack_mat(*input_mat, *binarized_mat); pad(*binarized_mat, pad_h, pad_w, *padded_mat); bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h); - } else if (gemm_compatible()) { + } else { output_mat->fill(0.f); bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w, @@ -136,14 +138,15 @@ void BinConv::forward_impl() const { const int m = weight_mat->n; const int n = output_mat->h * output_mat->w; const int k = weight_mat->total() / weight_mat->n; - bgemm(m, n, k, static_cast(transposed_weight_mat->data), - m, static_cast(col_mat->data), k, - static_cast(output_mat->data), m); - } else { - pack_mat(*input_mat, *binarized_mat); - baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h, - weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, - 1, output_mat->c, *output_mat); + if (gemm_compatible()) { + bgemm(m, n, k, static_cast(transposed_weight_mat->data), + m, static_cast(col_mat->data), k, + static_cast(output_mat->data), m); + } else { + bgemm_naive(m, n, k, static_cast(transposed_weight_mat->data), + m, static_cast(col_mat->data), k, + static_cast(output_mat->data), m); + } } } else { pack_mat(*input_mat, *binarized_mat);