mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Optimize softmax cpu by parallel using openmp.
This commit is contained in:
parent
84fa1018a3
commit
e7bdfa00db
1 changed files with 34 additions and 7 deletions
|
|
@ -38,6 +38,29 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
common::Status SoftmaxCore(const int n,
|
||||
const int d,
|
||||
const float* Xdata,
|
||||
float* Ydata,
|
||||
float* scale,
|
||||
const float* sum_multiplier,
|
||||
bool logarithmic,
|
||||
float* rowmax) {
|
||||
|
||||
const int nd = n * d;
|
||||
|
||||
math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);
|
||||
|
||||
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
|
||||
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
|
||||
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
|
||||
// Exponentiation
|
||||
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status SoftmaxCPU(const int64_t N,
|
||||
const int64_t D,
|
||||
const float* Xdata,
|
||||
|
|
@ -59,19 +82,23 @@ common::Status SoftmaxCPU(const int64_t N,
|
|||
const int d = gsl::narrow_cast<int>(D);
|
||||
const int nd = gsl::narrow_cast<int>(N * D);
|
||||
|
||||
math::RowwiseMax<float, CPUMathUtil>(n, d, Xdata, rowmax, nullptr);
|
||||
static const int kGROUP = 8;
|
||||
int g = (n + (kGROUP-1)) / kGROUP;
|
||||
|
||||
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
|
||||
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < kGROUP; ++i) {
|
||||
int s = g * i;
|
||||
if (s < n) {
|
||||
int c = (n - s >= g)?g : (n-s);
|
||||
SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), scale + s, sum_multiplier, logarithmic, rowmax+s);
|
||||
}
|
||||
}
|
||||
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
|
||||
|
||||
// Exponentiation
|
||||
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
|
||||
math::Gemv<float, CPUMathUtil>(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr);
|
||||
|
||||
// Do division
|
||||
if (!logarithmic) {
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int j = 0; j < D; ++j) {
|
||||
Ydata[i * D + j] /= scale[i];
|
||||
|
|
|
|||
Loading…
Reference in a new issue