Let mlas use the session threadpool for gemm functions (#1196)

This commit is contained in:
Changming Sun 2019-06-09 19:57:04 -07:00 committed by GitHub
parent be36385a8c
commit 280ab9a2d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 231 additions and 234 deletions

View file

@ -6,6 +6,9 @@
#include <gsl/span>
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
}
namespace contrib {
template <typename T>

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "attention_wrapper.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/providers/cpu/rnn/rnn_helpers.h"
#include <stdexcept>
@ -34,14 +35,14 @@ AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger&
// rnn_cell_output is of [batch_size, rnn_cell_hidden_size]
template <typename T>
void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_output) {
void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_output, concurrency::ThreadPool* tp) {
if (has_attn_layer_) {
// rnn_cell_output * cell_weights, (part of the attention layer above the attention mechanism).
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
rnn_cell_output.data(), inner_cell_hidden_size_,
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
math::GemmEx<T, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
rnn_cell_output.data(), inner_cell_hidden_size_,
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
attn_states_.data(), attn_layer_depth_, tp);
}
// Get the context which is calculated within attention mechanism.
@ -54,11 +55,11 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
//concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
// The first part is calulated above. Here just add the later.
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
math::GemmEx<T, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
attn_states_.data(), attn_layer_depth_, tp);
}
}

View file

@ -10,6 +10,9 @@
#include "core/framework/allocator.h"
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
}
namespace contrib {
template <typename T>
@ -27,7 +30,7 @@ class AttentionWrapper {
virtual ~AttentionWrapper() = default;
// Calculation based on output of the inner wrapped rnn_cell.
void ProcessOutput(const gsl::span<const T>& rnn_cell_state);
void ProcessOutput(const gsl::span<const T>& rnn_cell_state, onnxruntime::concurrency::ThreadPool* tp);
gsl::span<const T> GetAttnStates() const;

View file

@ -15,8 +15,8 @@ namespace contrib {
template <typename T>
BahdanauAttention<T>::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger,
int batch_size, int max_memory_step, int memory_depth,
int query_depth, int attn_depth, bool normalize)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize) {
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* tp)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), tp_(tp) {
values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true);
keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true);
processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true);
@ -72,11 +72,11 @@ void BahdanauAttention<T>::PrepareMemory(
"Real memory steps ", mem_steps, " is not in (0, ", max_memory_steps_, "]");
}
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
memory.data(), memory_depth_,
memory_layer_weights_.data(), attn_depth_, T{0.0},
keys_.data(), attn_depth_, &CPUMathUtil::Instance());
math::GemmEx<T, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
memory.data(), memory_depth_,
memory_layer_weights_.data(), attn_depth_, T{0.0},
keys_.data(), attn_depth_, tp_);
}
template <typename T>
@ -115,11 +115,11 @@ void BahdanauAttention<T>::Compute(
const gsl::span<T>& output,
const gsl::span<T>& aligns) const {
//process query in dense query layer without bias
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,
query_layer_weights_.data(), attn_depth_, T{0.0},
processed_query_.data(), attn_depth_, &CPUMathUtil::Instance());
math::GemmEx<T, onnxruntime::concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,
query_layer_weights_.data(), attn_depth_, T{0.0},
processed_query_.data(), attn_depth_, tp_);
std::fill(aligns.begin(), aligns.end(), T{});
@ -146,11 +146,11 @@ void BahdanauAttention<T>::Compute(
// Calculate the context
auto outspan = output.subspan(b * memory_depth_);
auto values = values_.subspan(b * max_memory_steps_ * memory_depth_);
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
1, memory_depth_, max_memory_steps_, T{1.0},
alignments, max_memory_steps_,
values.data(), memory_depth_, T{0.0},
outspan.data(), memory_depth_, &CPUMathUtil::Instance());
math::GemmEx<T, onnxruntime::concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
1, memory_depth_, max_memory_steps_, T{1.0},
alignments, max_memory_steps_,
values.data(), memory_depth_, T{0.0},
outspan.data(), memory_depth_, tp_);
}
}

View file

@ -23,7 +23,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
int memory_depth,
int query_depth,
int attn_depth,
bool normalize);
bool normalize, concurrency::ThreadPool* tp);
void SetWeights(
const gsl::span<const T>& attn_weights,
@ -53,7 +53,6 @@ class BahdanauAttention : public IAttentionMechanism<T> {
private:
AllocatorPtr allocator_;
const logging::Logger& logger_;
int batch_size_;
int max_memory_steps_;
int memory_depth_;
@ -77,6 +76,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
gsl::span<int> mem_seq_lengths_;
bool normalize_;
concurrency::ThreadPool* const tp_;
};
} // namespace contrib

