diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc index a16fc0992e..0311a20e26 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc @@ -214,7 +214,8 @@ void ComputeGemm(const int M, float* C, float* C_end, const int ldc, - AllocatorPtr /*allocator*/, + uint8_t* /* quantized_A_buffer */, + int32_t* /* quantize_agg_C_buffer */, concurrency::ThreadPool* thread_pool) { // validate all the inputs // need to use the lda/ldb/ldc strides which should be >= the columns for the span @@ -249,7 +250,8 @@ void ComputeGemm(const int M, float* C, float* C_end, const int ldc, - AllocatorPtr allocator, + uint8_t* quantized_A_buffer, + int32_t* quantize_agg_C_buffer, concurrency::ThreadPool* thread_pool) { // validate all the inputs // need to use the lda/ldb/ldc strides which should be >= the columns for the span @@ -262,10 +264,8 @@ void ComputeGemm(const int M, uint8_t a_zero_point; GetQuantizationParameter(A, M * K, a_scale, a_zero_point, thread_pool); - uint8_t* a_data_quant = static_cast(allocator->Alloc(SafeInt(M * K) * sizeof(uint8_t))); - BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(allocator)); // quantize the data - ParQuantizeLinear(A, a_data_quant, M * K, a_scale, a_zero_point, thread_pool); + ParQuantizeLinear(A, quantized_A_buffer, M * K, a_scale, a_zero_point, thread_pool); bool b_is_signed = weights.quant_para_->is_signed; uint8_t b_zero_point = weights.quant_para_->zero_point ? *static_cast(weights.quant_para_->zero_point) : 0; @@ -277,11 +277,9 @@ void ComputeGemm(const int M, size_t ld_C_buffer = ldc; int32_t* C_buffer = reinterpret_cast(C); - BufferUniquePtr tmp_res_buffer_holder; if (beta == 1.0f) { - C_buffer = static_cast(allocator->Alloc(SafeInt(M * N) * sizeof(int32_t))); + C_buffer = quantize_agg_C_buffer; ld_C_buffer = static_cast(N); - tmp_res_buffer_holder = BufferUniquePtr(C_buffer, BufferDeleter(allocator)); } MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR output_processor( @@ -296,7 +294,7 @@ void ComputeGemm(const int M, gemm_shape.BIsSigned = b_is_signed; MLAS_GEMM_U8X8_DATA_PARAMS gemm_params; - gemm_params.A = a_data_quant; + gemm_params.A = quantized_A_buffer; gemm_params.lda = static_cast(K); gemm_params.ZeroPointA = a_zero_point; gemm_params.B = weights.buffer_; diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index 2b534c7295..e23d516350 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -227,7 +227,8 @@ void ComputeGemm(const int M, float* C, float* C_end, const int ldc, - AllocatorPtr /*allocator*/, + uint8_t* /* quantized_A_buffer */, + int32_t* /* quantize_agg_C_buffer */, concurrency::ThreadPool* thread_pool); void ComputeGemm(const int M, @@ -241,7 +242,8 @@ void ComputeGemm(const int M, float* C, float* C_end, const int ldc, - AllocatorPtr allocator, + uint8_t* quantized_A_buffer, + int32_t* quantize_agg_C_buffer, concurrency::ThreadPool* thread_pool); // helper to convert a span to a raw pointer diff --git a/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc b/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc index bbf4beddb1..39a2d86296 100644 --- a/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc @@ -198,6 +198,20 @@ void UniDirectionalLstm::LoadBias(const gsl::span& WbRb_values) { */ } +template +template +void UniDirectionalLstm::AllocateQuantizeBuffers(int max_sequence_length) { + // Can not specialize on WeightT without specify T explicitly, so use sizeof + if (sizeof(WeightT) == 1) { + const int hidden_size_x4 = 4 * hidden_size_; + const int total_rows = max_sequence_length * batch_size_; + + int input_or_a_size = std::max(total_rows * input_size_, batch_size_ * hidden_size_); + quantized_input_or_a_ = Allocate(allocator_, input_or_a_size, quantized_input_or_a_ptr_, false); + quantized_C_buffer_ = Allocate(allocator_, batch_size_ * hidden_size_x4, quantized_C_buffer_ptr_, false); + } +} + template template void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, @@ -247,8 +261,8 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, // Calculate the max and min length const auto min_max_pair = std::minmax_element(sequence_lengths.cbegin(), sequence_lengths.cend()); - int32_t max_sequence_length = *min_max_pair.second; - int32_t min_sequence_length = std::min(seq_length_, *min_max_pair.first); + int max_sequence_length = *min_max_pair.second; + int min_sequence_length = std::min(seq_length_, *min_max_pair.first); ///**************************LSTM Calculations****************************/ float alpha = 1.0f; @@ -257,10 +271,15 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, const int hidden_size_x4 = 4 * hidden_size_; const int total_rows = max_sequence_length * batch_size_; + AllocateQuantizeBuffers(max_sequence_length); + // apply the weights to all the inputs and save to output_IOFC ComputeGemm(total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin(), inputs.cend(), input_weights, - beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, allocator_, thread_pool_); + beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, + quantized_input_or_a_.begin(), + nullptr, + thread_pool_); DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4); @@ -311,7 +330,10 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, previous_state, previous_state_end, // Ht-1 recurrent_weights, // R[iofc] beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4, allocator_, ttp); + hidden_size_x4, + quantized_input_or_a_.begin() + (seq_start * hidden_size_), + quantized_C_buffer_.begin() + (seq_start * hidden_size_x4), + ttp); DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, num_seq_to_compute_adjusted, hidden_size_x4); diff --git a/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.h b/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.h index 9c18b79c72..9eb71112d2 100644 --- a/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.h @@ -109,6 +109,17 @@ class UniDirectionalLstm { ActivationInfo activation_h_; concurrency::ThreadPool* thread_pool_; + + // Quantized operation related allocation members + template + void AllocateQuantizeBuffers(int max_sequence_length); + + // Buffer shared for quantized input whole, and quantized a each sequence step + IAllocatorUniquePtr quantized_input_or_a_ptr_; + gsl::span quantized_input_or_a_; + + IAllocatorUniquePtr quantized_C_buffer_ptr_; + gsl::span quantized_C_buffer_; }; } // namespace lstm