mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
Let mlas use the session threadpool for gemm functions (#1196)
This commit is contained in:
parent
be36385a8c
commit
280ab9a2d0
27 changed files with 231 additions and 234 deletions
|
|
@ -6,6 +6,9 @@
|
|||
#include <gsl/span>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "attention_wrapper.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/providers/cpu/rnn/rnn_helpers.h"
|
||||
|
||||
#include <stdexcept>
|
||||
|
|
@ -34,14 +35,14 @@ AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger&
|
|||
|
||||
// rnn_cell_output is of [batch_size, rnn_cell_hidden_size]
|
||||
template <typename T>
|
||||
void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_output) {
|
||||
void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_output, concurrency::ThreadPool* tp) {
|
||||
if (has_attn_layer_) {
|
||||
// rnn_cell_output * cell_weights, (part of the attention layer above the attention mechanism).
|
||||
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
|
||||
rnn_cell_output.data(), inner_cell_hidden_size_,
|
||||
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
|
||||
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
|
||||
math::GemmEx<T, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
|
||||
rnn_cell_output.data(), inner_cell_hidden_size_,
|
||||
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
|
||||
attn_states_.data(), attn_layer_depth_, tp);
|
||||
}
|
||||
|
||||
// Get the context which is calculated within attention mechanism.
|
||||
|
|
@ -54,11 +55,11 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
|
|||
//concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
|
||||
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
|
||||
// The first part is calulated above. Here just add the later.
|
||||
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
|
||||
attn_context_.data(), attn_context_depth_,
|
||||
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
|
||||
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
|
||||
math::GemmEx<T, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
|
||||
attn_context_.data(), attn_context_depth_,
|
||||
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
|
||||
attn_states_.data(), attn_layer_depth_, tp);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@
|
|||
#include "core/framework/allocator.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -27,7 +30,7 @@ class AttentionWrapper {
|
|||
virtual ~AttentionWrapper() = default;
|
||||
|
||||
// Calculation based on output of the inner wrapped rnn_cell.
|
||||
void ProcessOutput(const gsl::span<const T>& rnn_cell_state);
|
||||
void ProcessOutput(const gsl::span<const T>& rnn_cell_state, onnxruntime::concurrency::ThreadPool* tp);
|
||||
|
||||
gsl::span<const T> GetAttnStates() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ namespace contrib {
|
|||
template <typename T>
|
||||
BahdanauAttention<T>::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger,
|
||||
int batch_size, int max_memory_step, int memory_depth,
|
||||
int query_depth, int attn_depth, bool normalize)
|
||||
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize) {
|
||||
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* tp)
|
||||
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), tp_(tp) {
|
||||
values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true);
|
||||
keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true);
|
||||
processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true);
|
||||
|
|
@ -72,11 +72,11 @@ void BahdanauAttention<T>::PrepareMemory(
|
|||
"Real memory steps ", mem_steps, " is not in (0, ", max_memory_steps_, "]");
|
||||
}
|
||||
|
||||
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
|
||||
memory.data(), memory_depth_,
|
||||
memory_layer_weights_.data(), attn_depth_, T{0.0},
|
||||
keys_.data(), attn_depth_, &CPUMathUtil::Instance());
|
||||
math::GemmEx<T, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
|
||||
memory.data(), memory_depth_,
|
||||
memory_layer_weights_.data(), attn_depth_, T{0.0},
|
||||
keys_.data(), attn_depth_, tp_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -115,11 +115,11 @@ void BahdanauAttention<T>::Compute(
|
|||
const gsl::span<T>& output,
|
||||
const gsl::span<T>& aligns) const {
|
||||
//process query in dense query layer without bias
|
||||
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_depth_, query_depth_, T{1.0},
|
||||
queries.data(), query_depth_,
|
||||
query_layer_weights_.data(), attn_depth_, T{0.0},
|
||||
processed_query_.data(), attn_depth_, &CPUMathUtil::Instance());
|
||||
math::GemmEx<T, onnxruntime::concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_depth_, query_depth_, T{1.0},
|
||||
queries.data(), query_depth_,
|
||||
query_layer_weights_.data(), attn_depth_, T{0.0},
|
||||
processed_query_.data(), attn_depth_, tp_);
|
||||
|
||||
std::fill(aligns.begin(), aligns.end(), T{});
|
||||
|
||||
|
|
@ -146,11 +146,11 @@ void BahdanauAttention<T>::Compute(
|
|||
// Calculate the context
|
||||
auto outspan = output.subspan(b * memory_depth_);
|
||||
auto values = values_.subspan(b * max_memory_steps_ * memory_depth_);
|
||||
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
|
||||
1, memory_depth_, max_memory_steps_, T{1.0},
|
||||
alignments, max_memory_steps_,
|
||||
values.data(), memory_depth_, T{0.0},
|
||||
outspan.data(), memory_depth_, &CPUMathUtil::Instance());
|
||||
math::GemmEx<T, onnxruntime::concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans,
|
||||
1, memory_depth_, max_memory_steps_, T{1.0},
|
||||
alignments, max_memory_steps_,
|
||||
values.data(), memory_depth_, T{0.0},
|
||||
outspan.data(), memory_depth_, tp_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
|
|||
int memory_depth,
|
||||
int query_depth,
|
||||
int attn_depth,
|
||||
bool normalize);
|
||||
bool normalize, concurrency::ThreadPool* tp);
|
||||
|
||||
void SetWeights(
|
||||
const gsl::span<const T>& attn_weights,
|
||||
|
|
@ -53,7 +53,6 @@ class BahdanauAttention : public IAttentionMechanism<T> {
|
|||
private:
|
||||
AllocatorPtr allocator_;
|
||||
const logging::Logger& logger_;
|
||||
|
||||
int batch_size_;
|
||||
int max_memory_steps_;
|
||||
int memory_depth_;
|
||||
|
|
@ -77,6 +76,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
|
|||
gsl::span<int> mem_seq_lengths_;
|
||||
|
||||
bool normalize_;
|
||||
concurrency::ThreadPool* const tp_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include "core/common/common.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -70,6 +71,8 @@ static gsl::span<const T> SecondHalfSpan(const gsl::span<const T>& dspan) {
|
|||
|
||||
template <typename T>
|
||||
Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(&context);
|
||||
auto tp = ctx_internal->GetOperatorThreadPool();
|
||||
auto& logger = context.Logger();
|
||||
|
||||
// original lstm processing
|
||||
|
|
@ -229,7 +232,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
last_cell_size_per_direction);
|
||||
|
||||
auto fam = std::make_unique<BahdanauAttention<T>>(
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp);
|
||||
fam->SetWeights(
|
||||
FirstHalfSpan(am_v_weights.DataAsSpan<T>()),
|
||||
FirstHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
|
||||
|
|
@ -248,10 +251,10 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
activation_funcs_.Entries()[2],
|
||||
clip_, ttp_);
|
||||
clip_, *tp);
|
||||
|
||||
auto bam = std::make_unique<BahdanauAttention<T>>(
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp);
|
||||
bam->SetWeights(
|
||||
SecondHalfSpan(am_v_weights.DataAsSpan<T>()),
|
||||
SecondHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
|
||||
|
|
@ -270,14 +273,14 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[3],
|
||||
activation_funcs_.Entries()[4],
|
||||
activation_funcs_.Entries()[5],
|
||||
clip_, ttp_);
|
||||
clip_, *tp);
|
||||
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
|
||||
|
||||
} else {
|
||||
auto fam = std::make_unique<BahdanauAttention<T>>(
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
|
||||
alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false, tp);
|
||||
fam->SetWeights(
|
||||
am_v_weights.DataAsSpan<T>(),
|
||||
am_query_layer_weights.DataAsSpan<T>(),
|
||||
|
|
@ -296,7 +299,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
activation_funcs_.Entries()[2],
|
||||
clip_, ttp_);
|
||||
clip_, *tp);
|
||||
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,12 +91,6 @@ class DeepCpuAttnLstmOp final : public OpKernel {
|
|||
bool input_forget_ = false;
|
||||
|
||||
ActivationFuncs activation_funcs_;
|
||||
|
||||
// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
|
||||
// across them. mutable due to this.
|
||||
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
|
||||
// cost on every call.
|
||||
mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_ATTN_LSTM", (int)std::thread::hardware_concurrency()};
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -200,6 +200,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
gsl::span<T>& outputs,
|
||||
gsl::span<T>& final_hidden_state,
|
||||
gsl::span<T>& final_cell_state) {
|
||||
onnxruntime::concurrency::ThreadPool* tp = &ttp_;
|
||||
// copy spans (just T* and size, not data in span) as we may change them
|
||||
gsl::span<const T> inputs = inputs_arg;
|
||||
gsl::span<const int> sequence_lengths = sequence_lengths_arg;
|
||||
|
|
@ -254,7 +255,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
input_weights.cbegin(), input_weights.cend(), // W[iofc]^T
|
||||
input_size_ + attention_size_, T{0.0},
|
||||
output_iofc_.begin(), output_iofc_.end(),
|
||||
hidden_size_x4);
|
||||
hidden_size_x4, tp);
|
||||
|
||||
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
|
||||
|
||||
|
|
@ -296,7 +297,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
input_weights.cbegin() + input_size_, input_weights.cend(), // WA[iofc]
|
||||
input_size_ + attention_size_, T{1.0},
|
||||
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
|
||||
hidden_size_x4);
|
||||
hidden_size_x4, tp);
|
||||
|
||||
// calculate Xt*(W[iofc]^T) + Ht-1*R[iofc]
|
||||
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, T{1.0},
|
||||
|
|
@ -305,7 +306,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
|
||||
hidden_size_, T{1.0},
|
||||
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
|
||||
hidden_size_x4);
|
||||
hidden_size_x4, tp);
|
||||
|
||||
span_T_iter batched_output, batched_output_end;
|
||||
if (output_sequence) {
|
||||
|
|
@ -345,7 +346,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
previous_state = batched_output;
|
||||
previous_state_end = batched_output_end;
|
||||
|
||||
attention_wrapper_.ProcessOutput(outputs.subspan(step * output_step_length, batch_size_ * hidden_size_));
|
||||
attention_wrapper_.ProcessOutput(outputs.subspan(step * output_step_length, batch_size_ * hidden_size_), tp);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -45,7 +46,7 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation(
|
|||
int64_t char_embedding_size,
|
||||
int64_t filter_width,
|
||||
int64_t num_filters,
|
||||
float* output) const {
|
||||
float* output, concurrency::ThreadPool* tp) const {
|
||||
int64_t input_word_size = word_len * char_embedding_size;
|
||||
int64_t unfolded_width = word_len - filter_width + 1;
|
||||
int64_t unfolded_kernal_size = filter_width * char_embedding_size;
|
||||
|
|
@ -83,12 +84,12 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation(
|
|||
tmp_word_inx++;
|
||||
}
|
||||
|
||||
math::GemmEx<float, CPUMathUtil>(
|
||||
math::GemmEx<float, concurrency::ThreadPool>(
|
||||
CblasNoTrans, CblasTrans,
|
||||
static_cast<int>(words_unfolded_width), static_cast<int>(num_filters), static_cast<int>(unfolded_kernal_size), 1.0f,
|
||||
unfolded_buffer_p.get(), static_cast<int>(unfolded_kernal_size),
|
||||
weights, static_cast<int>(unfolded_kernal_size), 0.0f,
|
||||
conv_buf_p, static_cast<int>(num_filters), &CPUMathUtil::Instance());
|
||||
conv_buf_p, static_cast<int>(num_filters), tp);
|
||||
|
||||
for (int64_t unfolded_inx = 0; unfolded_inx < words_unfolded_width; unfolded_inx++)
|
||||
for (int64_t filter_inx = 0; filter_inx < num_filters; filter_inx++) {
|
||||
|
|
@ -160,6 +161,9 @@ Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, co
|
|||
}
|
||||
|
||||
Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto tp = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
// original lstm processing
|
||||
const Tensor& sequence = *(ctx->Input<Tensor>(0)); // sequence: [sequence_length, word_length]
|
||||
const Tensor& w_conv = *(ctx->Input<Tensor>(1)); // conv weight: [M, C/group, kH, kW]
|
||||
|
|
@ -216,7 +220,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
|
|||
char_embedding_size,
|
||||
filter_width,
|
||||
filter_size,
|
||||
Y->MutableData<float>());
|
||||
Y->MutableData<float>(), tp);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@
|
|||
#include "core/framework/tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
namespace contrib {
|
||||
|
||||
class WordConvEmbedding final : public OpKernel {
|
||||
|
|
@ -38,7 +41,7 @@ class WordConvEmbedding final : public OpKernel {
|
|||
int64_t char_embedding_size,
|
||||
int64_t filter_width,
|
||||
int64_t num_filters,
|
||||
float* output) const;
|
||||
float* output, onnxruntime::concurrency::ThreadPool* tp) const;
|
||||
void CalculateLengthOfEachWordInSequence(
|
||||
const int* seq_ptr,
|
||||
int* words_len_ptr,
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class OpKernelContextInternal : public OpKernelContext {
|
|||
const bool& GetTerminateFlag() const noexcept { return terminate_flag_; }
|
||||
|
||||
const onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return session_state_.GetThreadPool(); }
|
||||
onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() { return session_state_.GetThreadPool(); }
|
||||
|
||||
private:
|
||||
const SessionState& session_state_;
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "gemm_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -30,6 +32,9 @@ class Gemm : public OpKernel {
|
|||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
auto thread_pool = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
const auto X = context->Input<Tensor>(0);
|
||||
const auto W = context->Input<Tensor>(1);
|
||||
const auto B = context->Input<Tensor>(2);
|
||||
|
|
@ -95,7 +100,7 @@ class Gemm : public OpKernel {
|
|||
}
|
||||
|
||||
// W * x
|
||||
math::Gemm<T_X, CPUMathUtil>(
|
||||
math::Gemm<T_X, concurrency::ThreadPool>(
|
||||
trans_A_,
|
||||
trans_B_,
|
||||
M,
|
||||
|
|
@ -106,7 +111,7 @@ class Gemm : public OpKernel {
|
|||
W->template Data<T_W>(),
|
||||
beta_,
|
||||
y_data,
|
||||
&CPUMathUtil::Instance());
|
||||
thread_pool);
|
||||
|
||||
FuseActivation<T_Y>(activation_, y_data, M * N, leaky_relu_alpha_);
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
#include "core/providers/cpu/math/logsoftmax.h"
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/math/softmax_shared.h"
|
||||
#include "core/util/math.h"
|
||||
|
|
@ -12,6 +14,9 @@ namespace onnxruntime {
|
|||
|
||||
template <>
|
||||
Status LogSoftmax<float>::Compute(OpKernelContext* ctx) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto tp = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
const auto* tensor_pointer = ctx->Input<Tensor>(0);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
const Tensor& X = *tensor_pointer;
|
||||
|
|
@ -32,7 +37,7 @@ Status LogSoftmax<float>::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
const bool logarithmic = true;
|
||||
auto status = SoftmaxCPU(N, D, X.template Data<float>(), Ydata,
|
||||
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data());
|
||||
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/math/matmul.h"
|
||||
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
#include "matmul_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -53,6 +55,9 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
|||
|
||||
template <typename T>
|
||||
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto thread_pool = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
const auto* left_X = ctx->Input<Tensor>(0);
|
||||
const auto* right_X = ctx->Input<Tensor>(1);
|
||||
|
||||
|
|
@ -64,7 +69,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
|||
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
|
||||
size_t max_len = helper.OutputOffsets().size();
|
||||
for (size_t i = 0; i < max_len; i++) {
|
||||
math::Gemm<T, CPUMathUtil>(
|
||||
math::Gemm<T, concurrency::ThreadPool>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
static_cast<int>(helper.M()),
|
||||
|
|
@ -75,7 +80,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
|||
right_X->template Data<T>() + helper.RightOffsets()[i],
|
||||
/* beta */ 0.0f,
|
||||
Y->template MutableData<T>() + helper.OutputOffsets()[i],
|
||||
&CPUMathUtil::Instance());
|
||||
thread_pool);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/providers/cpu/math/softmax.h"
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/math/softmax_shared.h"
|
||||
#include "core/util/math.h"
|
||||
|
|
@ -12,6 +13,9 @@ namespace onnxruntime {
|
|||
|
||||
template <>
|
||||
Status Softmax<float>::Compute(OpKernelContext* ctx) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto tp = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
const auto* tensor_pointer = ctx->Input<Tensor>(0);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
const Tensor& X = *tensor_pointer;
|
||||
|
|
@ -34,7 +38,7 @@ Status Softmax<float>::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
const bool logarithmic = false;
|
||||
auto status = SoftmaxCPU(N, D, X.template Data<float>(), Ydata,
|
||||
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data());
|
||||
scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@
|
|||
#endif
|
||||
|
||||
#include "core/providers/cpu/math/softmax_shared.h"
|
||||
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
|
|
@ -46,7 +47,7 @@ common::Status SoftmaxCPU(const int64_t N,
|
|||
float* scale,
|
||||
const float* sum_multiplier,
|
||||
bool logarithmic,
|
||||
float* rowmax) {
|
||||
float* rowmax, onnxruntime::concurrency::ThreadPool* tp) {
|
||||
// the Math functions SoftmaxCPU uses only support int32_t as input, so enforce that
|
||||
if (N * D > INT32_MAX || N > INT32_MAX || D > INT32_MAX) {
|
||||
std::ostringstream ss;
|
||||
|
|
@ -65,7 +66,7 @@ common::Status SoftmaxCPU(const int64_t N,
|
|||
// Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
|
||||
gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
|
||||
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
|
||||
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, tp);
|
||||
|
||||
// Exponentiation
|
||||
math::Exp<float, CPUMathUtil>(nd, Ydata, Ydata, nullptr);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@
|
|||
#include "core/common/status.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
/**
|
||||
Calculate Softmax using CPU memory.
|
||||
@param N Number of rows
|
||||
|
|
@ -18,5 +21,5 @@ Calculate Softmax using CPU memory.
|
|||
@param rowmax Storage for calculation of maximum in each row. Size must be >= N.
|
||||
*/
|
||||
common::Status SoftmaxCPU(int64_t N, int64_t D, const float* Xdata, float* Ydata, float* scale,
|
||||
const float* sum_multiplier, bool logarithmic, float* rowmax);
|
||||
const float* sum_multiplier, bool logarithmic, float* rowmax, concurrency::ThreadPool* tp);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -49,6 +49,11 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
// Get access to the internal threadpool
|
||||
// Temporarily derive concurrency parameters without access to session state
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
auto thread_pool = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
if (kernel_rank == 2 || kernel_rank == 3) {
|
||||
MLAS_ACTIVATION Activation;
|
||||
if (activation_.empty()) {
|
||||
|
|
@ -66,11 +71,6 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_);
|
||||
}
|
||||
|
||||
// Get access to the internal threadpool
|
||||
// Temporarily derive concurrency parameters without access to session state
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
auto thread_pool = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize;
|
||||
MlasConvPrepare(&Parameters,
|
||||
|
|
@ -87,9 +87,9 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
static_cast<size_t>(M / group_),
|
||||
&Activation,
|
||||
&WorkingBufferSize,
|
||||
const_cast<concurrency::ThreadPool*>(thread_pool));
|
||||
thread_pool);
|
||||
|
||||
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
|
||||
auto working_data = WorkingBufferSize > 0 ? alloc->AllocArray(sizeof(float), WorkingBufferSize) : nullptr;
|
||||
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
|
||||
|
||||
MlasConv(&Parameters,
|
||||
|
|
@ -98,7 +98,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
B != nullptr ? B->template Data<float>() : nullptr,
|
||||
static_cast<float*>(working_buffer.get()),
|
||||
Ydata,
|
||||
const_cast<concurrency::ThreadPool*>(thread_pool));
|
||||
thread_pool);
|
||||
} else {
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
|
|
@ -120,7 +120,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>()(
|
||||
math::Im2colNd<float, onnxruntime::concurrency::ThreadPool, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
image_shape.GetDims().data(),
|
||||
col_buffer_shape.data(),
|
||||
|
|
@ -132,8 +132,8 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
pads.data(),
|
||||
static_cast<int>(kernel_shape.size()),
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
math::Gemm<float, CPUMathUtil>(
|
||||
thread_pool);
|
||||
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M / group_,
|
||||
|
|
@ -144,7 +144,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
col_buffer_data,
|
||||
0,
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
thread_pool);
|
||||
}
|
||||
|
||||
if (B != nullptr) {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
#include "core/providers/cpu/nn/conv_transpose.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
|
|
@ -228,6 +230,9 @@ Status ConvTranspose<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
template <typename T>
|
||||
Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
auto tp = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
Prepare p;
|
||||
bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3;
|
||||
|
|
@ -254,7 +259,7 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
for (auto image_id = 0; image_id < p.N; ++image_id) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
// Weight term
|
||||
math::Gemm<T, CPUMathUtil>(
|
||||
math::Gemm<T, onnxruntime::concurrency::ThreadPool>(
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
kernel_dim,
|
||||
|
|
@ -265,7 +270,7 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
Xdata + group_id * X_offset,
|
||||
0,
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
tp);
|
||||
|
||||
// Col2im
|
||||
math::Col2im<T, CPUMathUtil, StorageOrder::NCHW>(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
#include "core/platform/ort_mutex.h"
|
||||
|
||||
|
|
@ -171,7 +172,7 @@ class UniDirectionalGru {
|
|||
|
||||
void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
|
||||
const gsl::span<const T>& input_weights, const gsl::span<const T>& recurrent_weights,
|
||||
gsl::span<T>& outputs, gsl::span<T>& final_hidden_state);
|
||||
gsl::span<T>& outputs, gsl::span<T>& final_hidden_state, onnxruntime::concurrency::ThreadPool* tp);
|
||||
|
||||
~UniDirectionalGru() = default;
|
||||
|
||||
|
|
@ -263,6 +264,9 @@ Status DeepCpuGruOp::Compute(OpKernelContext* context) const {
|
|||
|
||||
template <typename T>
|
||||
Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(&context);
|
||||
auto tp = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
const Tensor& X = *context.Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
|
||||
const Tensor& W = *context.Input<Tensor>(1); // weights. [num_directions, 3*hidden_size, input_size]
|
||||
const Tensor& R = *context.Input<Tensor>(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size]
|
||||
|
|
@ -375,7 +379,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
clip_);
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, tp);
|
||||
|
||||
std::unique_ptr<detail::UniDirectionalGru<T>> bw = std::make_unique<detail::UniDirectionalGru<T>>(
|
||||
alloc,
|
||||
|
|
@ -389,7 +393,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[2],
|
||||
activation_funcs_.Entries()[3],
|
||||
clip_);
|
||||
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2);
|
||||
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2, tp);
|
||||
} else {
|
||||
std::unique_ptr<detail::UniDirectionalGru<T>> gru_p = std::make_unique<detail::UniDirectionalGru<T>>(
|
||||
alloc,
|
||||
|
|
@ -404,7 +408,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
|
|||
activation_funcs_.Entries()[1],
|
||||
clip_);
|
||||
|
||||
gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
|
||||
gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, tp);
|
||||
}
|
||||
|
||||
if (!output.empty())
|
||||
|
|
@ -505,7 +509,8 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
const gsl::span<const T>& input_weights,
|
||||
const gsl::span<const T>& recurrent_weights,
|
||||
gsl::span<T>& outputs,
|
||||
gsl::span<T>& final_hidden_state) {
|
||||
gsl::span<T>& final_hidden_state,
|
||||
onnxruntime::concurrency::ThreadPool* tp) {
|
||||
using span_T_const_iter = typename gsl::span<T>::const_iterator;
|
||||
using span_T_iter = typename gsl::span<T>::iterator;
|
||||
|
||||
|
|
@ -559,7 +564,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
input_weights.cbegin(), input_weights.cend(),
|
||||
input_size_, beta,
|
||||
outputZRH_.begin(), outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
hidden_size_x3, tp);
|
||||
|
||||
DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_);
|
||||
|
||||
|
|
@ -624,7 +629,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
|
||||
hidden_size_, beta,
|
||||
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
hidden_size_x3, tp);
|
||||
|
||||
DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
|
||||
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
|
||||
|
|
@ -640,7 +645,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
|
||||
hidden_size_, beta,
|
||||
linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output
|
||||
hidden_size_);
|
||||
hidden_size_, tp);
|
||||
|
||||
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
|
||||
}
|
||||
|
|
@ -707,7 +712,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
|
||||
hidden_size_, beta,
|
||||
out_H, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
hidden_size_x3, tp);
|
||||
}
|
||||
|
||||
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,
|
||||
|
|
|
|||
|
|
@ -783,7 +783,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
input_weights.cbegin(), input_weights.cend(), // W[iofc]
|
||||
input_size_, beta,
|
||||
output_iofc_.begin(), output_iofc_.end(),
|
||||
hidden_size_x4);
|
||||
hidden_size_x4, &ttp_);
|
||||
|
||||
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
|
||||
|
||||
|
|
@ -832,7 +832,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
|
||||
hidden_size_, beta,
|
||||
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
|
||||
hidden_size_x4);
|
||||
hidden_size_x4, &ttp_);
|
||||
|
||||
DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str,
|
||||
&*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4);
|
||||
|
|
@ -910,7 +910,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
|
||||
hidden_size_, beta,
|
||||
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
|
||||
hidden_size_x4);
|
||||
hidden_size_x4, &ttp_);
|
||||
|
||||
span_T_iter batched_output;
|
||||
span_T_iter batched_output_end;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/providers/cpu/rnn/rnn.h"
|
||||
#include "core/providers/cpu/rnn/rnn_activation_functors.h"
|
||||
#include "core/providers/cpu/rnn/rnn_helpers.h"
|
||||
|
|
@ -98,6 +99,9 @@ using EigenMatrixMapRowMajor = Eigen::Map<
|
|||
|
||||
template <>
|
||||
Status RNN<float>::Compute(OpKernelContext* ctx) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto tp = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
using namespace rnn::detail;
|
||||
|
||||
// inputs
|
||||
|
|
@ -160,7 +164,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
|
|||
}
|
||||
|
||||
// X * W[direction]^t + B
|
||||
math::Gemm<float, CPUMathUtil>(
|
||||
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
static_cast<int>(seq_length * batch_size),
|
||||
|
|
@ -171,7 +175,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
|
|||
W.template Data<float>() + direction * hidden_size_ * input_size,
|
||||
1,
|
||||
x_matmul_w_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
tp);
|
||||
|
||||
for (int64_t t = 0; t < seq_length; t++) {
|
||||
int64_t time_step = isReverse ? (seq_length - t - 1) : t;
|
||||
|
|
@ -192,7 +196,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
if (h_prev != nullptr) {
|
||||
// H_t_1 * R[direction]^t
|
||||
math::Gemm<float, CPUMathUtil>(
|
||||
math::Gemm<float, onnxruntime::concurrency::ThreadPool>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
static_cast<int>(batch_size),
|
||||
|
|
@ -203,7 +207,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
|
|||
R.template Data<float>() + direction * hidden_size_ * hidden_size_,
|
||||
0,
|
||||
Y_buffer_data_current_frame,
|
||||
&CPUMathUtil::Instance());
|
||||
tp);
|
||||
} else {
|
||||
math::Set<float, CPUMathUtil>(batch_size * hidden_size_, 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ void ComputeGemm(const int M,
|
|||
const float beta,
|
||||
TSpanCIter C,
|
||||
TSpanCIter C_end,
|
||||
const int ldc) {
|
||||
const int ldc, onnxruntime::concurrency::ThreadPool* tp) {
|
||||
// validate all the inputs
|
||||
// need to use the lda/ldb/ldc strides which should be >= the columns for the span
|
||||
ORT_ENFORCE(lda >= K && ldb >= K && ldc >= N);
|
||||
|
|
@ -167,12 +167,12 @@ void ComputeGemm(const int M,
|
|||
ORT_ENFORCE(B + (N * ldb - (ldb - K)) <= B_end);
|
||||
ORT_ENFORCE(C + (M * ldc - (ldc - N)) <= C_end);
|
||||
|
||||
::onnxruntime::math::GemmEx<float, CPUMathUtil>(
|
||||
::onnxruntime::math::GemmEx<float, ::onnxruntime::concurrency::ThreadPool>(
|
||||
CblasNoTrans, CblasTrans,
|
||||
M, N, K, alpha,
|
||||
&*A, lda,
|
||||
&*B, ldb, beta,
|
||||
&*C, ldc, &CPUMathUtil::Instance());
|
||||
&*C, ldc, tp);
|
||||
}
|
||||
|
||||
// helper to convert a span to a raw pointer
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ extern "C" {
|
|||
#include "core/framework/tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
|
||||
enum StorageOrder {
|
||||
UNKNOWN = 0,
|
||||
|
|
@ -187,10 +190,7 @@ void Gemm(
|
|||
const T* B,
|
||||
float beta,
|
||||
T* C,
|
||||
Provider* provider,
|
||||
//Caffe2 use this type to control on GPU, what presicion do we want to do the calculation
|
||||
//But not sure is this a good design for us. Keep it here for now.
|
||||
MLDataType math_type = FLOAT_TYPE);
|
||||
Provider*);
|
||||
|
||||
// We also provide a gemm that has explicit lda, ldb and ldc specified.
|
||||
// In most cases you probably want to use the function above, though.
|
||||
|
|
@ -209,7 +209,7 @@ void GemmEx(
|
|||
T beta,
|
||||
T* C,
|
||||
int ldc,
|
||||
Provider* provider);
|
||||
Provider*);
|
||||
|
||||
// GemmBatched provides a simple abstraction into library routines
|
||||
template <typename T, class Provider>
|
||||
|
|
@ -228,9 +228,7 @@ void GemmBatched(
|
|||
const T* B,
|
||||
float beta,
|
||||
T* C,
|
||||
Provider* provider,
|
||||
Tensor* scratch = nullptr,
|
||||
MLDataType math_type = DataTypeImpl::FLOAT_TYPE);
|
||||
Provider* tp);
|
||||
|
||||
// Gemv always takes in a M*N matrix A, and depending on whether we set TransA
|
||||
// to Trans, the output is:
|
||||
|
|
|
|||
|
|
@ -40,9 +40,7 @@
|
|||
#include "core/util/math_cpuonly.h"
|
||||
#include "Eigen/src/Core/arch/GPU/Half.h"
|
||||
|
||||
#if defined(USE_MLAS)
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace math {
|
||||
|
|
@ -107,7 +105,7 @@ void GemmEigen(
|
|||
// CBLAS call or the Eigen implementation.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// when USE_MKLML is defined, use cblas APIs for MKLML
|
||||
#if defined(USE_EIGEN_FOR_BLAS) && !defined(USE_MKLML_FOR_BLAS)
|
||||
#if !defined(USE_MKLML_FOR_BLAS)
|
||||
|
||||
// Caffe2 gemm provides a simpler interface to the gemm functions, with the
|
||||
// limitation that the data has to be contiguous in memory.
|
||||
|
|
@ -125,112 +123,60 @@ void GemmEigen(
|
|||
// (transpose) if the argument TransA or TransB is set to CblasNoTrans or
|
||||
// CblasTrans, respectively, for each of A and B.
|
||||
template <>
|
||||
void Gemm<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
|
||||
float* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
#if defined(USE_MLAS)
|
||||
void Gemm<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
|
||||
float* C, onnxruntime::concurrency::ThreadPool* threadpool) {
|
||||
int lda = static_cast<int>((TransA == CblasNoTrans) ? K : M);
|
||||
int ldb = static_cast<int>((TransB == CblasNoTrans) ? N : K);
|
||||
// TODO: Make this use the operator threadpool
|
||||
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, nullptr);
|
||||
#else
|
||||
GemmEigen<float>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
#endif
|
||||
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, threadpool);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<double, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
|
||||
float beta, double* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<double, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
|
||||
float beta, double* C, onnxruntime::concurrency::ThreadPool*) {
|
||||
// No double precision Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<double>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
|
||||
float beta, int32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<int32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
|
||||
float beta, int32_t* C, onnxruntime::concurrency::ThreadPool*) {
|
||||
// No int32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
|
||||
float beta, uint32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<uint32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
|
||||
float beta, uint32_t* C, onnxruntime::concurrency::ThreadPool*) {
|
||||
// No uint32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
|
||||
float beta, int64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<int64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
|
||||
float beta, int64_t* C, onnxruntime::concurrency::ThreadPool*) {
|
||||
// No int64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
|
||||
float beta, uint64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<uint64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
|
||||
float beta, uint64_t* C, onnxruntime::concurrency::ThreadPool*) {
|
||||
// No uint64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmEx<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
|
||||
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
|
||||
int ldc, CPUMathUtil*) {
|
||||
#if defined(USE_MLAS)
|
||||
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, nullptr);
|
||||
#else
|
||||
using OuterStride = Eigen::OuterStride<Eigen::Dynamic>;
|
||||
using StridedMap = Eigen::Map<Eigen::MatrixXf, 0, OuterStride>;
|
||||
using ConstStridedMap = Eigen::Map<const Eigen::MatrixXf, 0, OuterStride>;
|
||||
auto C_mat = StridedMap(C, N, M, OuterStride(ldc));
|
||||
if (beta == 0) {
|
||||
C_mat.setZero();
|
||||
} else {
|
||||
C_mat *= beta;
|
||||
}
|
||||
switch (TransA) {
|
||||
case CblasNoTrans: {
|
||||
switch (TransB) {
|
||||
case CblasNoTrans:
|
||||
C_mat.noalias() +=
|
||||
alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) *
|
||||
ConstStridedMap(A, K, M, OuterStride(lda)));
|
||||
return;
|
||||
case CblasTrans:
|
||||
C_mat.noalias() +=
|
||||
alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() *
|
||||
ConstStridedMap(A, K, M, OuterStride(lda)));
|
||||
return;
|
||||
default:
|
||||
ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
|
||||
}
|
||||
}
|
||||
case CblasTrans: {
|
||||
switch (TransB) {
|
||||
case CblasNoTrans:
|
||||
C_mat.noalias() +=
|
||||
alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) *
|
||||
ConstStridedMap(A, M, K, OuterStride(lda)).transpose());
|
||||
return;
|
||||
case CblasTrans:
|
||||
C_mat.noalias() +=
|
||||
alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() *
|
||||
ConstStridedMap(A, M, K, OuterStride(lda)).transpose());
|
||||
return;
|
||||
default:
|
||||
ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
|
||||
}
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA);
|
||||
}
|
||||
#endif
|
||||
void GemmEx<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
|
||||
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
|
||||
int ldc, onnxruntime::concurrency::ThreadPool* tp) {
|
||||
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, tp);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -301,12 +247,12 @@ SPECIALIZED_AXPY(float)
|
|||
SPECIALIZED_AXPBY(float)
|
||||
#undef SPECIALIZED_AXPBY
|
||||
|
||||
#else // USE_EIGEN_FOR_BLAS
|
||||
#else
|
||||
|
||||
template <>
|
||||
void Gemm<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
|
||||
float* C, CPUMathUtil* /*context*/, MLDataType /*math_type*/) {
|
||||
void Gemm<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const float* A, const float* B, float beta,
|
||||
float* C, concurrency::ThreadPool*) {
|
||||
int lda = gsl::narrow_cast<int>((TransA == CblasNoTrans) ? K : M);
|
||||
int ldb = gsl::narrow_cast<int>((TransB == CblasNoTrans) ? N : K);
|
||||
cblas_sgemm(CblasRowMajor, TransA, TransB,
|
||||
|
|
@ -318,9 +264,9 @@ void Gemm<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOS
|
|||
}
|
||||
|
||||
template <>
|
||||
void Gemm<double, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
|
||||
float beta, double* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<double, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const double* A, const double* B,
|
||||
float beta, double* C, concurrency::ThreadPool*) {
|
||||
int lda = gsl::narrow_cast<int>((TransA == CblasNoTrans) ? K : M);
|
||||
int ldb = gsl::narrow_cast<int>((TransB == CblasNoTrans) ? N : K);
|
||||
cblas_dgemm(CblasRowMajor, TransA, TransB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
|
||||
|
|
@ -329,41 +275,41 @@ void Gemm<double, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPO
|
|||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
|
||||
float beta, int32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<int32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int32_t* A, const int32_t* B,
|
||||
float beta, int32_t* C, concurrency::ThreadPool*) {
|
||||
// No int32_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint32_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
|
||||
float beta, uint32_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<uint32_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint32_t* A, const uint32_t* B,
|
||||
float beta, uint32_t* C, concurrency::ThreadPool*) {
|
||||
// No uint32_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
|
||||
float beta, int64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<int64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const int64_t* A, const int64_t* B,
|
||||
float beta, int64_t* C, concurrency::ThreadPool*) {
|
||||
// No int64_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint64_t, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
|
||||
float beta, uint64_t* C, CPUMathUtil* /*provider*/, MLDataType /*math_type*/) {
|
||||
void Gemm<uint64_t, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, float alpha, const uint64_t* A, const uint64_t* B,
|
||||
float beta, uint64_t* C, concurrency::ThreadPool*) {
|
||||
// No uint64_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmEx<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
|
||||
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
|
||||
int ldc, CPUMathUtil* /*context*/) {
|
||||
void GemmEx<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int M, int N, int K,
|
||||
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
|
||||
int ldc, concurrency::ThreadPool*) {
|
||||
cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb,
|
||||
beta, C, ldc);
|
||||
}
|
||||
|
|
@ -417,20 +363,19 @@ CAFFE2_SPECIALIZED_AXPY(float, s)
|
|||
CAFFE2_SPECIALIZED_AXPBY(float, s)
|
||||
#undef CAFFE2_SPECIALIZED_AXPBY
|
||||
|
||||
#endif // USE_EIGEN_FOR_BLAS
|
||||
#endif
|
||||
|
||||
template <>
|
||||
void GemmBatched<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int A_size,
|
||||
int A_batches, int B_size, int B_batches, int M, int N, int K, float /*alpha*/,
|
||||
const float* A, const float* B, float /*beta*/, float* C, CPUMathUtil* provider,
|
||||
Tensor*, /* scratch */
|
||||
MLDataType /* math_type */) {
|
||||
void GemmBatched<float, concurrency::ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, int A_size,
|
||||
int A_batches, int B_size, int B_batches, int M, int N, int K, float /*alpha*/,
|
||||
const float* A, const float* B, float /*beta*/, float* C,
|
||||
onnxruntime::concurrency::ThreadPool* tp) {
|
||||
auto a_offset = A_size / A_batches;
|
||||
auto b_offset = B_size / B_batches;
|
||||
auto y_offset = M * N;
|
||||
// loop over matrices in the batch
|
||||
for (int i = 0; i < A_batches; ++i) {
|
||||
math::Gemm<float, CPUMathUtil>(
|
||||
math::Gemm<float, concurrency::ThreadPool>(
|
||||
TransA,
|
||||
TransB,
|
||||
M,
|
||||
|
|
@ -441,7 +386,7 @@ void GemmBatched<float, CPUMathUtil>(const CBLAS_TRANSPOSE TransA, const CBLAS_T
|
|||
B + b_offset * i,
|
||||
0,
|
||||
C + y_offset * i,
|
||||
provider);
|
||||
tp);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@
|
|||
|
||||
#include "core/util/math.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
namespace onnxruntime {
|
||||
|
||||
#define VECTOR_HEAD(x) x.size() > 0 ? &x[0] : NULL
|
||||
|
||||
TEST(MathTest, GemmNoTransNoTrans) {
|
||||
concurrency::ThreadPool tp("", 1);
|
||||
auto& provider = CPUMathUtil::Instance();
|
||||
std::vector<float> X(50); // 5 * 10
|
||||
std::vector<float> W(60); // 10 * 6
|
||||
|
|
@ -40,26 +42,26 @@ TEST(MathTest, GemmNoTransNoTrans) {
|
|||
const float kOne = 1.0;
|
||||
const float kPointFive = 0.5;
|
||||
const float kZero = 0.0;
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
|
||||
&provider);
|
||||
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
|
||||
&tp);
|
||||
EXPECT_EQ(Y.size(), 30);
|
||||
for (size_t i = 0; i < Y.size(); ++i) {
|
||||
EXPECT_EQ(Y[i], 10) << i;
|
||||
}
|
||||
// Test Accumulate
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
|
||||
VECTOR_HEAD(Y), &provider);
|
||||
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
|
||||
VECTOR_HEAD(Y), &tp);
|
||||
EXPECT_EQ(Y.size(), 30);
|
||||
for (size_t i = 0; i < Y.size(); ++i) {
|
||||
EXPECT_EQ(Y[i], 15) << i;
|
||||
}
|
||||
// Test Accumulate
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasNoTrans, 5, 6, 10,
|
||||
kPointFive,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
|
||||
&provider);
|
||||
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasNoTrans, 5, 6, 10,
|
||||
kPointFive,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
|
||||
&tp);
|
||||
EXPECT_EQ(Y.size(), 30);
|
||||
for (size_t i = 0; i < Y.size(); ++i) {
|
||||
EXPECT_EQ(Y[i], 20) << i;
|
||||
|
|
@ -68,6 +70,8 @@ TEST(MathTest, GemmNoTransNoTrans) {
|
|||
|
||||
TEST(MathTest, GemmNoTransTrans) {
|
||||
auto& provider = CPUMathUtil::Instance();
|
||||
concurrency::ThreadPool tp("", 1);
|
||||
|
||||
std::vector<float> X(50); // 5 * 10
|
||||
std::vector<float> W(60); // 10 * 6
|
||||
std::vector<float> Y(30); // 5 * 6
|
||||
|
|
@ -84,24 +88,24 @@ TEST(MathTest, GemmNoTransTrans) {
|
|||
const float kOne = 1.0;
|
||||
const float kPointFive = 0.5;
|
||||
const float kZero = 0.0;
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
|
||||
&provider);
|
||||
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kZero, VECTOR_HEAD(Y),
|
||||
&tp);
|
||||
EXPECT_EQ(Y.size(), 30);
|
||||
for (size_t i = 0; i < Y.size(); ++i) {
|
||||
EXPECT_EQ(Y[i], 10) << i;
|
||||
}
|
||||
// Test Accumulate
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
|
||||
VECTOR_HEAD(Y), &provider);
|
||||
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasTrans, 5, 6, 10, kOne,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kPointFive,
|
||||
VECTOR_HEAD(Y), &tp);
|
||||
EXPECT_EQ(Y.size(), 30);
|
||||
for (size_t i = 0; i < Y.size(); ++i) {
|
||||
EXPECT_EQ(Y[i], 15) << i;
|
||||
}
|
||||
math::Gemm<float, CPUMathUtil>(CblasNoTrans, CblasTrans, 5, 6, 10, kPointFive,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
|
||||
&provider);
|
||||
math::Gemm<float, concurrency::ThreadPool>(CblasNoTrans, CblasTrans, 5, 6, 10, kPointFive,
|
||||
VECTOR_HEAD(X), VECTOR_HEAD(W), kOne, VECTOR_HEAD(Y),
|
||||
&tp);
|
||||
EXPECT_EQ(Y.size(), 30);
|
||||
for (size_t i = 0; i < Y.size(); ++i) {
|
||||
EXPECT_EQ(Y[i], 20) << i;
|
||||
|
|
|
|||
|
|
@ -194,23 +194,23 @@ TEST(SoftmaxOperator, InvalidAxis) {
|
|||
|
||||
TEST(SoftmaxOperator, TestInputTooLarge) {
|
||||
float* ignored = nullptr;
|
||||
|
||||
concurrency::ThreadPool tp("", 1);
|
||||
// N > INT32_MAX
|
||||
int64_t N = int64_t(INT32_MAX) + 1;
|
||||
int64_t D = 1;
|
||||
auto status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored);
|
||||
auto status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp);
|
||||
EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT);
|
||||
|
||||
// D > INT32_MAX
|
||||
N = 1;
|
||||
D = int64_t(INT32_MAX) + 1;
|
||||
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored);
|
||||
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp);
|
||||
EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT);
|
||||
|
||||
// N * D > INT32_MAX
|
||||
N = int64_t(INT32_MAX) / 2;
|
||||
D = 3;
|
||||
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored);
|
||||
status = SoftmaxCPU(N, D, ignored, ignored, ignored, ignored, true, ignored, &tp);
|
||||
EXPECT_EQ(status.Code(), common::INVALID_ARGUMENT);
|
||||
|
||||
/*
|
||||
|
|
|
|||
Loading…
Reference in a new issue