View file

@ -9,6 +9,7 @@
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
#include "core/framework/op_kernel_context_internal.h"
namespace onnxruntime {
namespace contrib {
@ -70,6 +71,8 @@ static gsl::span<const T> SecondHalfSpan(const gsl::span<const T>& dspan) {
template <typename T>
Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(&context);
auto tp = ctx_internal->GetOperatorThreadPool();
auto& logger = context.Logger();
// original lstm processing
@ -229,7 +232,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
last_cell_size_per_direction);
auto fam = std::make_unique<BahdanauAttention<T>>(
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp);
fam->SetWeights(
FirstHalfSpan(am_v_weights.DataAsSpan<T>()),
FirstHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
@ -248,10 +251,10 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
clip_, *tp);
auto bam = std::make_unique<BahdanauAttention<T>>(
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp);
bam->SetWeights(
SecondHalfSpan(am_v_weights.DataAsSpan<T>()),
SecondHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
@ -270,14 +273,14 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[3],
activation_funcs_.Entries()[4],
activation_funcs_.Entries()[5],
clip_, ttp_);
clip_, *tp);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
} else {
auto fam = std::make_unique<BahdanauAttention<T>>(
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp);
fam->SetWeights(
am_v_weights.DataAsSpan<T>(),
am_query_layer_weights.DataAsSpan<T>(),
@ -296,7 +299,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
clip_, *tp);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
}

View file

@ -91,12 +91,6 @@ class DeepCpuAttnLstmOp final : public OpKernel {
bool input_forget_ = false;
ActivationFuncs activation_funcs_;
// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
// across them. mutable due to this.
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
// cost on every call.
mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_ATTN_LSTM", (int)std::thread::hardware_concurrency()};
};
} // namespace contrib

View file

@ -200,6 +200,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
gsl::span<T>& outputs,
gsl::span<T>& final_hidden_state,
gsl::span<T>& final_cell_state) {
onnxruntime::concurrency::ThreadPool* tp = &ttp_;
// copy spans (just T* and size, not data in span) as we may change them
gsl::span<const T> inputs = inputs_arg;
gsl::span<const int> sequence_lengths = sequence_lengths_arg;
@ -254,7 +255,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.cbegin(), input_weights.cend(), // W[iofc]^T
input_size_ + attention_size_, T{0.0},
output_iofc_.begin(), output_iofc_.end(),
hidden_size_x4);
hidden_size_x4, tp);
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
@ -296,7 +297,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.cbegin() + input_size_, input_weights.cend(), // WA[iofc]
input_size_ + attention_size_, T{1.0},
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4);
hidden_size_x4, tp);
// calculate Xt*(W[iofc]^T) + Ht-1*R[iofc]
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, T{1.0},
@ -305,7 +306,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, T{1.0},
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4);
hidden_size_x4, tp);
span_T_iter batched_output, batched_output_end;
if (output_sequence) {
@ -345,7 +346,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
previous_state = batched_output;
previous_state_end = batched_output_end;
attention_wrapper_.ProcessOutput(outputs.subspan(step * output_step_length, batch_size_ * hidden_size_));
attention_wrapper_.ProcessOutput(outputs.subspan(step * output_step_length, batch_size_ * hidden_size_), tp);
}
}

View file

@ -6,6 +6,7 @@
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/mlas/inc/mlas.h"
#include "core/framework/op_kernel_context_internal.h"
namespace onnxruntime {
namespace contrib {
@ -45,7 +46,7 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation(
int64_t char_embedding_size,
int64_t filter_width,
int64_t num_filters,
float* output) const {
float* output, concurrency::ThreadPool* tp) const {
int64_t input_word_size = word_len * char_embedding_size;
int64_t unfolded_width = word_len - filter_width + 1;
int64_t unfolded_kernal_size = filter_width * char_embedding_size;
@ -83,12 +84,12 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation(
tmp_word_inx++;
}
math::GemmEx<float, CPUMathUtil>(
math::GemmEx<float, concurrency::ThreadPool>(
CblasNoTrans, CblasTrans,
static_cast<int>(words_unfolded_width), static_cast<int>(num_filters), static_cast<int>(unfolded_kernal_size), 1.0f,
unfolded_buffer_p.get(), static_cast<int>(unfolded_kernal_size),
weights, static_cast<int>(unfolded_kernal_size), 0.0f,
conv_buf_p, static_cast<int>(num_filters), &CPUMathUtil::Instance());
conv_buf_p, static_cast<int>(num_filters), tp);
for (int64_t unfolded_inx = 0; unfolded_inx < words_unfolded_width; unfolded_inx++)
for (int64_t filter_inx = 0; filter_inx < num_filters; filter_inx++) {
@ -160,6 +161,9 @@ Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, co
}
Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto tp = ctx_internal->GetOperatorThreadPool();
// original lstm processing
const Tensor& sequence = *(ctx->Input<Tensor>(0)); // sequence: [sequence_length, word_length]
const Tensor& w_conv = *(ctx->Input<Tensor>(1)); // conv weight: [M, C/group, kH, kW]
@ -216,7 +220,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
char_embedding_size,
filter_width,
filter_size,
Y->MutableData<float>());
Y->MutableData<float>(), tp);
return Status::OK();
}

