Fix a memory leak in QGemm (#15703)

### Description
The BufferUniquePtrs in the old code doesn't have knowledge of the
allocator where the allocated memory was from, so it cannot free the
memory.
This commit is contained in:
Changming Sun 2023-04-26 18:48:00 -07:00 committed by GitHub
parent 740d553c42
commit e63bb5acef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 9 deletions

View file

@ -47,14 +47,14 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
bool a_is_signed = a->IsDataType<int8_t>();
const uint8_t* a_data = static_cast<const uint8_t*>(a->DataRaw());
BufferUniquePtr a_trans_buffer;
std::unique_ptr<Tensor> a_trans_buffer;
if (trans_A_ == CblasTrans) {
a_data = quantization::TransPoseInputData(a_data, a_trans_buffer, allocator, K, M);
}
bool b_is_signed;
const uint8_t* b_data = nullptr;
BufferUniquePtr b_trans_buffer;
std::unique_ptr<Tensor> b_trans_buffer;
if (nullptr == b) {
b_data = static_cast<const uint8_t*>(packed_b_.get());
b_is_signed = b_is_signed_;
@ -71,11 +71,12 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
// prepare output buffer of GEMM
int32_t* gemm_output_data = nullptr;
BufferUniquePtr gemm_output_buffer;
std::unique_ptr<Tensor> gemm_output_buffer;
bool need_requant = y_scale != nullptr;
if (need_requant) {
gemm_output_data = static_cast<int32_t*>(allocator->Alloc(SafeInt<size_t>(M * N) * sizeof(int32_t)));
gemm_output_buffer.reset(gemm_output_data);
TensorShape outputshape{static_cast<int64_t>(M), static_cast<int64_t>(N)};
gemm_output_buffer = std::make_unique<Tensor>(DataTypeImpl::GetType<int32_t>(), outputshape, allocator);
gemm_output_data = gemm_output_buffer->MutableData<int32_t>();
} else {
gemm_output_data = static_cast<int32_t*>(y->MutableDataRaw());
}

View file

@ -37,7 +37,7 @@ class MatMulIntegerBase : public OpKernel {
const auto* b_data = static_cast<const uint8_t*>(tensor.DataRaw());
BufferUniquePtr b_trans_buffer;
std::unique_ptr<Tensor> b_trans_buffer;
if (IsBTransposed()) {
std::swap(K, N);
b_data = quantization::TransPoseInputData(b_data, b_trans_buffer, alloc, N, K);

View file

@ -188,13 +188,14 @@ void Dequantize(const std::vector<T>& values,
// Transpose the input and store it to a new allocated buffer.
inline uint8_t* TransPoseInputData(const uint8_t* input,
BufferUniquePtr& buffer_holder,
std::unique_ptr<Tensor>& buffer_holder,
AllocatorPtr& allocator,
size_t M,
size_t N) {
uint8_t* output = static_cast<uint8_t*>(allocator->Alloc(M * N * sizeof(uint8_t)));
TensorShape outputshape{static_cast<int64_t>(M), static_cast<int64_t>(N)};
buffer_holder = std::make_unique<Tensor>(DataTypeImpl::GetType<uint8_t>(), outputshape, allocator);
uint8_t* output = buffer_holder->MutableData<uint8_t>();
MlasTranspose(input, output, M, N);
buffer_holder.reset(output);
return output;
}