mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Merge pull request #36 from Microsoft/zhalei/softmax_optimize
Optimize softmax cpu by parallel using openmp.
This commit is contained in:
commit
cd1042c94f
1 changed files with 50 additions and 8 deletions
|
|
@ -36,8 +36,47 @@
|
|||
#include "gsl/gsl_algorithm"
|
||||
#include "gsl/gsl_util"
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
common::Status SoftmaxCore(const int n,
|
||||
const int d,
|
||||
const float* Xdata,
|
||||
float* Ydata,
|
||||
const float* sum_multiplier,
|
||||
float* rowmax) {
|
||||
const int nd = n * d;
|
||||
|
||||
math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);
|
||||
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
|
||||
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
|
||||
// Exponentiation
|
||||
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static int GetParallelGroupCount(int n, int d) {
|
||||
#if defined(_OPENMP)
|
||||
int omp_num_threads = omp_get_num_threads();
|
||||
int group_count = std::min(omp_num_threads, n);
|
||||
if (group_count <= 1) return 1;
|
||||
|
||||
// 2048 * sizeof(float) is size of 2 cache page
|
||||
static const int min_elements_per_group = 2048;
|
||||
int max_groups = gsl::narrow_cast<int>((int64_t{n} * d + min_elements_per_group-1) / min_elements_per_group);
|
||||
|
||||
return std::min(group_count, max_groups);
|
||||
#else
|
||||
(void)n;
|
||||
(void)d;
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
common::Status SoftmaxCPU(const int64_t N,
|
||||
const int64_t D,
|
||||
const float* Xdata,
|
||||
|
|
@ -57,21 +96,24 @@ common::Status SoftmaxCPU(const int64_t N,
|
|||
|
||||
const int n = gsl::narrow_cast<int>(N);
|
||||
const int d = gsl::narrow_cast<int>(D);
|
||||
const int nd = gsl::narrow_cast<int>(N * D);
|
||||
|
||||
math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);
|
||||
int parallel_group_count = GetParallelGroupCount(n, d);
|
||||
int n_per_group = (n + (parallel_group_count-1)) / parallel_group_count;
|
||||
|
||||
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
|
||||
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < parallel_group_count; ++i) {
|
||||
int s = n_per_group * i;
|
||||
if (s < n) {
|
||||
int c = (n - s >= n_per_group) ? n_per_group : (n-s);
|
||||
SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), sum_multiplier, rowmax+s);
|
||||
}
|
||||
}
|
||||
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
|
||||
|
||||
// Exponentiation
|
||||
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
|
||||
math::Gemv<float, CPUMathUtil>(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr);
|
||||
|
||||
// Do division
|
||||
if (!logarithmic) {
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int j = 0; j < D; ++j) {
|
||||
Ydata[i * D + j] /= scale[i];
|
||||
|
|
|
|||
Loading…
Reference in a new issue