View file

@ -8,6 +8,9 @@
#include "core/framework/tensor.h"
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
}
namespace contrib {
class WordConvEmbedding final : public OpKernel {
@ -38,7 +41,7 @@ class WordConvEmbedding final : public OpKernel {
int64_t char_embedding_size,
int64_t filter_width,
int64_t num_filters,
float* output) const;
float* output, onnxruntime::concurrency::ThreadPool* tp) const;
void CalculateLengthOfEachWordInSequence(
const int* seq_ptr,
int* words_len_ptr,

View file

@ -58,6 +58,7 @@ class OpKernelContextInternal : public OpKernelContext {
const bool& GetTerminateFlag() const noexcept { return terminate_flag_; }
const onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return session_state_.GetThreadPool(); }
onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() { return session_state_.GetThreadPool(); }
private:
const SessionState& session_state_;

View file

@ -4,9 +4,11 @@
#pragma once
#include "core/common/common.h"
#include "core/platform/threadpool.h"
#include "core/framework/op_kernel.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/framework/op_kernel_context_internal.h"
#include "gemm_helper.h"
namespace onnxruntime {
@ -30,6 +32,9 @@ class Gemm : public OpKernel {
}
Status Compute(OpKernelContext* context) const override {
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
auto thread_pool = ctx_internal->GetOperatorThreadPool();
const auto X = context->Input<Tensor>(0);
const auto W = context->Input<Tensor>(1);
const auto B = context->Input<Tensor>(2);
@ -95,7 +100,7 @@ class Gemm : public OpKernel {
}
// W * x
math::Gemm<T_X, CPUMathUtil>(
math::Gemm<T_X, concurrency::ThreadPool>(
trans_A_,
trans_B_,
M,
@ -106,7 +111,7 @@ class Gemm : public OpKernel {
W->template Data<T_W>(),
beta_,
y_data,
&CPUMathUtil::Instance());
thread_pool);
FuseActivation<T_Y>(activation_, y_data, M * N, leaky_relu_alpha_);

View file

@ -4,6 +4,8 @@
#include "core/providers/cpu/math/logsoftmax.h"
#include "core/framework/op_kernel.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/providers/common.h"
#include "core/providers/cpu/math/softmax_shared.h"
#include "core/util/math.h"
@ -12,6 +14,9 @@ namespace onnxruntime {
template <>
Status LogSoftmax<float>::Compute(OpKernelContext* ctx) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto tp = ctx_internal->GetOperatorThreadPool();
const auto* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const Tensor& X = *tensor_pointer;
@ -32,7 +37,7 @@ Status LogSoftmax<float>::Compute(OpKernelContext* ctx) const {
const bool logarithmic = true;
auto status = SoftmaxCPU(N, D, X.template Data<float>(), Ydata,
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data());
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp);
return status;
}

View file

@ -2,9 +2,11 @@
// Licensed under the MIT License.
#include "core/providers/cpu/math/matmul.h"
#include "core/platform/threadpool.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/framework/op_kernel_context_internal.h"
#include "matmul_helper.h"
namespace onnxruntime {
@ -53,6 +55,9 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
template <typename T>
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto thread_pool = ctx_internal->GetOperatorThreadPool();
const auto* left_X = ctx->Input<Tensor>(0);
const auto* right_X = ctx->Input<Tensor>(1);
@ -64,7 +69,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
size_t max_len = helper.OutputOffsets().size();
for (size_t i = 0; i < max_len; i++) {
math::Gemm<T, CPUMathUtil>(
math::Gemm<T, concurrency::ThreadPool>(
CblasNoTrans,
CblasNoTrans,
static_cast<int>(helper.M()),
@ -75,7 +80,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
right_X->template Data<T>() + helper.RightOffsets()[i],
/* beta */ 0.0f,
Y->template MutableData<T>() + helper.OutputOffsets()[i],
&CPUMathUtil::Instance());
thread_pool);
}
return Status::OK();

View file

@ -4,6 +4,7 @@
#include "core/providers/cpu/math/softmax.h"
#include "core/framework/op_kernel.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/providers/common.h"
#include "core/providers/cpu/math/softmax_shared.h"
#include "core/util/math.h"
@ -12,6 +13,9 @@ namespace onnxruntime {
template <>
Status Softmax<float>::Compute(OpKernelContext* ctx) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto tp = ctx_internal->GetOperatorThreadPool();
const auto* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const Tensor& X = *tensor_pointer;
@ -34,7 +38,7 @@ Status Softmax<float>::Compute(OpKernelContext* ctx) const {
const bool logarithmic = false;
auto status = SoftmaxCPU(N, D, X.template Data<float>(), Ydata,
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data());
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp);
return status;
}

View file

@ -31,6 +31,7 @@
#endif
#include "core/providers/cpu/math/softmax_shared.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
@ -46,7 +47,7 @@ common::Status SoftmaxCPU(const int64_t N,
float* scale,
const float* sum_multiplier,
bool logarithmic,
float* rowmax) {
float* rowmax, onnxruntime::concurrency::ThreadPool* tp) {
// the Math functions SoftmaxCPU uses only support int32_t as input, so enforce that
if (N * D > INT32_MAX || N > INT32_MAX || D > INT32_MAX) {
std::ostringstream ss;
@ -65,7 +66,7 @@ common::Status SoftmaxCPU(const int64_t N,
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, tp);
// Exponentiation
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);

View file

@ -6,6 +6,9 @@
#include "core/common/status.h"
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
}
/**
Calculate Softmax using CPU memory.
@param N Number of rows
@ -18,5 +21,5 @@ Calculate Softmax using CPU memory.
@param rowmax Storage for calculation of maximum in each row. Size must be >= N.
*/
common::Status SoftmaxCPU(int64_t N, int64_t D, const float* Xdata, float* Ydata, float* scale,
const float* sum_multiplier, bool logarithmic, float* rowmax);
const float* sum_multiplier, bool logarithmic, float* rowmax, concurrency::ThreadPool* tp);
} // namespace onnxruntime

View file

@ -49,6 +49,11 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
const size_t kernel_rank = kernel_shape.size();
// Get access to the internal threadpool
// Temporarily derive concurrency parameters without access to session state
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
auto thread_pool = ctx_internal->GetOperatorThreadPool();
if (kernel_rank == 2 || kernel_rank == 3) {
MLAS_ACTIVATION Activation;
if (activation_.empty()) {
@ -66,11 +71,6 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_);
}
// Get access to the internal threadpool
// Temporarily derive concurrency parameters without access to session state
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
auto thread_pool = ctx_internal->GetOperatorThreadPool();
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
MlasConvPrepare(&Parameters,
@ -87,9 +87,9 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
static_cast<size_t>(M / group_),
&Activation,
&WorkingBufferSize,
const_cast<concurrency::ThreadPool*>(thread_pool));
thread_pool);
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
auto working_data = WorkingBufferSize > 0 ? alloc->AllocArray(sizeof(float), WorkingBufferSize) : nullptr;
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
MlasConv(&Parameters,
@ -98,7 +98,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
B != nullptr ? B->template Data<float>() : nullptr,
static_cast<float*>(working_buffer.get()),
Ydata,
const_cast<concurrency::ThreadPool*>(thread_pool));
thread_pool);
} else {
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
@ -120,7 +120,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>()(
math::Im2colNd<float, onnxruntime::concurrency::ThreadPool, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
image_shape.GetDims().data(),
col_buffer_shape.data(),
@ -132,8 +132,8 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
pads.data(),
static_cast<int>(kernel_shape.size()),
col_buffer_data,
&CPUMathUtil::Instance());
math::Gemm<float, CPUMathUtil>(
thread_pool);
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(
CblasNoTrans,
CblasNoTrans,
M / group_,
@ -144,7 +144,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
col_buffer_data,
0,
Ydata + group_id * Y_offset,
&CPUMathUtil::Instance());
thread_pool);
}
if (B != nullptr) {

View file

@ -16,6 +16,8 @@
/* Modifications Copyright (c) Microsoft. */
#include "core/providers/cpu/nn/conv_transpose.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
@ -228,6 +230,9 @@ Status ConvTranspose<T>::Compute(OpKernelContext* context) const {
template <typename T>
Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
auto tp = ctx_internal->GetOperatorThreadPool();
size_t num_inputs = OpKernel::Node().InputDefs().size();
Prepare p;
bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3;
@ -254,7 +259,7 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
for (auto image_id = 0; image_id < p.N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
// Weight term
math::Gemm<T, CPUMathUtil>(
math::Gemm<T, onnxruntime::concurrency::ThreadPool>(
CblasTrans,
CblasNoTrans,
kernel_dim,
@ -265,7 +270,7 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
Xdata + group_id * X_offset,
0,
col_buffer_data,
&CPUMathUtil::Instance());
tp);
// Col2im
math::Col2im<T, CPUMathUtil, StorageOrder::NCHW>(

View file

@ -18,6 +18,7 @@
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/platform/ort_mutex.h"
@ -171,7 +172,7 @@ class UniDirectionalGru {
void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
const gsl::span<const T>& input_weights, const gsl::span<const T>& recurrent_weights,
gsl::span<T>& outputs, gsl::span<T>& final_hidden_state);
gsl::span<T>& outputs, gsl::span<T>& final_hidden_state, onnxruntime::concurrency::ThreadPool* tp);
~UniDirectionalGru() = default;
@ -263,6 +264,9 @@ Status DeepCpuGruOp::Compute(OpKernelContext* context) const {
template <typename T>
Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(&context);
auto tp = ctx_internal->GetOperatorThreadPool();
const Tensor& X = *context.Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
const Tensor& W = *context.Input<Tensor>(1); // weights. [num_directions, 3*hidden_size, input_size]
const Tensor& R = *context.Input<Tensor>(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size]
@ -375,7 +379,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
clip_);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, tp);
std::unique_ptr<detail::UniDirectionalGru<T>> bw = std::make_unique<detail::UniDirectionalGru<T>>(
alloc,
@ -389,7 +393,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[2],
activation_funcs_.Entries()[3],
clip_);
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2);
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2, tp);
} else {
std::unique_ptr<detail::UniDirectionalGru<T>> gru_p = std::make_unique<detail::UniDirectionalGru<T>>(
alloc,
@ -404,7 +408,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[1],
clip_);
gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, tp);
}
if (!output.empty())
@ -505,7 +509,8 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
const gsl::span<const T>& input_weights,
const gsl::span<const T>& recurrent_weights,
gsl::span<T>& outputs,
gsl::span<T>& final_hidden_state) {
gsl::span<T>& final_hidden_state,
onnxruntime::concurrency::ThreadPool* tp) {
using span_T_const_iter = typename gsl::span<T>::const_iterator;
using span_T_iter = typename gsl::span<T>::iterator;
@ -559,7 +564,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.cbegin(), input_weights.cend(),
input_size_, beta,
outputZRH_.begin(), outputZRH_.end(),
hidden_size_x3);
hidden_size_x3, tp);
DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_);
@ -624,7 +629,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
hidden_size_, beta,
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3);
hidden_size_x3, tp);
DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
@ -640,7 +645,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output
hidden_size_);
hidden_size_, tp);
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
}
@ -707,7 +712,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
out_H, outputZRH_.end(),
hidden_size_x3);
hidden_size_x3, tp);
}
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,

