Skip to content

Commit

Permalink
cpu: matmul: align allocation/usage logic for accumulation buffer size
Browse files Browse the repository at this point in the history
across all gemm-based implementations
  • Loading branch information
akharito committed Dec 7, 2022
1 parent 2955c9d commit 989acd3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 109 deletions.
42 changes: 26 additions & 16 deletions src/cpu/matmul/gemm_based_common.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2021 Intel Corporation
* Copyright 2019-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,7 +50,7 @@ struct params_t {

// indicates if src batch dims can be fused into M, so that a single
// GeMM call can be made
bool can_fuse_src_batch_dims_ = false;
bool use_single_gemm_call_optimization_ = false;

// an attribute for post processing kernel
primitive_attr_t pp_attr_;
Expand Down Expand Up @@ -111,13 +111,14 @@ inline bool check_gemm_binary_per_oc_compatible_formats(const matmul_pd_t &pd) {
return ok && strides[0] == utils::array_product(dims + 1, ndims - 1);
}

inline size_t get_scratchpad_size(const dim_t batch, dim_t M, const dim_t N,
const bool can_fuse_src_batch_dims, const int nthr) {
inline size_t get_scratchpad_block_elements(const dim_t batch, dim_t M,
const dim_t N, const bool use_single_gemm_call_optimization,
const int nthr) {
assert(batch > 0);
assert(M > 0);
assert(N > 0);
size_t buffer_size;
if (can_fuse_src_batch_dims || batch == 1) {
if (use_single_gemm_call_optimization) {
buffer_size = (size_t)batch * M * N;
} else {
const size_t work_per_thr = utils::div_up((size_t)batch * M * N, nthr);
Expand All @@ -131,20 +132,29 @@ inline size_t get_scratchpad_size(const dim_t batch, dim_t M, const dim_t N,
return utils::rnd_up(buffer_size, 64);
}

inline size_t get_scratchpad_num_elements(const dim_t batch, dim_t M,
const dim_t N, const bool use_single_gemm_call_optimization,
const int nthr) {
const int num_scratchpad_blocks
= use_single_gemm_call_optimization ? 1 : nthr;
return get_scratchpad_block_elements(
batch, M, N, use_single_gemm_call_optimization, nthr)
* num_scratchpad_blocks;
}

inline void book_acc_scratchpad(matmul_pd_t &pd, const params_t &params,
size_t sizeof_acc_data, const int nthr) {

if (!params.dst_is_acc_
&& !memory_desc_wrapper(pd.dst_md()).has_runtime_dims()) {
const size_t buffer_size = get_scratchpad_size(pd.batch(), pd.M(),
pd.N(), params.can_fuse_src_batch_dims_, nthr);
const size_t sp_size = params.can_fuse_src_batch_dims_
? buffer_size
: buffer_size * nthr;
auto scratchpad = pd.scratchpad_registry().registrar();
scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt,
sp_size, sizeof_acc_data);
}
if (params.dst_is_acc_) return; // scratchpad buffer is not required

// scratchpad buffer must be allocated on execution stage
if (pd.has_runtime_dims_or_strides()) return;

const size_t buffer_size = get_scratchpad_num_elements(pd.batch(), pd.M(),
pd.N(), params.use_single_gemm_call_optimization_, nthr);
auto scratchpad = pd.scratchpad_registry().registrar();
scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt,
buffer_size, sizeof_acc_data);
}

} // namespace gemm_based
Expand Down
36 changes: 10 additions & 26 deletions src/cpu/matmul/gemm_bf16_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,9 @@ status_t gemm_bf16_matmul_t<dst_type>::execute_ref(
const int nthr = pd()->nthr_;

const gemm_based::params_t &params = pd()->params();
const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
? helper.can_fuse_src_batch_dims()
: params.can_fuse_src_batch_dims_;
const dim_t acc_stride = gemm_based::get_scratchpad_size(
batch, M, N, can_fuse_src_batch_dims, nthr);
const bool use_single_gemm_call = pd()->has_runtime_dims_or_strides()
? helper.use_single_gemm_call_optimization(po)
: params.use_single_gemm_call_optimization_;
bool dst_is_acc = params.dst_is_acc_;
acc_data_t *acc = dst_is_acc
? (acc_data_t *)dst
Expand All @@ -202,9 +200,10 @@ status_t gemm_bf16_matmul_t<dst_type>::execute_ref(
// case: dynamic sizes
bool need_free_acc = false;
if (acc == nullptr) {
acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride
* ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
64);
const size_t buf_elements = gemm_based::get_scratchpad_num_elements(
batch, M, N, use_single_gemm_call, nthr);
acc = (acc_data_t *)malloc(sizeof(acc_data_t) * buf_elements, 64);

if (acc == nullptr) return status::out_of_memory;
need_free_acc = true;
}
Expand All @@ -216,24 +215,7 @@ status_t gemm_bf16_matmul_t<dst_type>::execute_ref(
= this->pd()->attr()->output_scales_.mask_ == (1 << (ndims - 1));

std::atomic<status_t> st(status::success);
// use parallel over batch when binary po with channel bcast
// (except batch == 1)
bool is_binary_po_per_oc;
bool is_binary_po_per_oc_sp;
bool is_binary_po_channel_bcast;
std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
is_binary_po_channel_bcast)
= bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial);
// if batched, parralel over batch for per_mb_sp and per_oc binary
// post-op broadcast
const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
&& IMPLICATION(
is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
if (!use_single_gemm_call) {
const int src_mask
= utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
const int wei_mask
Expand All @@ -243,6 +225,8 @@ status_t gemm_bf16_matmul_t<dst_type>::execute_ref(
: types::data_type_size(pd()->weights_md(1)->data_type);
const size_t work_amount = (size_t)batch * M * N;
const size_t work_per_batch = (size_t)M * N;
const dim_t acc_stride = gemm_based::get_scratchpad_block_elements(
batch, M, N, use_single_gemm_call, nthr);

// NOTE: inside lambda, type cast variables captured by reference using
// either c-like "(type)var" or functional "type(var)" notation in order
Expand Down
41 changes: 13 additions & 28 deletions src/cpu/matmul/gemm_f32_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) {
if (!ok) return status::unimplemented;

if (!has_runtime_dims_or_strides())
params_.can_fuse_src_batch_dims_
params_.use_single_gemm_call_optimization_
= matmul_helper_t(src_md(), weights_md(), dst_md())
.can_fuse_src_batch_dims();
.use_single_gemm_call_optimization(attr()->post_ops_);

CHECK(check_and_configure_attributes());

Expand Down Expand Up @@ -193,11 +193,9 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
const gemm_based::params_t &params = pd()->params();
const float alpha = params.get_gemm_alpha(scales);
const float beta = params.gemm_beta_;
const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
? helper.can_fuse_src_batch_dims()
: params.can_fuse_src_batch_dims_;
const dim_t acc_stride = gemm_based::get_scratchpad_size(
batch, M, N, can_fuse_src_batch_dims, nthr);
const bool use_single_gemm_call = pd()->has_runtime_dims_or_strides()
? helper.use_single_gemm_call_optimization(po)
: params.use_single_gemm_call_optimization_;
bool dst_is_acc = params.dst_is_acc_;
acc_data_t *acc = dst_is_acc
? (acc_data_t *)dst
Expand All @@ -206,9 +204,10 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
// case: dynamic sizes
bool need_free_acc = false;
if (acc == nullptr) {
acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride
* ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
64);
const size_t buf_elements = gemm_based::get_scratchpad_num_elements(
batch, M, N, use_single_gemm_call, nthr);
acc = (acc_data_t *)malloc(sizeof(acc_data_t) * buf_elements, 64);

if (acc == nullptr) return status::out_of_memory;
need_free_acc = true;
}
Expand All @@ -218,24 +217,7 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
= this->pd()->attr()->output_scales_.mask_ == (1 << (ndims - 1));

std::atomic<status_t> st(status::success);
// use parallel over batch when binary po with channel bcast
// (except batch == 1)
bool is_binary_po_per_oc = false;
bool is_binary_po_per_oc_sp = false;
bool is_binary_po_channel_bcast = false;
std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
is_binary_po_channel_bcast)
= bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial);
// if batched, parralel over batch for per_mb_sp and per_oc binary
// post-op broadcast
const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
&& IMPLICATION(
is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
if (!use_single_gemm_call) {
const int src_mask
= utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
const int wei_mask
Expand All @@ -245,6 +227,9 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
: types::data_type_size(pd()->weights_md(1)->data_type);
const size_t work_amount = (size_t)batch * M * N;
const size_t work_per_batch = (size_t)M * N;
const dim_t acc_stride = gemm_based::get_scratchpad_block_elements(
batch, M, N, use_single_gemm_call, nthr);

parallel(nthr, [&](int ithr, int nthr) {
size_t t_work_start {0}, t_work_end {0};
balance211(work_amount, nthr, ithr, t_work_start, t_work_end);
Expand Down
35 changes: 9 additions & 26 deletions src/cpu/matmul/gemm_x8s8s32x_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,9 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
const int nthr = pd()->nthr_;

const gemm_based::params_t &params = pd()->params();
const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
? helper.can_fuse_src_batch_dims()
: params.can_fuse_src_batch_dims_;
const dim_t acc_stride = gemm_based::get_scratchpad_size(
batch, M, N, can_fuse_src_batch_dims, nthr);
const bool use_single_gemm_call = pd()->has_runtime_dims_or_strides()
? helper.use_single_gemm_call_optimization(po)
: params.use_single_gemm_call_optimization_;
bool dst_is_acc = params.dst_is_acc_;
int32_t *acc = dst_is_acc
? reinterpret_cast<int32_t *>(dst)
Expand All @@ -233,9 +231,9 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
// case: dynamic sizes
bool need_free_acc = false;
if (acc == nullptr) {
acc = (int32_t *)malloc(sizeof(int32_t) * acc_stride
* ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
64);
const size_t buf_elements = gemm_based::get_scratchpad_num_elements(
batch, M, N, use_single_gemm_call, nthr);
acc = (int32_t *)malloc(sizeof(int32_t) * buf_elements, 64);

if (acc == nullptr) return status::out_of_memory;
need_free_acc = true;
Expand All @@ -248,24 +246,7 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
= this->pd()->attr()->output_scales_.mask_ == (1 << (ndims - 1));

std::atomic<status_t> st(status::success);
// use parallel over batch when binary po with channel bcast
// (except batch == 1)
bool is_binary_po_per_oc = false;
bool is_binary_po_per_oc_sp = false;
bool is_binary_po_channel_bcast = false;
std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
is_binary_po_channel_bcast)
= bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial);
// if batched, parralel over batch for per_mb_sp and per_oc binary
// post-op broadcast
const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
&& IMPLICATION(
is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
if (!use_single_gemm_call) {
const int src_mask
= utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
const int wei_mask
Expand All @@ -276,6 +257,8 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
const size_t work_amount = (size_t)batch * M * N;
const size_t work_per_batch = (size_t)M * N;
const dim_t acc_stride = gemm_based::get_scratchpad_block_elements(
batch, M, N, use_single_gemm_call, nthr);

// NOTE: inside lambda, type cast variables captured by reference using
// either c-like "(type)var" or functional "type(var)" notation in order
Expand Down
48 changes: 35 additions & 13 deletions src/cpu/matmul/matmul_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020 Intel Corporation
* Copyright 2020-2022 Intel Corporation
* Copyright 2022 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,6 +21,8 @@
#include "common/memory_desc_wrapper.hpp"
#include "common/utils.hpp"

#include "cpu/binary_injector_utils.hpp"

namespace dnnl {
namespace impl {
namespace cpu {
Expand Down Expand Up @@ -73,10 +75,36 @@ struct matmul_helper_t {

dim_t ldc() const { return dst_md_.blocking_desc().strides[ndims() - 2]; }

bool use_single_gemm_call_optimization(const post_ops_t &post_ops) {
using namespace binary_injector_utils;
bool is_binary_po_per_oc;
bool is_binary_po_per_oc_sp;
bool is_binary_po_channel_bcast;
std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
is_binary_po_channel_bcast)
= bcast_strategies_present_tup(post_ops.entry_, dst_md_,
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial);

const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
&& IMPLICATION(is_binary_po_per_oc || is_binary_po_per_oc_sp,
ndims() == 2);

// single GeMM call can be made, avoid parallelization over GeMM calls
return can_use_po_with_fused_batch && can_fuse_src_batch_dims();
}

private:
mdw_t src_md_;
mdw_t weights_md_;
mdw_t dst_md_;

// TODO similar optimization is also possible for wei batch fusion.
bool can_fuse_src_batch_dims() const {
/* Note:
We can fuse src batch dims so that a single GeMM can be used iff
We can fuse src batch dims so that a single GeMM can be used if
0. always for batch = 1 case
1. src is not transposed
2. wei batch dims are all 1's
3. The strides in batch dims are trivial (allowing permutations).
Expand All @@ -91,19 +119,18 @@ struct matmul_helper_t {
A single GeMM call can be used instead with m = a*d*c*b*m
*/
// Note 0:
if (batch() == 1) return true;

// Note 1:
if (transA() == 'T') return false;

const int n_dims = ndims();
const int batch_ndims = n_dims - 2;
if (batch_ndims == 0) return true;

// Note 2:
if (utils::array_product(weights_md_.dims(), batch_ndims) != 1)
return false;
if (wei_batch() != 1) return false;

// determine batch dims layout
dims_t src_strides;
const int batch_ndims = ndims() - 2;
utils::array_copy(
src_strides, src_md_.blocking_desc().strides, batch_ndims);

Expand Down Expand Up @@ -137,11 +164,6 @@ struct matmul_helper_t {

return true;
}

private:
mdw_t src_md_;
mdw_t weights_md_;
mdw_t dst_md_;
};

} // namespace matmul
Expand Down

0 comments on commit 989acd3

Please sign in to comment.