diff --git a/src/layer/vulkan/multiheadattention_vulkan.cpp b/src/layer/vulkan/multiheadattention_vulkan.cpp index 1abc09c30e6..b8cfd399cf2 100644 --- a/src/layer/vulkan/multiheadattention_vulkan.cpp +++ b/src/layer/vulkan/multiheadattention_vulkan.cpp @@ -43,6 +43,19 @@ MultiHeadAttention_vulkan::MultiHeadAttention_vulkan() pipeline_multiheadattention_qkv_cross_pack4to1 = 0; } +int MultiHeadAttention_vulkan::load_param(const ParamDict& pd) +{ + int ret = MultiHeadAttention::load_param(pd); + + if (int8_scale_term) + { + support_vulkan = false; + support_image_storage = false; + } + + return ret; +} + int MultiHeadAttention_vulkan::create_pipeline(const Option& opt) { const int embed_dim_per_head = embed_dim / num_heads; diff --git a/src/layer/vulkan/multiheadattention_vulkan.h b/src/layer/vulkan/multiheadattention_vulkan.h index 3b77d96db48..58e06bfc191 100644 --- a/src/layer/vulkan/multiheadattention_vulkan.h +++ b/src/layer/vulkan/multiheadattention_vulkan.h @@ -24,6 +24,8 @@ class MultiHeadAttention_vulkan : public MultiHeadAttention public: MultiHeadAttention_vulkan(); + virtual int load_param(const ParamDict& pd); + virtual int create_pipeline(const Option& opt); virtual int destroy_pipeline(const Option& opt); diff --git a/src/layer/x86/multiheadattention_x86.cpp b/src/layer/x86/multiheadattention_x86.cpp index 9bddb3a78ef..da5ac4022c6 100644 --- a/src/layer/x86/multiheadattention_x86.cpp +++ b/src/layer/x86/multiheadattention_x86.cpp @@ -36,8 +36,26 @@ MultiHeadAttention_x86::MultiHeadAttention_x86() o_gemm = 0; } -int MultiHeadAttention_x86::create_pipeline(const Option& opt) +int MultiHeadAttention_x86::create_pipeline(const Option& _opt) { + Option opt = _opt; + if (int8_scale_term) + { + support_packing = false; + + opt.use_packing_layout = false;// TODO enable packing + } + + { + qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); + ncnn::ParamDict pd; + pd.set(0, -1); + pd.set(1, 1); + qk_softmax->load_param(pd); + qk_softmax->load_model(ModelBinFromMatArray(0)); + qk_softmax->create_pipeline(opt); + } + const int qdim = weight_data_size / embed_dim; { @@ -57,10 +75,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif q_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = q_weight_data; weights[1] = q_bias_data; +#if NCNN_INT8 + weights[2] = q_weight_data_int8_scales; +#endif q_gemm->load_model(ModelBinFromMatArray(weights)); q_gemm->create_pipeline(opt); @@ -86,10 +110,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif k_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = k_weight_data; weights[1] = k_bias_data; +#if NCNN_INT8 + weights[2] = k_weight_data_int8_scales; +#endif k_gemm->load_model(ModelBinFromMatArray(weights)); k_gemm->create_pipeline(opt); @@ -115,10 +145,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif v_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = v_weight_data; weights[1] = v_bias_data; +#if NCNN_INT8 + weights[2] = v_weight_data_int8_scales; +#endif v_gemm->load_model(ModelBinFromMatArray(weights)); v_gemm->create_pipeline(opt); @@ -129,6 +165,41 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) } } + { + o_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 1); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 1); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M = outch + pd.set(8, qdim); // N = size + pd.set(9, embed_dim); // K = maxk*inch + pd.set(10, 4); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif + o_gemm->load_param(pd); + Mat weights[3]; + weights[0] = out_weight_data; + weights[1] = out_bias_data; +#if NCNN_INT8 + Mat out_weight_data_int8_scales(1); + out_weight_data_int8_scales[0] = out_weight_data_int8_scale; + weights[2] = out_weight_data_int8_scales; +#endif + o_gemm->load_model(ModelBinFromMatArray(weights)); + o_gemm->create_pipeline(opt); + + if (opt.lightmode) + { + out_weight_data.release(); + out_bias_data.release(); + } + } + { qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; @@ -143,12 +214,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif qk_gemm->load_param(pd); qk_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; opt1.num_threads = 1; qk_gemm->create_pipeline(opt1); } + { qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; @@ -164,6 +239,9 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 1); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif qkv_gemm->load_param(pd); qkv_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; @@ -171,48 +249,24 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) qkv_gemm->create_pipeline(opt1); } + return 0; +} + +int MultiHeadAttention_x86::destroy_pipeline(const Option& _opt) +{ + Option opt = _opt; + if (int8_scale_term) { - qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); - ncnn::ParamDict pd; - pd.set(0, -1); - pd.set(1, 1); - qk_softmax->load_param(pd); - qk_softmax->load_model(ModelBinFromMatArray(0)); - qk_softmax->create_pipeline(opt); + opt.use_packing_layout = false;// TODO enable packing } + if (qk_softmax) { - o_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(2, 1); // transA - pd.set(3, 1); // transB - pd.set(4, 0); // constantA - pd.set(5, 1); // constantB - pd.set(6, 1); // constantC - pd.set(7, 0); // M = outch - pd.set(8, qdim); // N = size - pd.set(9, embed_dim); // K = maxk*inch - pd.set(10, 4); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - o_gemm->load_param(pd); - Mat weights[2]; - weights[0] = out_weight_data; - weights[1] = out_bias_data; - o_gemm->load_model(ModelBinFromMatArray(weights)); - o_gemm->create_pipeline(opt); - - if (opt.lightmode) - { - out_weight_data.release(); - out_bias_data.release(); - } + qk_softmax->destroy_pipeline(opt); + delete qk_softmax; + qk_softmax = 0; } - return 0; -} - -int MultiHeadAttention_x86::destroy_pipeline(const Option& opt) -{ if (q_gemm) { q_gemm->destroy_pipeline(opt); @@ -234,6 +288,13 @@ int MultiHeadAttention_x86::destroy_pipeline(const Option& opt) v_gemm = 0; } + if (o_gemm) + { + o_gemm->destroy_pipeline(opt); + delete o_gemm; + o_gemm = 0; + } + if (qk_gemm) { qk_gemm->destroy_pipeline(opt); @@ -247,30 +308,22 @@ int MultiHeadAttention_x86::destroy_pipeline(const Option& opt) qkv_gemm = 0; } - if (qk_softmax) - { - qk_softmax->destroy_pipeline(opt); - delete qk_softmax; - qk_softmax = 0; - } - - if (o_gemm) - { - o_gemm->destroy_pipeline(opt); - delete o_gemm; - o_gemm = 0; - } - return 0; } -int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { const Mat& q_blob = bottom_blobs[0]; const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1]; const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat(); + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false;// TODO enable packing + } + Mat attn_mask_blob_unpacked; if (attn_mask && attn_mask_blob.elempack != 1) {