View file

@ -783,7 +783,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.cbegin(), input_weights.cend(), // W[iofc]
input_size_, beta,
output_iofc_.begin(), output_iofc_.end(),
hidden_size_x4);
hidden_size_x4, &ttp_);
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
@ -832,7 +832,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta,
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4);
hidden_size_x4, &ttp_);
DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str,
&*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4);
@ -910,7 +910,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta,
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4);
hidden_size_x4, &ttp_);
span_T_iter batched_output;
span_T_iter batched_output_end;

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel_context_internal.h"
#include "core/providers/cpu/rnn/rnn.h"
#include "core/providers/cpu/rnn/rnn_activation_functors.h"
#include "core/providers/cpu/rnn/rnn_helpers.h"
@ -98,6 +99,9 @@ using EigenMatrixMapRowMajor = Eigen::Map<
template <>
Status RNN<float>::Compute(OpKernelContext* ctx) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto tp = ctx_internal->GetOperatorThreadPool();
using namespace rnn::detail;
// inputs
@ -160,7 +164,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
}
// X * W[direction]^t + B
math::Gemm<float, CPUMathUtil>(
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(
CblasNoTrans,
CblasTrans,
static_cast<int>(seq_length * batch_size),
@ -171,7 +175,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
W.template Data<float>() + direction * hidden_size_ * input_size,
1,
x_matmul_w_buffer_data,
&CPUMathUtil::Instance());
tp);
for (int64_t t = 0; t < seq_length; t++) {
int64_t time_step = isReverse ? (seq_length - t - 1) : t;
@ -192,7 +196,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
if (h_prev != nullptr) {
// H_t_1 * R[direction]^t
math::Gemm<float, CPUMathUtil>(
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(
CblasNoTrans,
CblasTrans,
static_cast<int>(batch_size),
@ -203,7 +207,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
R.template Data<float>() + direction * hidden_size_ * hidden_size_,
0,
Y_buffer_data_current_frame,
&CPUMathUtil::Instance());
tp);
} else {
math::Set<float, CPUMathUtil>(batch_size * hidden_size_, 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance());
}

View file

@ -159,7 +159,7 @@ void ComputeGemm(const int M,
const float beta,
TSpanCIter C,
TSpanCIter C_end,
const int ldc) {
const int ldc, onnxruntime::concurrency::ThreadPool* tp) {
// validate all the inputs
// need to use the lda/ldb/ldc strides which should be >= the columns for the span
ORT_ENFORCE(lda >= K && ldb >= K && ldc >= N);
@ -167,12 +167,12 @@ void ComputeGemm(const int M,
ORT_ENFORCE(B + (N * ldb - (ldb - K)) <= B_end);
ORT_ENFORCE(C + (M * ldc - (ldc - N)) <= C_end);
::onnxruntime::math::GemmEx<float, CPUMathUtil>(
::onnxruntime::math::GemmEx<float, ::onnxruntime::concurrency::ThreadPool>(
CblasNoTrans, CblasTrans,
M, N, K, alpha,
&*A, lda,
&*B, ldb, beta,
&*C, ldc, &CPUMathUtil::Instance());
&*C, ldc, tp);
}
// helper to convert a span to a raw pointer

View file

@ -40,6 +40,9 @@ extern "C" {
#include "core/framework/tensor.h"
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
}
enum StorageOrder {
UNKNOWN = 0,
@ -187,10 +190,7 @@ void Gemm(
const T* B,
float beta,
T* C,
Provider* provider,
//Caffe2 use this type to control on GPU, what presicion do we want to do the calculation
//But not sure is this a good design for us. Keep it here for now.
MLDataType math_type = FLOAT_TYPE);
Provider*);
// We also provide a gemm that has explicit lda, ldb and ldc specified.
// In most cases you probably want to use the function above, though.
@ -209,7 +209,7 @@ void GemmEx(
T beta,
T* C,
int ldc,
Provider* provider);
Provider*);
// GemmBatched provides a simple abstraction into library routines
template <typename T, class Provider>
@ -228,9 +228,7 @@ void GemmBatched(
const T* B,
float beta,
T* C,
Provider* provider,
Tensor* scratch = nullptr,
MLDataType math_type = DataTypeImpl::FLOAT_TYPE);
Provider* tp);
// Gemv always takes in a M*N matrix A, and depending on whether we set TransA
// to Trans, the output is:

