Decrease lock contention in qlstm by memory allocation. (#7815)

* Decrease lock contention in qlstm caused by memory allocation.
This commit is contained in:
Zhang Lei 2021-05-25 17:08:42 -07:00 committed by GitHub
parent ff655175ff
commit f49a4b6329
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 15 deletions

View file

@ -214,7 +214,8 @@ void ComputeGemm(const int M,
float* C,
float* C_end,
const int ldc,
AllocatorPtr /*allocator*/,
uint8_t* /* quantized_A_buffer */,
int32_t* /* quantize_agg_C_buffer */,
concurrency::ThreadPool* thread_pool) {
// validate all the inputs
// need to use the lda/ldb/ldc strides which should be >= the columns for the span
@ -249,7 +250,8 @@ void ComputeGemm(const int M,
float* C,
float* C_end,
const int ldc,
AllocatorPtr allocator,
uint8_t* quantized_A_buffer,
int32_t* quantize_agg_C_buffer,
concurrency::ThreadPool* thread_pool) {
// validate all the inputs
// need to use the lda/ldb/ldc strides which should be >= the columns for the span
@ -262,10 +264,8 @@ void ComputeGemm(const int M,
uint8_t a_zero_point;
GetQuantizationParameter(A, M * K, a_scale, a_zero_point, thread_pool);
uint8_t* a_data_quant = static_cast<uint8_t*>(allocator->Alloc(SafeInt<size_t>(M * K) * sizeof(uint8_t)));
BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(allocator));
// quantize the data
ParQuantizeLinear(A, a_data_quant, M * K, a_scale, a_zero_point, thread_pool);
ParQuantizeLinear(A, quantized_A_buffer, M * K, a_scale, a_zero_point, thread_pool);
bool b_is_signed = weights.quant_para_->is_signed;
uint8_t b_zero_point = weights.quant_para_->zero_point ? *static_cast<const uint8_t*>(weights.quant_para_->zero_point) : 0;
@ -277,11 +277,9 @@ void ComputeGemm(const int M,
size_t ld_C_buffer = ldc;
int32_t* C_buffer = reinterpret_cast<int32_t*>(C);
BufferUniquePtr tmp_res_buffer_holder;
if (beta == 1.0f) {
C_buffer = static_cast<int32_t*>(allocator->Alloc(SafeInt<size_t>(M * N) * sizeof(int32_t)));
C_buffer = quantize_agg_C_buffer;
ld_C_buffer = static_cast<size_t>(N);
tmp_res_buffer_holder = BufferUniquePtr(C_buffer, BufferDeleter(allocator));
}
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR output_processor(
@ -296,7 +294,7 @@ void ComputeGemm(const int M,
gemm_shape.BIsSigned = b_is_signed;
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
gemm_params.A = a_data_quant;
gemm_params.A = quantized_A_buffer;
gemm_params.lda = static_cast<size_t>(K);
gemm_params.ZeroPointA = a_zero_point;
gemm_params.B = weights.buffer_;

View file

@ -227,7 +227,8 @@ void ComputeGemm(const int M,
float* C,
float* C_end,
const int ldc,
AllocatorPtr /*allocator*/,
uint8_t* /* quantized_A_buffer */,
int32_t* /* quantize_agg_C_buffer */,
concurrency::ThreadPool* thread_pool);
void ComputeGemm(const int M,
@ -241,7 +242,8 @@ void ComputeGemm(const int M,
float* C,
float* C_end,
const int ldc,
AllocatorPtr allocator,
uint8_t* quantized_A_buffer,
int32_t* quantize_agg_C_buffer,
concurrency::ThreadPool* thread_pool);
// helper to convert a span to a raw pointer

View file

@ -198,6 +198,20 @@ void UniDirectionalLstm<T>::LoadBias(const gsl::span<const T>& WbRb_values) {
*/
}
template <typename T>
template <typename WeightT>
void UniDirectionalLstm<T>::AllocateQuantizeBuffers(int max_sequence_length) {
// Can not specialize on WeightT without specify T explicitly, so use sizeof
if (sizeof(WeightT) == 1) {
const int hidden_size_x4 = 4 * hidden_size_;
const int total_rows = max_sequence_length * batch_size_;
int input_or_a_size = std::max(total_rows * input_size_, batch_size_ * hidden_size_);
quantized_input_or_a_ = Allocate(allocator_, input_or_a_size, quantized_input_or_a_ptr_, false);
quantized_C_buffer_ = Allocate(allocator_, batch_size_ * hidden_size_x4, quantized_C_buffer_ptr_, false);
}
}
template <typename T>
template <typename WeightT>
void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
@ -247,8 +261,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
// Calculate the max and min length
const auto min_max_pair = std::minmax_element(sequence_lengths.cbegin(), sequence_lengths.cend());
int32_t max_sequence_length = *min_max_pair.second;
int32_t min_sequence_length = std::min(seq_length_, *min_max_pair.first);
int max_sequence_length = *min_max_pair.second;
int min_sequence_length = std::min(seq_length_, *min_max_pair.first);
///**************************LSTM Calculations****************************/
float alpha = 1.0f;
@ -257,10 +271,15 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
const int hidden_size_x4 = 4 * hidden_size_;
const int total_rows = max_sequence_length * batch_size_;
AllocateQuantizeBuffers<WeightT>(max_sequence_length);
// apply the weights to all the inputs and save to output_IOFC
ComputeGemm(total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin(), inputs.cend(),
input_weights,
beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, allocator_, thread_pool_);
beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4,
quantized_input_or_a_.begin(),
nullptr,
thread_pool_);
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
@ -311,7 +330,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
previous_state, previous_state_end, // Ht-1
recurrent_weights, // R[iofc]
beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, allocator_, ttp);
hidden_size_x4,
quantized_input_or_a_.begin() + (seq_start * hidden_size_),
quantized_C_buffer_.begin() + (seq_start * hidden_size_x4),
ttp);
DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, num_seq_to_compute_adjusted, hidden_size_x4);

View file

@ -109,6 +109,17 @@ class UniDirectionalLstm {
ActivationInfo<deepcpu::LstmMergeGatesFuncPtr> activation_h_;
concurrency::ThreadPool* thread_pool_;
// Quantized operation related allocation members
template <typename WeightT>
void AllocateQuantizeBuffers(int max_sequence_length);
// Buffer shared for quantized input whole, and quantized a each sequence step
IAllocatorUniquePtr<uint8_t> quantized_input_or_a_ptr_;
gsl::span<uint8_t> quantized_input_or_a_;
IAllocatorUniquePtr<int32_t> quantized_C_buffer_ptr_;
gsl::span<int32_t> quantized_C_buffer_;
};
} // namespace lstm