From 280ab9a2d09e626f4bf44531758b2b290cab0cd2 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sun, 9 Jun 2019 19:57:04 -0700 Subject: [PATCH] Let mlas use the session threadpool for gemm functions (#1196) --- .../cpu/attnlstm/attention_mechanism.h | 3 + .../cpu/attnlstm/attention_wrapper.cc | 23 +-- .../cpu/attnlstm/attention_wrapper.h | 5 +- .../cpu/attnlstm/bahdanau_attention.cc | 34 ++-- .../cpu/attnlstm/bahdanau_attention.h | 4 +- .../cpu/attnlstm/deep_cpu_attn_lstm.cc | 15 +- .../cpu/attnlstm/deep_cpu_attn_lstm.h | 6 - .../cpu/attnlstm/uni_dir_attn_lstm.cc | 9 +- .../contrib_ops/cpu/word_conv_embedding.cc | 12 +- .../contrib_ops/cpu/word_conv_embedding.h | 5 +- .../framework/op_kernel_context_internal.h | 1 + onnxruntime/core/providers/cpu/math/gemm.h | 9 +- .../core/providers/cpu/math/logsoftmax.cc | 7 +- onnxruntime/core/providers/cpu/math/matmul.cc | 11 +- .../core/providers/cpu/math/softmax.cc | 6 +- .../core/providers/cpu/math/softmax_shared.cc | 5 +- .../core/providers/cpu/math/softmax_shared.h | 5 +- onnxruntime/core/providers/cpu/nn/conv.cc | 24 +-- .../core/providers/cpu/nn/conv_transpose.cc | 9 +- .../core/providers/cpu/rnn/deep_cpu_gru.cc | 23 ++- .../core/providers/cpu/rnn/deep_cpu_lstm.cc | 6 +- onnxruntime/core/providers/cpu/rnn/rnn.cc | 12 +- .../core/providers/cpu/rnn/rnn_helpers.h | 6 +- onnxruntime/core/util/math.h | 14 +- onnxruntime/core/util/math_cpu.cc | 161 ++++++------------ onnxruntime/test/framework/math_test.cc | 42 ++--- .../test/providers/cpu/math/softmax_test.cc | 8 +- 27 files changed, 231 insertions(+), 234 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h b/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h index c648536f99..8aa36723bd 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h @@ -6,6 +6,9 @@ #include namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} namespace contrib { template diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc index 4555713a59..934855883c 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "attention_wrapper.h" +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/cpu/rnn/rnn_helpers.h" #include @@ -34,14 +35,14 @@ AttentionWrapper::AttentionWrapper(AllocatorPtr alloc, const logging::Logger& // rnn_cell_output is of [batch_size, rnn_cell_hidden_size] template -void AttentionWrapper::ProcessOutput(const gsl::span& rnn_cell_output) { +void AttentionWrapper::ProcessOutput(const gsl::span& rnn_cell_output, concurrency::ThreadPool* tp) { if (has_attn_layer_) { // rnn_cell_output * cell_weights, (part of the attention layer above the attention mechanism). - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0}, - rnn_cell_output.data(), inner_cell_hidden_size_, - attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0}, - attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0}, + rnn_cell_output.data(), inner_cell_hidden_size_, + attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0}, + attn_states_.data(), attn_layer_depth_, tp); } // Get the context which is calculated within attention mechanism. @@ -54,11 +55,11 @@ void AttentionWrapper::ProcessOutput(const gsl::span& rnn_cell_outpu //concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) = // p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_ // The first part is calulated above. Here just add the later. - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0}, - attn_context_.data(), attn_context_depth_, - attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0}, - attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0}, + attn_context_.data(), attn_context_depth_, + attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0}, + attn_states_.data(), attn_layer_depth_, tp); } } diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h index 2469a7b99a..061f1b6728 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h @@ -10,6 +10,9 @@ #include "core/framework/allocator.h" namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} namespace contrib { template @@ -27,7 +30,7 @@ class AttentionWrapper { virtual ~AttentionWrapper() = default; // Calculation based on output of the inner wrapped rnn_cell. - void ProcessOutput(const gsl::span& rnn_cell_state); + void ProcessOutput(const gsl::span& rnn_cell_state, onnxruntime::concurrency::ThreadPool* tp); gsl::span GetAttnStates() const; diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc index 932ac263f8..ca66b24ffc 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc @@ -15,8 +15,8 @@ namespace contrib { template BahdanauAttention::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger, int batch_size, int max_memory_step, int memory_depth, - int query_depth, int attn_depth, bool normalize) - : allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize) { + int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* tp) + : allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), tp_(tp) { values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true); keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true); processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true); @@ -72,11 +72,11 @@ void BahdanauAttention::PrepareMemory( "Real memory steps ", mem_steps, " is not in (0, ", max_memory_steps_, "]"); } - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0}, - memory.data(), memory_depth_, - memory_layer_weights_.data(), attn_depth_, T{0.0}, - keys_.data(), attn_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0}, + memory.data(), memory_depth_, + memory_layer_weights_.data(), attn_depth_, T{0.0}, + keys_.data(), attn_depth_, tp_); } template @@ -115,11 +115,11 @@ void BahdanauAttention::Compute( const gsl::span& output, const gsl::span& aligns) const { //process query in dense query layer without bias - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_, attn_depth_, query_depth_, T{1.0}, - queries.data(), query_depth_, - query_layer_weights_.data(), attn_depth_, T{0.0}, - processed_query_.data(), attn_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_, attn_depth_, query_depth_, T{1.0}, + queries.data(), query_depth_, + query_layer_weights_.data(), attn_depth_, T{0.0}, + processed_query_.data(), attn_depth_, tp_); std::fill(aligns.begin(), aligns.end(), T{}); @@ -146,11 +146,11 @@ void BahdanauAttention::Compute( // Calculate the context auto outspan = output.subspan(b * memory_depth_); auto values = values_.subspan(b * max_memory_steps_ * memory_depth_); - math::GemmEx(CblasNoTrans, CblasNoTrans, - 1, memory_depth_, max_memory_steps_, T{1.0}, - alignments, max_memory_steps_, - values.data(), memory_depth_, T{0.0}, - outspan.data(), memory_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + 1, memory_depth_, max_memory_steps_, T{1.0}, + alignments, max_memory_steps_, + values.data(), memory_depth_, T{0.0}, + outspan.data(), memory_depth_, tp_); } } diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h index 755af6ba6d..0d052fd8bd 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h @@ -23,7 +23,7 @@ class BahdanauAttention : public IAttentionMechanism { int memory_depth, int query_depth, int attn_depth, - bool normalize); + bool normalize, concurrency::ThreadPool* tp); void SetWeights( const gsl::span& attn_weights, @@ -53,7 +53,6 @@ class BahdanauAttention : public IAttentionMechanism { private: AllocatorPtr allocator_; const logging::Logger& logger_; - int batch_size_; int max_memory_steps_; int memory_depth_; @@ -77,6 +76,7 @@ class BahdanauAttention : public IAttentionMechanism { gsl::span mem_seq_lengths_; bool normalize_; + concurrency::ThreadPool* const tp_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc index 23eb0cc8e1..3dc4be32f5 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc @@ -9,6 +9,7 @@ #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/framework/allocator.h" +#include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { namespace contrib { @@ -70,6 +71,8 @@ static gsl::span SecondHalfSpan(const gsl::span& dspan) { template Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { + auto ctx_internal = static_cast(&context); + auto tp = ctx_internal->GetOperatorThreadPool(); auto& logger = context.Logger(); // original lstm processing @@ -229,7 +232,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { last_cell_size_per_direction); auto fam = std::make_unique>( - alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false); + alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp); fam->SetWeights( FirstHalfSpan(am_v_weights.DataAsSpan()), FirstHalfSpan(am_query_layer_weights.DataAsSpan()), @@ -248,10 +251,10 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, *tp); auto bam = std::make_unique>( - alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false); + alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp); bam->SetWeights( SecondHalfSpan(am_v_weights.DataAsSpan()), SecondHalfSpan(am_query_layer_weights.DataAsSpan()), @@ -270,14 +273,14 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[3], activation_funcs_.Entries()[4], activation_funcs_.Entries()[5], - clip_, ttp_); + clip_, *tp); fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2); } else { auto fam = std::make_unique>( - alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false); + alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp); fam->SetWeights( am_v_weights.DataAsSpan(), am_query_layer_weights.DataAsSpan(), @@ -296,7 +299,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, *tp); fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); } diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h index a3accc70b8..c881a1efde 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h @@ -91,12 +91,6 @@ class DeepCpuAttnLstmOp final : public OpKernel { bool input_forget_ = false; ActivationFuncs activation_funcs_; - -// Threadpool for operator. If concurrent Compute calls are possible, it will be shared -// across them. mutable due to this. -// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation -// cost on every call. - mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_ATTN_LSTM", (int)std::thread::hardware_concurrency()}; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc index caa05f9d5c..3ddf522ca3 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc @@ -200,6 +200,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, gsl::span& outputs, gsl::span& final_hidden_state, gsl::span& final_cell_state) { + onnxruntime::concurrency::ThreadPool* tp = &ttp_; // copy spans (just T* and size, not data in span) as we may change them gsl::span inputs = inputs_arg; gsl::span sequence_lengths = sequence_lengths_arg; @@ -254,7 +255,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, input_weights.cbegin(), input_weights.cend(), // W[iofc]^T input_size_ + attention_size_, T{0.0}, output_iofc_.begin(), output_iofc_.end(), - hidden_size_x4); + hidden_size_x4, tp); DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4); @@ -296,7 +297,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, input_weights.cbegin() + input_size_, input_weights.cend(), // WA[iofc] input_size_ + attention_size_, T{1.0}, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, tp); // calculate Xt*(W[iofc]^T) + Ht-1*R[iofc] ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, T{1.0}, @@ -305,7 +306,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] hidden_size_, T{1.0}, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, tp); span_T_iter batched_output, batched_output_end; if (output_sequence) { @@ -345,7 +346,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, previous_state = batched_output; previous_state_end = batched_output_end; - attention_wrapper_.ProcessOutput(outputs.subspan(step * output_step_length, batch_size_ * hidden_size_)); + attention_wrapper_.ProcessOutput(outputs.subspan(step * output_step_length, batch_size_ * hidden_size_), tp); } } diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc index 7d7f577d5e..7dd395ce76 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc @@ -6,6 +6,7 @@ #include "core/util/math.h" #include "core/util/math_cpuonly.h" #include "core/mlas/inc/mlas.h" +#include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { namespace contrib { @@ -45,7 +46,7 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation( int64_t char_embedding_size, int64_t filter_width, int64_t num_filters, - float* output) const { + float* output, concurrency::ThreadPool* tp) const { int64_t input_word_size = word_len * char_embedding_size; int64_t unfolded_width = word_len - filter_width + 1; int64_t unfolded_kernal_size = filter_width * char_embedding_size; @@ -83,12 +84,12 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation( tmp_word_inx++; } - math::GemmEx( + math::GemmEx( CblasNoTrans, CblasTrans, static_cast(words_unfolded_width), static_cast(num_filters), static_cast(unfolded_kernal_size), 1.0f, unfolded_buffer_p.get(), static_cast(unfolded_kernal_size), weights, static_cast(unfolded_kernal_size), 0.0f, - conv_buf_p, static_cast(num_filters), &CPUMathUtil::Instance()); + conv_buf_p, static_cast(num_filters), tp); for (int64_t unfolded_inx = 0; unfolded_inx < words_unfolded_width; unfolded_inx++) for (int64_t filter_inx = 0; filter_inx < num_filters; filter_inx++) { @@ -160,6 +161,9 @@ Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, co } Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + auto tp = ctx_internal->GetOperatorThreadPool(); + // original lstm processing const Tensor& sequence = *(ctx->Input(0)); // sequence: [sequence_length, word_length] const Tensor& w_conv = *(ctx->Input(1)); // conv weight: [M, C/group, kH, kW] @@ -216,7 +220,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { char_embedding_size, filter_width, filter_size, - Y->MutableData()); + Y->MutableData(), tp); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h index e74afab169..5ee4127e3b 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h @@ -8,6 +8,9 @@ #include "core/framework/tensor.h" namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} namespace contrib { class WordConvEmbedding final : public OpKernel { @@ -38,7 +41,7 @@ class WordConvEmbedding final : public OpKernel { int64_t char_embedding_size, int64_t filter_width, int64_t num_filters, - float* output) const; + float* output, onnxruntime::concurrency::ThreadPool* tp) const; void CalculateLengthOfEachWordInSequence( const int* seq_ptr, int* words_len_ptr, diff --git a/onnxruntime/core/framework/op_kernel_context_internal.h b/onnxruntime/core/framework/op_kernel_context_internal.h index 02515ba39a..edde9a6480 100644 --- a/onnxruntime/core/framework/op_kernel_context_internal.h +++ b/onnxruntime/core/framework/op_kernel_context_internal.h @@ -58,6 +58,7 @@ class OpKernelContextInternal : public OpKernelContext { const bool& GetTerminateFlag() const noexcept { return terminate_flag_; } const onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return session_state_.GetThreadPool(); } + onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() { return session_state_.GetThreadPool(); } private: const SessionState& session_state_; diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index c72c5bc1e0..e676759a66 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -4,9 +4,11 @@ #pragma once #include "core/common/common.h" +#include "core/platform/threadpool.h" #include "core/framework/op_kernel.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" +#include "core/framework/op_kernel_context_internal.h" #include "gemm_helper.h" namespace onnxruntime { @@ -30,6 +32,9 @@ class Gemm : public OpKernel { } Status Compute(OpKernelContext* context) const override { + auto ctx_internal = static_cast(context); + auto thread_pool = ctx_internal->GetOperatorThreadPool(); + const auto X = context->Input(0); const auto W = context->Input(1); const auto B = context->Input(2); @@ -95,7 +100,7 @@ class Gemm : public OpKernel { } // W * x - math::Gemm( + math::Gemm( trans_A_, trans_B_, M, @@ -106,7 +111,7 @@ class Gemm : public OpKernel { W->template Data(), beta_, y_data, - &CPUMathUtil::Instance()); + thread_pool); FuseActivation(activation_, y_data, M * N, leaky_relu_alpha_); diff --git a/onnxruntime/core/providers/cpu/math/logsoftmax.cc b/onnxruntime/core/providers/cpu/math/logsoftmax.cc index 281031e715..e4981bf073 100644 --- a/onnxruntime/core/providers/cpu/math/logsoftmax.cc +++ b/onnxruntime/core/providers/cpu/math/logsoftmax.cc @@ -4,6 +4,8 @@ #include "core/providers/cpu/math/logsoftmax.h" #include "core/framework/op_kernel.h" +#include "core/framework/op_kernel_context_internal.h" + #include "core/providers/common.h" #include "core/providers/cpu/math/softmax_shared.h" #include "core/util/math.h" @@ -12,6 +14,9 @@ namespace onnxruntime { template <> Status LogSoftmax::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + auto tp = ctx_internal->GetOperatorThreadPool(); + const auto* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& X = *tensor_pointer; @@ -32,7 +37,7 @@ Status LogSoftmax::Compute(OpKernelContext* ctx) const { const bool logarithmic = true; auto status = SoftmaxCPU(N, D, X.template Data(), Ydata, - scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data()); + scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp); return status; } diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 19a0f28f7e..7e3b3d3fd9 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/providers/cpu/math/matmul.h" - +#include "core/platform/threadpool.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" +#include "core/framework/op_kernel_context_internal.h" + #include "matmul_helper.h" namespace onnxruntime { @@ -53,6 +55,9 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( template Status MatMul::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + auto thread_pool = ctx_internal->GetOperatorThreadPool(); + const auto* left_X = ctx->Input(0); const auto* right_X = ctx->Input(1); @@ -64,7 +69,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { // TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well size_t max_len = helper.OutputOffsets().size(); for (size_t i = 0; i < max_len; i++) { - math::Gemm( + math::Gemm( CblasNoTrans, CblasNoTrans, static_cast(helper.M()), @@ -75,7 +80,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { right_X->template Data() + helper.RightOffsets()[i], /* beta */ 0.0f, Y->template MutableData() + helper.OutputOffsets()[i], - &CPUMathUtil::Instance()); + thread_pool); } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/math/softmax.cc b/onnxruntime/core/providers/cpu/math/softmax.cc index 9242967901..dce810a423 100644 --- a/onnxruntime/core/providers/cpu/math/softmax.cc +++ b/onnxruntime/core/providers/cpu/math/softmax.cc @@ -4,6 +4,7 @@ #include "core/providers/cpu/math/softmax.h" #include "core/framework/op_kernel.h" +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/common.h" #include "core/providers/cpu/math/softmax_shared.h" #include "core/util/math.h" @@ -12,6 +13,9 @@ namespace onnxruntime { template <> Status Softmax::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + auto tp = ctx_internal->GetOperatorThreadPool(); + const auto* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& X = *tensor_pointer; @@ -34,7 +38,7 @@ Status Softmax::Compute(OpKernelContext* ctx) const { const bool logarithmic = false; auto status = SoftmaxCPU(N, D, X.template Data(), Ydata, - scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data()); + scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp); return status; } diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 7dd3a10cfc..2bb3c337bf 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -31,6 +31,7 @@ #endif #include "core/providers/cpu/math/softmax_shared.h" + #include "core/util/math.h" #include "core/util/math_cpuonly.h" @@ -46,7 +47,7 @@ common::Status SoftmaxCPU(const int64_t N, float* scale, const float* sum_multiplier, bool logarithmic, - float* rowmax) { + float* rowmax, onnxruntime::concurrency::ThreadPool* tp) { // the Math functions SoftmaxCPU uses only support int32_t as input, so enforce that if (N * D > INT32_MAX || N > INT32_MAX || D > INT32_MAX) { std::ostringstream ss; @@ -65,7 +66,7 @@ common::Status SoftmaxCPU(const int64_t N, // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); - math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); + math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, tp); // Exponentiation math::Exp(nd, Ydata, Ydata, nullptr); diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.h b/onnxruntime/core/providers/cpu/math/softmax_shared.h index 3439b9717f..26ffeb193f 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.h +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.h @@ -6,6 +6,9 @@ #include "core/common/status.h" namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} /** Calculate Softmax using CPU memory. @param N Number of rows @@ -18,5 +21,5 @@ Calculate Softmax using CPU memory. @param rowmax Storage for calculation of maximum in each row. Size must be >= N. */ common::Status SoftmaxCPU(int64_t N, int64_t D, const float* Xdata, float* Ydata, float* scale, - const float* sum_multiplier, bool logarithmic, float* rowmax); + const float* sum_multiplier, bool logarithmic, float* rowmax, concurrency::ThreadPool* tp); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index 7505aca264..43e56edd70 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -49,6 +49,11 @@ Status Conv::Compute(OpKernelContext* context) const { const size_t kernel_rank = kernel_shape.size(); + // Get access to the internal threadpool + // Temporarily derive concurrency parameters without access to session state + auto ctx_internal = static_cast(context); + auto thread_pool = ctx_internal->GetOperatorThreadPool(); + if (kernel_rank == 2 || kernel_rank == 3) { MLAS_ACTIVATION Activation; if (activation_.empty()) { @@ -66,11 +71,6 @@ Status Conv::Compute(OpKernelContext* context) const { ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_); } - // Get access to the internal threadpool - // Temporarily derive concurrency parameters without access to session state - auto ctx_internal = static_cast(context); - auto thread_pool = ctx_internal->GetOperatorThreadPool(); - MLAS_CONV_PARAMETERS Parameters; size_t WorkingBufferSize; MlasConvPrepare(&Parameters, @@ -87,9 +87,9 @@ Status Conv::Compute(OpKernelContext* context) const { static_cast(M / group_), &Activation, &WorkingBufferSize, - const_cast(thread_pool)); + thread_pool); - auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr; + auto working_data = WorkingBufferSize > 0 ? alloc->AllocArray(sizeof(float), WorkingBufferSize) : nullptr; BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc)); MlasConv(&Parameters, @@ -98,7 +98,7 @@ Status Conv::Compute(OpKernelContext* context) const { B != nullptr ? B->template Data() : nullptr, static_cast(working_buffer.get()), Ydata, - const_cast(thread_pool)); + thread_pool); } else { const int64_t input_image_size = input_shape.Size(); const int64_t output_image_size = output_shape.Size(); @@ -120,7 +120,7 @@ Status Conv::Compute(OpKernelContext* context) const { for (int image_id = 0; image_id < N; ++image_id) { for (int group_id = 0; group_id < group_; ++group_id) { - math::Im2colNd()( + math::Im2colNd()( Xdata + group_id * X_offset, image_shape.GetDims().data(), col_buffer_shape.data(), @@ -132,8 +132,8 @@ Status Conv::Compute(OpKernelContext* context) const { pads.data(), static_cast(kernel_shape.size()), col_buffer_data, - &CPUMathUtil::Instance()); - math::Gemm( + thread_pool); + math::Gemm( CblasNoTrans, CblasNoTrans, M / group_, @@ -144,7 +144,7 @@ Status Conv::Compute(OpKernelContext* context) const { col_buffer_data, 0, Ydata + group_id * Y_offset, - &CPUMathUtil::Instance()); + thread_pool); } if (B != nullptr) { diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index 14f13ccd20..e9d820f518 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -16,6 +16,8 @@ /* Modifications Copyright (c) Microsoft. */ #include "core/providers/cpu/nn/conv_transpose.h" +#include "core/framework/op_kernel_context_internal.h" + #include "core/util/math.h" #include "core/util/math_cpuonly.h" @@ -228,6 +230,9 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { template Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + auto ctx_internal = static_cast(context); + auto tp = ctx_internal->GetOperatorThreadPool(); + size_t num_inputs = OpKernel::Node().InputDefs().size(); Prepare p; bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; @@ -254,7 +259,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ for (auto image_id = 0; image_id < p.N; ++image_id) { for (int group_id = 0; group_id < group_; ++group_id) { // Weight term - math::Gemm( + math::Gemm( CblasTrans, CblasNoTrans, kernel_dim, @@ -265,7 +270,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ Xdata + group_id * X_offset, 0, col_buffer_data, - &CPUMathUtil::Instance()); + tp); // Col2im math::Col2im( diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index b6520bd60f..b953785478 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -18,6 +18,7 @@ #include "core/common/logging/logging.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" +#include "core/framework/op_kernel_context_internal.h" #include "core/platform/ort_mutex.h" @@ -171,7 +172,7 @@ class UniDirectionalGru { void Compute(const gsl::span& inputs, const gsl::span& sequence_lengths, int num_directions, const gsl::span& input_weights, const gsl::span& recurrent_weights, - gsl::span& outputs, gsl::span& final_hidden_state); + gsl::span& outputs, gsl::span& final_hidden_state, onnxruntime::concurrency::ThreadPool* tp); ~UniDirectionalGru() = default; @@ -263,6 +264,9 @@ Status DeepCpuGruOp::Compute(OpKernelContext* context) const { template Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { + auto ctx_internal = static_cast(&context); + auto tp = ctx_internal->GetOperatorThreadPool(); + const Tensor& X = *context.Input(0); // inputs. [seq_length, batch_size, input_size] const Tensor& W = *context.Input(1); // weights. [num_directions, 3*hidden_size, input_size] const Tensor& R = *context.Input(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size] @@ -375,7 +379,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], clip_); - fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1); + fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, tp); std::unique_ptr> bw = std::make_unique>( alloc, @@ -389,7 +393,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[2], activation_funcs_.Entries()[3], clip_); - bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2); + bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2, tp); } else { std::unique_ptr> gru_p = std::make_unique>( alloc, @@ -404,7 +408,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[1], clip_); - gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1); + gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, tp); } if (!output.empty()) @@ -505,7 +509,8 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, const gsl::span& input_weights, const gsl::span& recurrent_weights, gsl::span& outputs, - gsl::span& final_hidden_state) { + gsl::span& final_hidden_state, + onnxruntime::concurrency::ThreadPool* tp) { using span_T_const_iter = typename gsl::span::const_iterator; using span_T_iter = typename gsl::span::iterator; @@ -559,7 +564,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, input_weights.cbegin(), input_weights.cend(), input_size_, beta, outputZRH_.begin(), outputZRH_.end(), - hidden_size_x3); + hidden_size_x3, tp); DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_); @@ -624,7 +629,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(), hidden_size_, beta, outputZRH_.begin() + out_added_offset, outputZRH_.end(), - hidden_size_x3); + hidden_size_x3, tp); DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str, outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3); @@ -640,7 +645,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T hidden_size_, beta, linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output - hidden_size_); + hidden_size_, tp); DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_); } @@ -707,7 +712,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T hidden_size_, beta, out_H, outputZRH_.end(), - hidden_size_x3); + hidden_size_x3, tp); } DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset, diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index fbff94cfd6..c5d05f8e4f 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -783,7 +783,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, input_weights.cbegin(), input_weights.cend(), // W[iofc] input_size_, beta, output_iofc_.begin(), output_iofc_.end(), - hidden_size_x4); + hidden_size_x4, &ttp_); DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4); @@ -832,7 +832,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, &ttp_); DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4); @@ -910,7 +910,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, &ttp_); span_T_iter batched_output; span_T_iter batched_output_end; diff --git a/onnxruntime/core/providers/cpu/rnn/rnn.cc b/onnxruntime/core/providers/cpu/rnn/rnn.cc index 4030d65a94..e35387794c 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/cpu/rnn/rnn.h" #include "core/providers/cpu/rnn/rnn_activation_functors.h" #include "core/providers/cpu/rnn/rnn_helpers.h" @@ -98,6 +99,9 @@ using EigenMatrixMapRowMajor = Eigen::Map< template <> Status RNN::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + auto tp = ctx_internal->GetOperatorThreadPool(); + using namespace rnn::detail; // inputs @@ -160,7 +164,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { } // X * W[direction]^t + B - math::Gemm( + math::Gemm( CblasNoTrans, CblasTrans, static_cast(seq_length * batch_size), @@ -171,7 +175,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { W.template Data() + direction * hidden_size_ * input_size, 1, x_matmul_w_buffer_data, - &CPUMathUtil::Instance()); + tp); for (int64_t t = 0; t < seq_length; t++) { int64_t time_step = isReverse ? (seq_length - t - 1) : t; @@ -192,7 +196,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { if (h_prev != nullptr) { // H_t_1 * R[direction]^t - math::Gemm( + math::Gemm( CblasNoTrans, CblasTrans, static_cast(batch_size), @@ -203,7 +207,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { R.template Data() + direction * hidden_size_ * hidden_size_, 0, Y_buffer_data_current_frame, - &CPUMathUtil::Instance()); + tp); } else { math::Set(batch_size * hidden_size_, 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance()); } diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index 2e3e5f88d7..bb1b1de10f 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -159,7 +159,7 @@ void ComputeGemm(const int M, const float beta, TSpanCIter C, TSpanCIter C_end, - const int ldc) { + const int ldc, onnxruntime::concurrency::ThreadPool* tp) { // validate all the inputs // need to use the lda/ldb/ldc strides which should be >= the columns for the span ORT_ENFORCE(lda >= K && ldb >= K && ldc >= N); @@ -167,12 +167,12 @@ void ComputeGemm(const int M, ORT_ENFORCE(B + (N * ldb - (ldb - K)) <= B_end); ORT_ENFORCE(C + (M * ldc - (ldc - N)) <= C_end); - ::onnxruntime::math::GemmEx( + ::onnxruntime::math::GemmEx( CblasNoTrans, CblasTrans, M, N, K, alpha, &*A, lda, &*B, ldb, beta, - &*C, ldc, &CPUMathUtil::Instance()); + &*C, ldc, tp); } // helper to convert a span to a raw pointer diff --git a/onnxruntime/core/util/math.h b/onnxruntime/core/util/math.h index 70d9cb3630..73901d17cc 100644 --- a/onnxruntime/core/util/math.h +++ b/onnxruntime/core/util/math.h @@ -40,6 +40,9 @@ extern "C" { #include "core/framework/tensor.h" namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} enum StorageOrder { UNKNOWN = 0, @@ -187,10 +190,7 @@ void Gemm( const T* B, float beta, T* C, - Provider* provider, - //Caffe2 use this type to control on GPU, what presicion do we want to do the calculation - //But not sure is this a good design for us. Keep it here for now. - MLDataType math_type = FLOAT_TYPE); + Provider*); // We also provide a gemm that has explicit lda, ldb and ldc specified. // In most cases you probably want to use the function above, though. @@ -209,7 +209,7 @@ void GemmEx( T beta, T* C, int ldc, - Provider* provider); + Provider*); // GemmBatched provides a simple abstraction into library routines template @@ -228,9 +228,7 @@ void GemmBatched( const T* B, float beta, T* C, - Provider* provider, - Tensor* scratch = nullptr, - MLDataType math_type = DataTypeImpl::FLOAT_TYPE); + Provider* tp); // Gemv always takes in a M*N matrix A, and depending on whether we set TransA // to Trans, the output is: diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 3359a2f334..98f83b577f 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -40,9 +40,7 @@ #include "core/util/math_cpuonly.h" #include "Eigen/src/Core/arch/GPU/Half.h" -#if defined(USE_MLAS) #include "core/mlas/inc/mlas.h" -#endif namespace onnxruntime { namespace math { @@ -107,7 +105,7 @@ void GemmEigen( // CBLAS call or the Eigen implementation. //////////////////////////////////////////////////////////////////////////////// // when USE_MKLML is defined, use cblas APIs for MKLML -#if defined(USE_EIGEN_FOR_BLAS) && !defined(USE_MKLML_FOR_BLAS) +#if !defined(USE_MKLML_FOR_BLAS) // Caffe2 gemm provides a simpler interface to the gemm functions, with the // limitation that the data has to be contiguous in memory. @@ -125,112 +123,60 @@ void GemmEigen( // (transpose) if the argument TransA or TransB is set to CblasNoTrans or // CblasTrans, respectively, for each of A and B. template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta, - float* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { -#if defined(USE_MLAS) +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta, + float* C, onnxruntime::concurrency::ThreadPool* threadpool) { int lda = static_cast((TransA == CblasNoTrans) ? K : M); int ldb = static_cast((TransB == CblasNoTrans) ? N : K); // TODO: Make this use the operator threadpool - MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, nullptr); -#else - GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); -#endif + MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, threadpool); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const double* A, const double* B, - float beta, double* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const double* A, const double* B, + float beta, double* C, onnxruntime::concurrency::ThreadPool*) { // No double precision Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B, - float beta, int32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B, + float beta, int32_t* C, onnxruntime::concurrency::ThreadPool*) { // No int32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B, - float beta, uint32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B, + float beta, uint32_t* C, onnxruntime::concurrency::ThreadPool*) { // No uint32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B, - float beta, int64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B, + float beta, int64_t* C, onnxruntime::concurrency::ThreadPool*) { // No int64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B, - float beta, uint64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B, + float beta, uint64_t* C, onnxruntime::concurrency::ThreadPool*) { // No uint64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void GemmEx(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K, - float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, - int ldc, CPUMathUtil*) { -#if defined(USE_MLAS) - MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, nullptr); -#else - using OuterStride = Eigen::OuterStride; - using StridedMap = Eigen::Map; - using ConstStridedMap = Eigen::Map; - auto C_mat = StridedMap(C, N, M, OuterStride(ldc)); - if (beta == 0) { - C_mat.setZero(); - } else { - C_mat *= beta; - } - switch (TransA) { - case CblasNoTrans: { - switch (TransB) { - case CblasNoTrans: - C_mat.noalias() += - alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) * - ConstStridedMap(A, K, M, OuterStride(lda))); - return; - case CblasTrans: - C_mat.noalias() += - alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() * - ConstStridedMap(A, K, M, OuterStride(lda))); - return; - default: - ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); - } - } - case CblasTrans: { - switch (TransB) { - case CblasNoTrans: - C_mat.noalias() += - alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) * - ConstStridedMap(A, M, K, OuterStride(lda)).transpose()); - return; - case CblasTrans: - C_mat.noalias() += - alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() * - ConstStridedMap(A, M, K, OuterStride(lda)).transpose()); - return; - default: - ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); - } - } - default: - ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA); - } -#endif +void GemmEx(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K, + float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, + int ldc, onnxruntime::concurrency::ThreadPool* tp) { + MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, tp); } template <> @@ -301,12 +247,12 @@ SPECIALIZED_AXPY(float) SPECIALIZED_AXPBY(float) #undef SPECIALIZED_AXPBY -#else // USE_EIGEN_FOR_BLAS +#else template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta, - float* C, CPUMathUtil* /*context*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta, + float* C, concurrency::ThreadPool*) { int lda = gsl::narrow_cast((TransA == CblasNoTrans) ? K : M); int ldb = gsl::narrow_cast((TransB == CblasNoTrans) ? N : K); cblas_sgemm(CblasRowMajor, TransA, TransB, @@ -318,9 +264,9 @@ void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOS } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const double* A, const double* B, - float beta, double* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const double* A, const double* B, + float beta, double* C, concurrency::ThreadPool*) { int lda = gsl::narrow_cast((TransA == CblasNoTrans) ? K : M); int ldb = gsl::narrow_cast((TransB == CblasNoTrans) ? N : K); cblas_dgemm(CblasRowMajor, TransA, TransB, gsl::narrow_cast(M), gsl::narrow_cast(N), @@ -329,41 +275,41 @@ void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPO } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B, - float beta, int32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B, + float beta, int32_t* C, concurrency::ThreadPool*) { // No int32_t Gemm offering from MKLML. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B, - float beta, uint32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B, + float beta, uint32_t* C, concurrency::ThreadPool*) { // No uint32_t Gemm offering from MKLML. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B, - float beta, int64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B, + float beta, int64_t* C, concurrency::ThreadPool*) { // No int64_t Gemm offering from MKLML. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, - const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B, - float beta, uint64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) { +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M, + const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B, + float beta, uint64_t* C, concurrency::ThreadPool*) { // No uint64_t Gemm offering from MKLML. Directly fallback to Eigen. GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); } template <> -void GemmEx(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K, - float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, - int ldc, CPUMathUtil* /*context*/) { +void GemmEx(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K, + float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, + int ldc, concurrency::ThreadPool*) { cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } @@ -417,20 +363,19 @@ CAFFE2_SPECIALIZED_AXPY(float, s) CAFFE2_SPECIALIZED_AXPBY(float, s) #undef CAFFE2_SPECIALIZED_AXPBY -#endif // USE_EIGEN_FOR_BLAS +#endif template <> -void GemmBatched(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int A_size, - int A_batches, int B_size, int B_batches, int M, int N, int K, float /*alpha*/, - const float* A, const float* B, float /*beta*/, float* C, CPUMathUtil* provider, - Tensor*, /* scratch */ - MLDataType /* math_type */) { +void GemmBatched(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int A_size, + int A_batches, int B_size, int B_batches, int M, int N, int K, float /*alpha*/, + const float* A, const float* B, float /*beta*/, float* C, + onnxruntime::concurrency::ThreadPool* tp) { auto a_offset = A_size / A_batches; auto b_offset = B_size / B_batches; auto y_offset = M * N; // loop over matrices in the batch for (int i = 0; i < A_batches; ++i) { - math::Gemm( + math::Gemm( TransA, TransB, M, @@ -441,7 +386,7 @@ void GemmBatched(const CBLAS_TRANSPOSE TransA, const CBLAS_T B + b_offset * i, 0, C + y_offset * i, - provider); + tp); } } diff --git a/onnxruntime/test/framework/math_test.cc b/onnxruntime/test/framework/math_test.cc index 9710464733..3eed336dcd 100644 --- a/onnxruntime/test/framework/math_test.cc +++ b/onnxruntime/test/framework/math_test.cc @@ -17,12 +17,14 @@ #include "core/util/math.h" #include +#include "core/platform/threadpool.h" #include "core/util/math_cpuonly.h" namespace onnxruntime { #define VECTOR_HEAD(x) x.size() > 0 ? &x[0] : NULL TEST(MathTest, GemmNoTransNoTrans) { + concurrency::ThreadPool tp("", 1); auto& provider = CPUMathUtil::Instance(); std::vector X(50); // 5 * 10 std::vector W(60); // 10 * 6 @@ -40,26 +42,26 @@ TEST(MathTest, GemmNoTransNoTrans) { const float kOne = 1.0; const float kPointFive = 0.5; const float kZero = 0.0; - math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne, - VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y), - &provider); + math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne, + VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y), + &tp); EXPECT_EQ(Y.size(), 30); for (size_t i = 0; i < Y.size(); ++i) { EXPECT_EQ(Y[i], 10) << i; } // Test Accumulate - math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne, - VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive, - VECTOR_HEAD(Y), &provider); + math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne, + VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive, + VECTOR_HEAD(Y), &tp); EXPECT_EQ(Y.size(), 30); for (size_t i = 0; i < Y.size(); ++i) { EXPECT_EQ(Y[i], 15) << i; } // Test Accumulate - math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, - kPointFive, - VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y), - &provider); + math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, + kPointFive, + VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y), + &tp); EXPECT_EQ(Y.size(), 30); for (size_t i = 0; i < Y.size(); ++i) { EXPECT_EQ(Y[i], 20) << i; @@ -68,6 +70,8 @@ TEST(MathTest, GemmNoTransNoTrans) { TEST(MathTest, GemmNoTransTrans) { auto& provider = CPUMathUtil::Instance(); + concurrency::ThreadPool tp("", 1); + std::vector X(50); // 5 * 10 std::vector W(60); // 10 * 6 std::vector Y(30); // 5 * 6 @@ -84,24 +88,24 @@ TEST(MathTest, GemmNoTransTrans) { const float kOne = 1.0; const float kPointFive = 0.5; const float kZero = 0.0; - math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, kOne, - VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y), - &provider); + math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, kOne, + VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y), + &tp); EXPECT_EQ(Y.size(), 30); for (size_t i = 0; i < Y.size(); ++i) { EXPECT_EQ(Y[i], 10) << i; } // Test Accumulate - math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, kOne, - VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive, - VECTOR_HEAD(Y), &provider); + math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, kOne, + VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive, + VECTOR_HEAD(Y), &tp); EXPECT_EQ(Y.size(), 30); for (size_t i = 0; i < Y.size(); ++i) { EXPECT_EQ(Y[i], 15) << i; } - math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, kPointFive, - VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y), - &provider); + math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, kPointFive, + VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y), + &tp); EXPECT_EQ(Y.size(), 30); for (size_t i = 0; i < Y.size(); ++i) { EXPECT_EQ(Y[i], 20) << i; diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index aad97cfaf8..0273378194 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -194,23 +194,23 @@ TEST(SoftmaxOperator, InvalidAxis) { TEST(SoftmaxOperator, TestInputTooLarge) { float* ignored = nullptr; - + concurrency::ThreadPool tp("", 1); // N > INT32_MAX int64_t N = int64_t(INT32_MAX) + 1; int64_t D = 1; - auto status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored); + auto status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp); EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT); // D > INT32_MAX N = 1; D = int64_t(INT32_MAX) + 1; - status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored); + status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp); EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT); // N * D > INT32_MAX N = int64_t(INT32_MAX) / 2; D = 3; - status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored); + status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp); EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT); /*