View file

@ -40,9 +40,7 @@
#include "core/util/math_cpuonly.h"
#include "Eigen/src/Core/arch/GPU/Half.h"
#if defined(USE_MLAS)
#include "core/mlas/inc/mlas.h"
#endif
namespace onnxruntime {
namespace math {
@ -107,7 +105,7 @@ void GemmEigen(
// CBLAS call or the Eigen implementation.
////////////////////////////////////////////////////////////////////////////////
// when USE_MKLML is defined, use cblas APIs for MKLML
#if defined(USE_EIGEN_FOR_BLAS) && !defined(USE_MKLML_FOR_BLAS)
#if !defined(USE_MKLML_FOR_BLAS)
// Caffe2 gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
@ -125,112 +123,60 @@ void GemmEigen(
// (transpose) if the argument TransA or TransB is set to CblasNoTrans or
// CblasTrans, respectively, for each of A and B.
template <>
void Gemm<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
float* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
#if defined(USE_MLAS)
void Gemm<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
float* C, onnxruntime::concurrency::ThreadPool* threadpool) {
int lda = static_cast<int>((TransA == CblasNoTrans) ? K : M);
int ldb = static_cast<int>((TransB == CblasNoTrans) ? N : K);
// TODO: Make this use the operator threadpool
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, nullptr);
#else
GemmEigen<float>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
#endif
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, threadpool);
}
template <>
void Gemm<double, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
float beta, double* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<double, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
float beta, double* C, onnxruntime::concurrency::ThreadPool*) {
// No double precision Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<double>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<int32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
float beta, int32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<int32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
float beta, int32_t* C, onnxruntime::concurrency::ThreadPool*) {
// No int32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
float beta, uint32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<uint32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
float beta, uint32_t* C, onnxruntime::concurrency::ThreadPool*) {
// No uint32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<int64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
float beta, int64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<int64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
float beta, int64_t* C, onnxruntime::concurrency::ThreadPool*) {
// No int64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
float beta, uint64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<uint64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
float beta, uint64_t* C, onnxruntime::concurrency::ThreadPool*) {
// No uint64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void GemmEx<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
int ldc, CPUMathUtil*) {
#if defined(USE_MLAS)
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, nullptr);
#else
using OuterStride = Eigen::OuterStride<Eigen::Dynamic>;
using StridedMap = Eigen::Map<Eigen::MatrixXf, 0, OuterStride>;
using ConstStridedMap = Eigen::Map<const Eigen::MatrixXf, 0, OuterStride>;
auto C_mat = StridedMap(C, N, M, OuterStride(ldc));
if (beta == 0) {
C_mat.setZero();
} else {
C_mat *= beta;
}
switch (TransA) {
case CblasNoTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() +=
alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) *
ConstStridedMap(A, K, M, OuterStride(lda)));
return;
case CblasTrans:
C_mat.noalias() +=
alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() *
ConstStridedMap(A, K, M, OuterStride(lda)));
return;
default:
ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
case CblasTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() +=
alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) *
ConstStridedMap(A, M, K, OuterStride(lda)).transpose());
return;
case CblasTrans:
C_mat.noalias() +=
alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() *
ConstStridedMap(A, M, K, OuterStride(lda)).transpose());
return;
default:
ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
default:
ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA);
}
#endif
void GemmEx<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
int ldc, onnxruntime::concurrency::ThreadPool* tp) {
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, tp);
}
template <>
@ -301,12 +247,12 @@ SPECIALIZED_AXPY(float)
SPECIALIZED_AXPBY(float)
#undef SPECIALIZED_AXPBY
#else // USE_EIGEN_FOR_BLAS
#else
template <>
void Gemm<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
float* C, CPUMathUtil* /*context*/, MLDataType /*math_type*/) {
void Gemm<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
float* C, concurrency::ThreadPool*) {
int lda = gsl::narrow_cast<int>((TransA == CblasNoTrans) ? K : M);
int ldb = gsl::narrow_cast<int>((TransB == CblasNoTrans) ? N : K);
cblas_sgemm(CblasRowMajor, TransA, TransB,
@ -318,9 +264,9 @@ void Gemm<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOS
}
template <>
void Gemm<double, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
float beta, double* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<double, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
float beta, double* C, concurrency::ThreadPool*) {
int lda = gsl::narrow_cast<int>((TransA == CblasNoTrans) ? K : M);
int ldb = gsl::narrow_cast<int>((TransB == CblasNoTrans) ? N : K);
cblas_dgemm(CblasRowMajor, TransA, TransB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
@ -329,41 +275,41 @@ void Gemm<double, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPO
}
template <>
void Gemm<int32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
float beta, int32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<int32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
float beta, int32_t* C, concurrency::ThreadPool*) {
// No int32_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
float beta, uint32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<uint32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
float beta, uint32_t* C, concurrency::ThreadPool*) {
// No uint32_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<int64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
float beta, int64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<int64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
float beta, int64_t* C, concurrency::ThreadPool*) {
// No int64_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
float beta, uint64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
void Gemm<uint64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
float beta, uint64_t* C, concurrency::ThreadPool*) {
// No uint64_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void GemmEx<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
int ldc, CPUMathUtil* /*context*/) {
void GemmEx<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
int ldc, concurrency::ThreadPool*) {
cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
@ -417,20 +363,19 @@ CAFFE2_SPECIALIZED_AXPY(float, s)
CAFFE2_SPECIALIZED_AXPBY(float, s)
#undef CAFFE2_SPECIALIZED_AXPBY
#endif // USE_EIGEN_FOR_BLAS
#endif
template <>
void GemmBatched<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int A_size,
int A_batches, int B_size, int B_batches, int M, int N, int K, float /*alpha*/,
const float* A, const float* B, float /*beta*/, float* C, CPUMathUtil* provider,
Tensor*, /* scratch */
MLDataType /* math_type */) {
void GemmBatched<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int A_size,
int A_batches, int B_size, int B_batches, int M, int N, int K, float /*alpha*/,
const float* A, const float* B, float /*beta*/, float* C,
onnxruntime::concurrency::ThreadPool* tp) {
auto a_offset = A_size / A_batches;
auto b_offset = B_size / B_batches;
auto y_offset = M * N;
// loop over matrices in the batch
for (int i = 0; i < A_batches; ++i) {
math::Gemm<float, CPUMathUtil>(
math::Gemm<float, concurrency::ThreadPool>(
TransA,
TransB,
M,
@ -441,7 +386,7 @@ void GemmBatched<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_T
B + b_offset * i,
0,
C + y_offset * i,
provider);
tp);
}
}

