diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 32df249f36..762ee5bdfc 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -38,6 +38,29 @@ namespace onnxruntime { +common::Status SoftmaxCore(const int n, + const int d, + const float* Xdata, + float* Ydata, + float* scale, + const float* sum_multiplier, + bool logarithmic, + float* rowmax) { + + const int nd = n * d; + + math::RowwiseMax(n, d, Xdata, rowmax, nullptr); + + // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry + gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); + + math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); + // Exponentiation + math::Exp(nd, Ydata, Ydata, nullptr); + + return Status::OK(); +} + common::Status SoftmaxCPU(const int64_t N, const int64_t D, const float* Xdata, @@ -59,19 +82,23 @@ common::Status SoftmaxCPU(const int64_t N, const int d = gsl::narrow_cast(D); const int nd = gsl::narrow_cast(N * D); - math::RowwiseMax(n, d, Xdata, rowmax, nullptr); + static const int kGROUP = 8; + int g = (n + (kGROUP-1)) / kGROUP; - // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry - gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); + #pragma omp parallel for + for (int i = 0; i < kGROUP; ++i) { + int s = g * i; + if (s < n) { + int c = (n - s >= g)?g : (n-s); + SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), scale + s, sum_multiplier, logarithmic, rowmax+s); + } + } - math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); - - // Exponentiation - math::Exp(nd, Ydata, Ydata, nullptr); math::Gemv(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr); // Do division if (!logarithmic) { + #pragma omp parallel for for (int i = 0; i < N; ++i) { for (int j = 0; j < D; ++j) { Ydata[i * D + j] /= scale[i];