From e63bb5acefa2c931d66566013278ee9a1a0e2c7e Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 26 Apr 2023 18:48:00 -0700 Subject: [PATCH] Fix a memory leak in QGemm (#15703) ### Description The BufferUniquePtrs in the old code doesn't have knowledge of the allocator where the allocated memory was from, so it cannot free the memory. --- .../contrib_ops/cpu/quantization/quant_gemm.cc | 11 ++++++----- .../providers/cpu/quantization/matmul_integer_base.h | 2 +- onnxruntime/core/quantization/quantization.h | 7 ++++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index 0eb10ef84a..e49e6a9fa6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -47,14 +47,14 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { bool a_is_signed = a->IsDataType(); const uint8_t* a_data = static_cast(a->DataRaw()); - BufferUniquePtr a_trans_buffer; + std::unique_ptr a_trans_buffer; if (trans_A_ == CblasTrans) { a_data = quantization::TransPoseInputData(a_data, a_trans_buffer, allocator, K, M); } bool b_is_signed; const uint8_t* b_data = nullptr; - BufferUniquePtr b_trans_buffer; + std::unique_ptr b_trans_buffer; if (nullptr == b) { b_data = static_cast(packed_b_.get()); b_is_signed = b_is_signed_; @@ -71,11 +71,12 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { // prepare output buffer of GEMM int32_t* gemm_output_data = nullptr; - BufferUniquePtr gemm_output_buffer; + std::unique_ptr gemm_output_buffer; bool need_requant = y_scale != nullptr; if (need_requant) { - gemm_output_data = static_cast(allocator->Alloc(SafeInt(M * N) * sizeof(int32_t))); - gemm_output_buffer.reset(gemm_output_data); + TensorShape outputshape{static_cast(M), static_cast(N)}; + gemm_output_buffer = std::make_unique(DataTypeImpl::GetType(), outputshape, allocator); + gemm_output_data = gemm_output_buffer->MutableData(); } else { gemm_output_data = static_cast(y->MutableDataRaw()); } diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index 5aac7adda3..965552cc7c 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -37,7 +37,7 @@ class MatMulIntegerBase : public OpKernel { const auto* b_data = static_cast(tensor.DataRaw()); - BufferUniquePtr b_trans_buffer; + std::unique_ptr b_trans_buffer; if (IsBTransposed()) { std::swap(K, N); b_data = quantization::TransPoseInputData(b_data, b_trans_buffer, alloc, N, K); diff --git a/onnxruntime/core/quantization/quantization.h b/onnxruntime/core/quantization/quantization.h index 8e505157ad..9d4601ad1b 100644 --- a/onnxruntime/core/quantization/quantization.h +++ b/onnxruntime/core/quantization/quantization.h @@ -188,13 +188,14 @@ void Dequantize(const std::vector& values, // Transpose the input and store it to a new allocated buffer. inline uint8_t* TransPoseInputData(const uint8_t* input, - BufferUniquePtr& buffer_holder, + std::unique_ptr& buffer_holder, AllocatorPtr& allocator, size_t M, size_t N) { - uint8_t* output = static_cast(allocator->Alloc(M * N * sizeof(uint8_t))); + TensorShape outputshape{static_cast(M), static_cast(N)}; + buffer_holder = std::make_unique(DataTypeImpl::GetType(), outputshape, allocator); + uint8_t* output = buffer_holder->MutableData(); MlasTranspose(input, output, M, N); - buffer_holder.reset(output); return output; }