View file

@ -17,12 +17,14 @@
#include "core/util/math.h"
#include <gtest/gtest.h>
#include "core/platform/threadpool.h"
#include "core/util/math_cpuonly.h"
namespace onnxruntime {
#define VECTOR_HEAD(x) x.size() > 0 ? &x[0] : NULL
TEST(MathTest, GemmNoTransNoTrans) {
concurrency::ThreadPool tp("", 1);
auto& provider = CPUMathUtil::Instance();
std::vector<float> X(50); // 5 * 10
std::vector<float> W(60); // 10 * 6
@ -40,26 +42,26 @@ TEST(MathTest, GemmNoTransNoTrans) {
const float kOne = 1.0;
const float kPointFive = 0.5;
const float kZero = 0.0;
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
&provider);
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
&tp);
EXPECT_EQ(Y.size(), 30);
for (size_t i = 0; i < Y.size(); ++i) {
EXPECT_EQ(Y[i], 10) << i;
}
// Test Accumulate
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
VECTOR_HEAD(Y), &provider);
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
VECTOR_HEAD(Y), &tp);
EXPECT_EQ(Y.size(), 30);
for (size_t i = 0; i < Y.size(); ++i) {
EXPECT_EQ(Y[i], 15) << i;
}
// Test Accumulate
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, 5, 6, 10,
kPointFive,
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
&provider);
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, 5, 6, 10,
kPointFive,
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
&tp);
EXPECT_EQ(Y.size(), 30);
for (size_t i = 0; i < Y.size(); ++i) {
EXPECT_EQ(Y[i], 20) << i;
@ -68,6 +70,8 @@ TEST(MathTest, GemmNoTransNoTrans) {
TEST(MathTest, GemmNoTransTrans) {
auto& provider = CPUMathUtil::Instance();
concurrency::ThreadPool tp("", 1);
std::vector<float> X(50); // 5 * 10
std::vector<float> W(60); // 10 * 6
std::vector<float> Y(30); // 5 * 6
@ -84,24 +88,24 @@ TEST(MathTest, GemmNoTransTrans) {
const float kOne = 1.0;
const float kPointFive = 0.5;
const float kZero = 0.0;
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
&provider);
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
&tp);
EXPECT_EQ(Y.size(), 30);
for (size_t i = 0; i < Y.size(); ++i) {
EXPECT_EQ(Y[i], 10) << i;
}
// Test Accumulate
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
VECTOR_HEAD(Y), &provider);
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
VECTOR_HEAD(Y), &tp);
EXPECT_EQ(Y.size(), 30);
for (size_t i = 0; i < Y.size(); ++i) {
EXPECT_EQ(Y[i], 15) << i;
}
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasTrans, 5, 6, 10, kPointFive,
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
&provider);
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasTrans, 5, 6, 10, kPointFive,
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
&tp);
EXPECT_EQ(Y.size(), 30);
for (size_t i = 0; i < Y.size(); ++i) {
EXPECT_EQ(Y[i], 20) << i;

View file

@ -194,23 +194,23 @@ TEST(SoftmaxOperator, InvalidAxis) {
TEST(SoftmaxOperator, TestInputTooLarge) {
float* ignored = nullptr;
concurrency::ThreadPool tp("", 1);
// N > INT32_MAX
int64_t N = int64_t(INT32_MAX) + 1;
int64_t D = 1;
auto status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored);
auto status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp);
EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT);
// D > INT32_MAX
N = 1;
D = int64_t(INT32_MAX) + 1;
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored);
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp);
EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT);
// N * D > INT32_MAX
N = int64_t(INT32_MAX) / 2;
D = 3;
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored);
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp);
EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT);
/*