Skip to content

Commit

Permalink
x64: brgemm bwd_w convolution: update threading for small minibatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed Apr 17, 2023
1 parent 82cb7d3 commit 21bdc21
Showing 1 changed file with 155 additions and 0 deletions.
155 changes: 155 additions & 0 deletions src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2499,6 +2499,161 @@ void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) {
}
nthr_ic_b = jcp.nthr / (nthr_mb * nthr_oc_b);
nthr = nthr_mb * nthr_g * nthr_oc_b * nthr_ic_b;
} else if (is_amx(jcp.isa) && jcp.mb <= jcp.nthr / 2 && jcp.oc >= 64
&& jcp.ic >= 64 && jcp.ngroups == 1) {
// This heuristic is intended for usual convolutions if the minibatch
// is much less than the number of threads: it tries to divide the
// total amount of work into more-less 4-dimensional (by mb, g, oc, ic)
// "cubic" pieces
enum bwd_w_dims { g, ic, oc, sp };
constexpr int nd = 4;
// Keep maximum values for each dimension as a map
std::map<bwd_w_dims, int> maxv;
maxv.emplace(bwd_w_dims::g, jcp.ngroups);
maxv.emplace(bwd_w_dims::ic, div_up(jcp.nb_ic, 2));
maxv.emplace(bwd_w_dims::oc, div_up(jcp.nb_oc, 2));
maxv.emplace(bwd_w_dims::sp, jcp.mb * jcp.od * jcp.oh);

// Keep dimension values as a vector
std::vector<std::pair<double, bwd_w_dims>> dv;
const auto ks = jcp.kd * jcp.kh * jcp.kw;
double v = (jcp.ngroups > 1) ? static_cast<double>(jcp.ic) * jcp.oc
* jcp.ngroups * jcp.ngroups * ks
: 1;
dv.emplace_back(v, bwd_w_dims::g);
v = 5 * div_up(jcp.ic, jcp.amx_h) * ks;
dv.emplace_back(v, bwd_w_dims::ic);
v = 3 * div_up(jcp.oc, jcp.amx_h) * ks;
dv.emplace_back(v, bwd_w_dims::oc);
v = div_up(jcp.mb * jcp.od * jcp.oh * jcp.ow, jcp.amx_w);
dv.emplace_back(v, bwd_w_dims::sp);
// Estimate the size of "cubic" piece
double xd = 1;
for (int j = 0; j < nd; j++)
xd *= dv[j].first;
xd = pow(xd / jcp.nthr, 1.f / nd);
// Adjust piece to fit into dimensions
std::sort(dv.begin(), dv.end());
double tot_v = 1;
for (int i = 0; i < nd; i++) {
auto &dvf = dv[i].first;
const auto &dvs = dv[i].second;
const auto maxvf = static_cast<double>(maxv[dvs]);
if (dvf < xd) {
v = 1;
xd = 1;
for (int j = i + 1; j < nd; j++)
xd *= dv[j].first;
xd = pow(xd / jcp.nthr, 1.f / (nd - i - 1));
} else {
v = nstl::min(dvf / xd, maxvf);
}
tot_v *= v;
dvf = v;
}
std::sort(dv.begin(), dv.end());

// Normalize dimension values so product should be ~= nthr
double knorm = pow(jcp.nthr / tot_v, 1.f / nd);
tot_v = 1;
for (int i = 0; i < nd; i++) {
auto &dvf = dv[i].first;
auto &dvs = dv[i].second;
const auto maxvf = static_cast<double>(maxv[dvs]);
const auto new_dvf = dvf * knorm;
dvf = utils::saturate(1., maxvf, new_dvf);
knorm *= pow(new_dvf / dvf, 1.f / (nd - i - 1));
tot_v *= dvf;
}
std::sort(dv.begin(), dv.end());
knorm = jcp.nthr / tot_v;
for (int i = 0; i < nd; i++) {
auto &dvf = dv[i].first;
auto &dvs = dv[i].second;
const auto maxvf = static_cast<double>(maxv[dvs]);
const auto new_dvf = dvf * knorm;
dvf = utils::saturate(1., maxvf, new_dvf);
knorm = new_dvf / dvf;
}
std::sort(dv.begin(), dv.end());

// Selecting the number of threads for every dimension closest to what
// we defined before
auto calc_diff =
[&](const std::vector<std::pair<int, bwd_w_dims>> &cv) {
auto tot_n = 1;
double res = 1;
for (int i = 0; i < nd; i++) {
const auto nvf = dv[i].first;
const auto n = cv[i].first;
const auto v = maxv[cv[i].second];
const auto disb
= nvf * static_cast<double>(rnd_up(v, n)) / v;
const auto nf = static_cast<double>(n);
const auto var = ((nf > nvf) ? (nf / nvf) : (nvf / nf));
tot_n *= n;
res *= disb * var;
}
const auto thr_disb = static_cast<double>(jcp.nthr) / tot_n;
return res * thr_disb;
};

// nv: vector to keep result of selection
std::vector<std::pair<int, bwd_w_dims>> nv;
// Initial vector and estimation
for (int i = 0; i < nd; i++) {
const auto dvf = dv[i].first;
const auto dvs = dv[i].second;
const auto maxvf = maxv[dvs];
nv.emplace_back(
utils::saturate(1, maxvf, static_cast<int>(dvf + 0.5f)),
dvs);
}
nv[nd - 1].first = jcp.nthr / (nv[0].first * nv[1].first * nv[2].first);
double best_diff = calc_diff(nv);

// Iterate through all combinations of numbers
std::vector<std::pair<int, bwd_w_dims>> cv = nv;
const auto n0_max = jcp.nthr;
for (int n0 = 1; n0 <= n0_max; n0++) {
if (n0 > maxv[dv[0].second]) continue;
cv[0].first = n0;
const auto n1_max = n0_max / n0;
for (int n1 = 1; n1 <= n1_max; n1++) {
if (n1 > maxv[dv[1].second]) continue;
cv[1].first = n1;
const auto n2_max = n1_max / n1;
for (int n2 = 1; n2 <= n2_max; n2++) {
if (n2 > maxv[dv[2].second]) continue;
cv[2].first = n2;
const auto n3_max = n2_max / n2;
for (int n3 = n3_max; n3 >= 1; n3--) {
if (n3 > maxv[dv[3].second]) continue;
cv[3].first = n3;
const auto tot_n = n0 * n1 * n2 * n3;
const auto cdiff = calc_diff(cv);
if (cdiff < best_diff && tot_n <= jcp.nthr) {
best_diff = cdiff;
nv = cv;
}
}
}
}
}

for (size_t i = 0; i < nd; i++) {
const auto &nvf = nv[i].first;
const auto &nvs = nv[i].second;
if (nvs == bwd_w_dims::g)
nthr_g = nvf;
else if (nvs == bwd_w_dims::ic)
nthr_ic_b = nvf;
else if (nvs == bwd_w_dims::oc)
nthr_oc_b = nvf;
else if (nvs == bwd_w_dims::sp)
nthr_mb = nvf;
}
nthr = nthr_mb * nthr_g * nthr_oc_b * nthr_ic_b;
} else if (jcp.ngroups == 1 && (jcp.oc > 2048 || jcp.ic > 2048)) {
const bool more_oc = (jcp.ic < jcp.oc);
if (more_oc) {
Expand Down

0 comments on commit 21bdc21

Please sign in to comment.