Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize softmax cpu by parallel using openmp. #36

Merged
merged 3 commits into from
Nov 29, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions onnxruntime/core/providers/cpu/math/softmax_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,47 @@
#include "gsl/gsl_algorithm"
#include "gsl/gsl_util"

#if defined(_OPENMP)
#include <omp.h>
#endif

namespace onnxruntime {

common::Status SoftmaxCore(const int n,
const int d,
const float* Xdata,
float* Ydata,
const float* sum_multiplier,
float* rowmax) {
const int nd = n * d;

math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
// Exponentiation
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
return Status::OK();
}

static int GetParallelGroupCount(int n, int d) {
#if defined(_OPENMP)
int omp_num_threads = omp_get_num_threads();
int group_count = std::min(omp_num_threads, n);
if (group_count <= 1) return 1;

// 2048 * sizeof(float) is size of 2 cache page
static const int min_elements_per_group = 2048;
int max_groups = gsl::narrow_cast<int>((int64_t{n} * d + min_elements_per_group-1) / min_elements_per_group);

return std::min(group_count, max_groups);
#else
(void)n;
(void)d;
return 1;
#endif
}

common::Status SoftmaxCPU(const int64_t N,
const int64_t D,
const float* Xdata,
Expand All @@ -57,21 +96,24 @@ common::Status SoftmaxCPU(const int64_t N,

const int n = gsl::narrow_cast<int>(N);
const int d = gsl::narrow_cast<int>(D);
const int nd = gsl::narrow_cast<int>(N * D);

math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);

// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
int parallel_group_count = GetParallelGroupCount(n, d);
int n_per_group = (n + (parallel_group_count-1)) / parallel_group_count;

math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
#pragma omp parallel for
for (int i = 0; i < parallel_group_count; ++i) {
int s = n_per_group * i;
if (s < n) {
int c = (n - s >= n_per_group) ? n_per_group : (n-s);
SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), sum_multiplier, rowmax+s);
}
}

// Exponentiation
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
math::Gemv<float, CPUMathUtil>(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr);

// Do division
if (!logarithmic) {
#pragma omp parallel for
for (int i = 0; i < N; ++i) {
for (int j = 0; j < D; ++j) {
Ydata[i * D + j] /= scale[i];
Expand Down