mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Decrease lock contention in qlstm by memory allocation. (#7815)
* Decrease lock contention in qlstm caused by memory allocation.
This commit is contained in:
parent
ff655175ff
commit
f49a4b6329
4 changed files with 48 additions and 15 deletions
|
|
@ -214,7 +214,8 @@ void ComputeGemm(const int M,
|
|||
float* C,
|
||||
float* C_end,
|
||||
const int ldc,
|
||||
AllocatorPtr /*allocator*/,
|
||||
uint8_t* /* quantized_A_buffer */,
|
||||
int32_t* /* quantize_agg_C_buffer */,
|
||||
concurrency::ThreadPool* thread_pool) {
|
||||
// validate all the inputs
|
||||
// need to use the lda/ldb/ldc strides which should be >= the columns for the span
|
||||
|
|
@ -249,7 +250,8 @@ void ComputeGemm(const int M,
|
|||
float* C,
|
||||
float* C_end,
|
||||
const int ldc,
|
||||
AllocatorPtr allocator,
|
||||
uint8_t* quantized_A_buffer,
|
||||
int32_t* quantize_agg_C_buffer,
|
||||
concurrency::ThreadPool* thread_pool) {
|
||||
// validate all the inputs
|
||||
// need to use the lda/ldb/ldc strides which should be >= the columns for the span
|
||||
|
|
@ -262,10 +264,8 @@ void ComputeGemm(const int M,
|
|||
uint8_t a_zero_point;
|
||||
GetQuantizationParameter(A, M * K, a_scale, a_zero_point, thread_pool);
|
||||
|
||||
uint8_t* a_data_quant = static_cast<uint8_t*>(allocator->Alloc(SafeInt<size_t>(M * K) * sizeof(uint8_t)));
|
||||
BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(allocator));
|
||||
// quantize the data
|
||||
ParQuantizeLinear(A, a_data_quant, M * K, a_scale, a_zero_point, thread_pool);
|
||||
ParQuantizeLinear(A, quantized_A_buffer, M * K, a_scale, a_zero_point, thread_pool);
|
||||
|
||||
bool b_is_signed = weights.quant_para_->is_signed;
|
||||
uint8_t b_zero_point = weights.quant_para_->zero_point ? *static_cast<const uint8_t*>(weights.quant_para_->zero_point) : 0;
|
||||
|
|
@ -277,11 +277,9 @@ void ComputeGemm(const int M,
|
|||
|
||||
size_t ld_C_buffer = ldc;
|
||||
int32_t* C_buffer = reinterpret_cast<int32_t*>(C);
|
||||
BufferUniquePtr tmp_res_buffer_holder;
|
||||
if (beta == 1.0f) {
|
||||
C_buffer = static_cast<int32_t*>(allocator->Alloc(SafeInt<size_t>(M * N) * sizeof(int32_t)));
|
||||
C_buffer = quantize_agg_C_buffer;
|
||||
ld_C_buffer = static_cast<size_t>(N);
|
||||
tmp_res_buffer_holder = BufferUniquePtr(C_buffer, BufferDeleter(allocator));
|
||||
}
|
||||
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR output_processor(
|
||||
|
|
@ -296,7 +294,7 @@ void ComputeGemm(const int M,
|
|||
gemm_shape.BIsSigned = b_is_signed;
|
||||
|
||||
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
|
||||
gemm_params.A = a_data_quant;
|
||||
gemm_params.A = quantized_A_buffer;
|
||||
gemm_params.lda = static_cast<size_t>(K);
|
||||
gemm_params.ZeroPointA = a_zero_point;
|
||||
gemm_params.B = weights.buffer_;
|
||||
|
|
|
|||
|
|
@ -227,7 +227,8 @@ void ComputeGemm(const int M,
|
|||
float* C,
|
||||
float* C_end,
|
||||
const int ldc,
|
||||
AllocatorPtr /*allocator*/,
|
||||
uint8_t* /* quantized_A_buffer */,
|
||||
int32_t* /* quantize_agg_C_buffer */,
|
||||
concurrency::ThreadPool* thread_pool);
|
||||
|
||||
void ComputeGemm(const int M,
|
||||
|
|
@ -241,7 +242,8 @@ void ComputeGemm(const int M,
|
|||
float* C,
|
||||
float* C_end,
|
||||
const int ldc,
|
||||
AllocatorPtr allocator,
|
||||
uint8_t* quantized_A_buffer,
|
||||
int32_t* quantize_agg_C_buffer,
|
||||
concurrency::ThreadPool* thread_pool);
|
||||
|
||||
// helper to convert a span to a raw pointer
|
||||
|
|
|
|||
|
|
@ -198,6 +198,20 @@ void UniDirectionalLstm<T>::LoadBias(const gsl::span<const T>& WbRb_values) {
|
|||
*/
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename WeightT>
|
||||
void UniDirectionalLstm<T>::AllocateQuantizeBuffers(int max_sequence_length) {
|
||||
// Can not specialize on WeightT without specify T explicitly, so use sizeof
|
||||
if (sizeof(WeightT) == 1) {
|
||||
const int hidden_size_x4 = 4 * hidden_size_;
|
||||
const int total_rows = max_sequence_length * batch_size_;
|
||||
|
||||
int input_or_a_size = std::max(total_rows * input_size_, batch_size_ * hidden_size_);
|
||||
quantized_input_or_a_ = Allocate(allocator_, input_or_a_size, quantized_input_or_a_ptr_, false);
|
||||
quantized_C_buffer_ = Allocate(allocator_, batch_size_ * hidden_size_x4, quantized_C_buffer_ptr_, false);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename WeightT>
|
||||
void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
||||
|
|
@ -247,8 +261,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
|
||||
// Calculate the max and min length
|
||||
const auto min_max_pair = std::minmax_element(sequence_lengths.cbegin(), sequence_lengths.cend());
|
||||
int32_t max_sequence_length = *min_max_pair.second;
|
||||
int32_t min_sequence_length = std::min(seq_length_, *min_max_pair.first);
|
||||
int max_sequence_length = *min_max_pair.second;
|
||||
int min_sequence_length = std::min(seq_length_, *min_max_pair.first);
|
||||
|
||||
///**************************LSTM Calculations****************************/
|
||||
float alpha = 1.0f;
|
||||
|
|
@ -257,10 +271,15 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
const int hidden_size_x4 = 4 * hidden_size_;
|
||||
const int total_rows = max_sequence_length * batch_size_;
|
||||
|
||||
AllocateQuantizeBuffers<WeightT>(max_sequence_length);
|
||||
|
||||
// apply the weights to all the inputs and save to output_IOFC
|
||||
ComputeGemm(total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin(), inputs.cend(),
|
||||
input_weights,
|
||||
beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, allocator_, thread_pool_);
|
||||
beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4,
|
||||
quantized_input_or_a_.begin(),
|
||||
nullptr,
|
||||
thread_pool_);
|
||||
|
||||
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
|
||||
|
||||
|
|
@ -311,7 +330,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
previous_state, previous_state_end, // Ht-1
|
||||
recurrent_weights, // R[iofc]
|
||||
beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
|
||||
hidden_size_x4, allocator_, ttp);
|
||||
hidden_size_x4,
|
||||
quantized_input_or_a_.begin() + (seq_start * hidden_size_),
|
||||
quantized_C_buffer_.begin() + (seq_start * hidden_size_x4),
|
||||
ttp);
|
||||
|
||||
DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, num_seq_to_compute_adjusted, hidden_size_x4);
|
||||
|
||||
|
|
|
|||
|
|
@ -109,6 +109,17 @@ class UniDirectionalLstm {
|
|||
ActivationInfo<deepcpu::LstmMergeGatesFuncPtr> activation_h_;
|
||||
|
||||
concurrency::ThreadPool* thread_pool_;
|
||||
|
||||
// Quantized operation related allocation members
|
||||
template <typename WeightT>
|
||||
void AllocateQuantizeBuffers(int max_sequence_length);
|
||||
|
||||
// Buffer shared for quantized input whole, and quantized a each sequence step
|
||||
IAllocatorUniquePtr<uint8_t> quantized_input_or_a_ptr_;
|
||||
gsl::span<uint8_t> quantized_input_or_a_;
|
||||
|
||||
IAllocatorUniquePtr<int32_t> quantized_C_buffer_ptr_;
|
||||
gsl::span<int32_t> quantized_C_buffer_;
|
||||
};
|
||||
|
||||
} // namespace lstm
|
||||
|
|
|
|||
Loading…
Reference in a new issue