diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index cbc2208b6b..9015b23296 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -3,12 +3,14 @@ #pragma once +#include + #include "core/common/common.h" #include "core/framework/allocator_stats.h" +// some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h #include "core/session/onnxruntime_c_api.h" -#include "ortdevice.h" -#include "ortmemoryinfo.h" -#include +#include "core/framework/ortdevice.h" +#include "core/framework/ortmemoryinfo.h" // This configures the arena based allocator used by ORT // See docs/C_API.md for details on what these mean and how to choose these values @@ -68,8 +70,12 @@ class IAllocator { IAllocator(const OrtMemoryInfo& info) : memory_info_(info) {} virtual ~IAllocator() = default; /** - @remarks Use SafeInt when calculating the size of memory to allocate using Alloc. - */ + * Allocate memory of the specified size. + * If size is 0, nullptr is returned. + * If allocation fails, an exception is thrown. + * + * @remarks Use SafeInt when calculating the size of memory to allocate using Alloc. + */ virtual void* Alloc(size_t size) = 0; virtual void Free(void* p) = 0; @@ -100,7 +106,8 @@ class IAllocator { * \param out Total size required after any alignment is applied * \return true, successful. false, overflow */ - [[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept; + [[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, + size_t* out) noexcept; /** * https://cwe.mitre.org/data/definitions/190.html @@ -120,8 +127,10 @@ class IAllocator { */ void* AllocArray(size_t nmemb, size_t size) { size_t len; - if (!CalcMemSizeForArray(nmemb, size, &len)) - return nullptr; + if (!CalcMemSizeForArray(nmemb, size, &len)) { + ORT_THROW("Invalid size requested for allocation: ", nmemb, " * ", size); + } + return Alloc(len); } @@ -131,8 +140,10 @@ class IAllocator { template void* AllocArrayWithAlignment(size_t nmemb, size_t size) { size_t len; - if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len)) - return nullptr; + if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len)) { + ORT_THROW("Invalid size requested for allocation: ", nmemb, " * ", size, " with alignment ", alignment); + } + return Alloc(len); } @@ -144,13 +155,14 @@ class IAllocator { @param stream Which stream instance allocated chunk will be used with. @param wait_fn If the allocator want to dynamic reuse a chunk from another stream, use this wait_fn to sync on the target stream to make the reuse safe. - @returns std::unique_ptr with allocated memory and deleter. + @returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory. */ template static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes, bool use_reserve = false, Stream* stream = nullptr, WaitNotificationFn wait_fn = nullptr) { - if (allocator == nullptr) return nullptr; + ValidateAllocator(allocator); + // for now limit to fundamental types. we could support others, but to do so either we or the caller // needs to call the dtor for the objects, for buffers allocated on device we don't have destructor // static_assert(std::is_fundamental::value, "Fundamental type required as no destructors are called."); @@ -161,38 +173,73 @@ class IAllocator { if constexpr (!std::is_void::value) { // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't // reachable if T is void. use std::conditional to 'use' void* in the sizeof call - if (!CalcMemSizeForArray( - count_or_bytes, sizeof(typename std::conditional::value, void*, T>::type), &alloc_size)) { - return nullptr; - } + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + alloc_size = ValidatedCalcMemSizeForArray(count_or_bytes, size); } // allocate T* p = static_cast(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, std::move(wait_fn))); - return IAllocatorUniquePtr{ - p, - [allocator = std::move(allocator)](T* p) { allocator->Free(p); }}; + ValidateAllocation(p, alloc_size); + + return IAllocatorUniquePtr{p, + [allocator = std::move(allocator)](T* p) { + allocator->Free(p); + }}; } + /** + Create a std::unique_ptr that is allocated and freed by the provided OrtAllocator. + @param ort_allocator The allocator. + @param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate. + @returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory. + */ template static IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) { - if (!ort_allocator) return nullptr; + ValidateAllocator(ort_allocator); size_t alloc_size = count_or_bytes; // if T is not void, 'count_or_bytes' == number of items so allow for that if constexpr (!std::is_void::value) { // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't // reachable if T is void. use std::conditional to 'use' void* in the sizeof call - if (!CalcMemSizeForArray( - count_or_bytes, sizeof(typename std::conditional::value, void*, T>::type), &alloc_size)) { - return nullptr; - } + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + alloc_size = ValidatedCalcMemSizeForArray(count_or_bytes, size); } - T* p = static_cast(ort_allocator->Alloc(ort_allocator, count_or_bytes)); - return IAllocatorUniquePtr{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }}; + + T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + ValidateAllocation(p, alloc_size); + + return IAllocatorUniquePtr{p, + [ort_allocator](T* p) { + ort_allocator->Free(ort_allocator, p); + }}; } private: + // + // validation functions. split out from methods that are templatized on the data type to minimize binary size. + // + + template + static void ValidateAllocator(const T& allocator) { + ORT_ENFORCE(allocator != nullptr); + } + + static size_t ValidatedCalcMemSizeForArray(size_t count, size_t size) { + size_t alloc_size = 0; + if (!CalcMemSizeForArray(count, size, &alloc_size)) { + ORT_THROW("Invalid size requested for allocation: ", count, " * ", size); + } + + return alloc_size; + } + + static void ValidateAllocation(void* p, size_t size) { + // allocator should throw directly but in case it didn't ensure we do here so that calling code doesn't + // need to check for nullptr when an actual allocation was expected. + ORT_ENFORCE(p != nullptr || size == 0, "Memory allocation failed. Size=", size); + }; + OrtMemoryInfo memory_info_; }; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index b060d500c6..a9703dc68d 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -71,9 +71,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat if (packed_b_size_ == 0) return Status::OK(); auto qptr = tensor.Data(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - if (packed_b_ == nullptr) { - return Status::OK(); - } std::memset(packed_b_.get(), 0, packed_b_size_); MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), is_asym_, false, compt_type, pool); diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 2499ead9ef..c3e96e450c 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -33,7 +33,7 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz ORT_CATCH(const OnnxRuntimeException& ex) { // overflow in calculating the size thrown by SafeInt. ORT_HANDLE_EXCEPTION([&]() { - LOGS_DEFAULT(ERROR) << ex.what(); + LOGS_DEFAULT(ERROR) << ex.what() << " nmemb=" << nmemb << " size=" << size << " alignment=" << alignment; ok = false; }); } diff --git a/onnxruntime/core/framework/sparse_tensor.cc b/onnxruntime/core/framework/sparse_tensor.cc index 5af2f4e4b5..a3bcea4762 100644 --- a/onnxruntime/core/framework/sparse_tensor.cc +++ b/onnxruntime/core/framework/sparse_tensor.cc @@ -220,7 +220,6 @@ Status SparseTensor::AllocateBuffer(int64_t buffer_size, size_t num_values) { ORT_RETURN_IF_NOT(buffer_size_t > values_bytes, "Values size ", static_cast(values_bytes), " must be less than total buffer size: ", buffer_size); auto data_ptr = IAllocator::MakeUniquePtr(allocator_, buffer_size_t); - ORT_RETURN_IF(data_ptr == nullptr, "SparseTensor Allocation failed for size: ", buffer_size); if (IsDataTypeString()) { // We own the buffer, so we must properly construct strings. Neither of the Tensors // we construct on top of the buffer own it. We are constructing empty strings, hopefully @@ -592,4 +591,4 @@ Status SparseTensor::Copy(const IDataTransfer& data_transfer, SparseTensor& dst_ } // namespace onnxruntime -#endif // !defined(DISABLE_SPARSE_TENSORS) \ No newline at end of file +#endif // !defined(DISABLE_SPARSE_TENSORS)