Optimize softmax cpu by parallel using openmp.

This commit is contained in:
Lei Zhang 2018-11-27 11:04:42 -08:00
parent 84fa1018a3
commit e7bdfa00db

View file

@ -38,6 +38,29 @@
namespace onnxruntime {
common::Status SoftmaxCore(const int n,
const int d,
const float* Xdata,
float* Ydata,
float* scale,
const float* sum_multiplier,
bool logarithmic,
float* rowmax) {
const int nd = n * d;
math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
// Exponentiation
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
return Status::OK();
}
common::Status SoftmaxCPU(const int64_t N,
const int64_t D,
const float* Xdata,
@ -59,19 +82,23 @@ common::Status SoftmaxCPU(const int64_t N,
const int d = gsl::narrow_cast<int>(D);
const int nd = gsl::narrow_cast<int>(N * D);
math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);
static const int kGROUP = 8;
int g = (n + (kGROUP-1)) / kGROUP;
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
#pragma omp parallel for
for (int i = 0; i < kGROUP; ++i) {
int s = g * i;
if (s < n) {
int c = (n - s >= g)?g : (n-s);
SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), scale + s, sum_multiplier, logarithmic, rowmax+s);
}
}
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
// Exponentiation
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
math::Gemv<float, CPUMathUtil>(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr);
// Do division
if (!logarithmic) {
#pragma omp parallel for
for (int i = 0; i < N; ++i) {
for (int j = 0; j < D; ++j) {
Ydata[i * D + j] /= scale[i];