Skip to content

Commit

Permalink
x64: brgemm bwd_w conv: use oh_block instead of oh for tr_diff_dst sc…
Browse files Browse the repository at this point in the history
…ratchpad
  • Loading branch information
ankalinin committed Apr 17, 2023
1 parent 017950a commit 4c34f89
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions src/cpu/x64/jit_brgemm_conv_bwd_w.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) {
brgattr.max_bottom_vpad = 0;

brgattr.LDA2 = jcp_.tr_iw * jcp_.ih_block * jcp_.id;
brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od;
brgattr.LDB2
= jcp_.tr_ow * jcp_.oc_block * jcp_.oh_block * jcp_.od;
brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw;
brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block
* jcp_.kd * jcp_.kh * jcp_.kw;
Expand Down Expand Up @@ -474,7 +475,7 @@ struct brgemm_convolution_bwd_weights_t::thread_info_t {

size_t tr_diff_dst_off(int g, int ocb, int od, int oh) const {
const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
const size_t tr_3d_size = tr_row_size * jcp.oh;
const size_t tr_3d_size = tr_row_size * jcp.oh_block;
int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking;
return tr_diff_dst_buf_number(g, ocb) * adj * jcp.tr_diff_dst_buf_size
+ od * tr_3d_size + oh * tr_row_size;
Expand Down Expand Up @@ -1026,7 +1027,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
+ (bs_id_s - id_s) * jcp.ih_block * jcp.tr_iw * jcp.ic_block;
const void *ptr_B = ((diff_dst_data_t *)p_dst)
+ (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block
+ (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block;
+ (bs_od_s - od_s) * jcp.oh_block * jcp.tr_ow * jcp.oc_block;
void *ptr_C = (jcp.transform_to_vnni)
? diff_wei + wei_offset_int(g, oc_b, ic_b, kd, kh, kw)
: diff_wei
Expand All @@ -1049,7 +1050,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
* jcp.ic_block * jcp.stride_d;
ti->brg_batch[odb * bs_h + ohb].ptr.B = (char *)ptr_B
+ ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block
+ odb * jcp.typesize_in * jcp.oh * jcp.tr_ow
+ odb * jcp.typesize_in * jcp.oh_block * jcp.tr_ow
* jcp.oc_block;
}
}
Expand Down Expand Up @@ -1127,8 +1128,8 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
&& (odb_s == od_s) && (iodb == odb_s)
&& (ohb_s == oh_s);
bp.dst = ((diff_dst_data_t *)p_dst)
+ (iodb - od_s) * jcp.oh * jcp.tr_ow
* jcp.oc_block
+ (iodb - od_s) * jcp.oh_block
* jcp.tr_ow * jcp.oc_block
+ (ohb_s - oh_s) * jcp.tr_ow
* jcp.oc_block;
(*diff_bias_kernel_)(&bp);
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2760,7 +2760,7 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
: jcp.nthr;
jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih_block * jcp.id;
jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh_block * jcp.od;

const int iframe_size = irow_size * jcp.id;
const int oframe_size = orow_size * jcp.od;
Expand Down

0 comments on commit 4c34f89

Please sign in to comment.