Skip to content

Commit

Permalink
x64: brgemm bwd_w convolution: update scratchpad data preparing
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed Apr 17, 2023
1 parent caead72 commit 796a600
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions src/cpu/x64/jit_brgemm_conv_bwd_w.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2022-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1358,15 +1358,26 @@ void brgemm_convolution_bwd_weights_t::prepare_scratchpad_data(
const auto &jcp = pd()->jcp_;

auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
// Zero out guard elements that cross a buffer boundary to prevent a
// race condition due to buffer overflows from memory optimization where
// buffers sharing padding
// TODO: optimize it
for (size_t isb = 1; isb <= jcp.tr_src_buf_count; ++isb) {
src_data_t *ts
= &tr_src[isb * jcp.tr_src_buf_size * jcp.nb_ic_blocking];
for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i)
ts[i] = 0;
const auto tr_src_full_size = jcp.tr_src_buf_size * jcp.nb_ic_blocking;
if (jcp.oh_block < jcp.oh || jcp.id > 1) {
// if (oh_block < oh) or (id > 1) then we zero all buffer because last
// elements position may vary depending on position of od_s, oh_block,
// padding and kh
parallel_nd(jcp.tr_src_buf_count, [&](size_t isb) {
src_data_t *ts = &tr_src[isb * tr_src_full_size];
std::memset(ts, 0, jcp.src_dsz * tr_src_full_size);
});
// Zero out last guard elements
src_data_t *ts = &tr_src[jcp.tr_src_buf_count * tr_src_full_size];
std::memset(ts, 0, jcp.src_dsz * jcp.tr_src_num_guard_elems);
} else {
// Zero out guard elements that cross a buffer boundary to prevent a
// race condition due to buffer overflows from memory optimization where
// buffers sharing padding
parallel_nd(jcp.tr_src_buf_count, [&](size_t isb) {
src_data_t *ts = &tr_src[(isb + 1) * tr_src_full_size];
std::memset(ts, 0, jcp.src_dsz * jcp.tr_src_num_guard_elems);
});
}

if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
Expand Down

0 comments on commit 796a600

Please sign in to comment.