mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
### Description <!-- Describe your changes. --> If we fail to calculate the buffer size (due to overflow) we currently return a nullptr. This is inconsistent as an actual memory allocation failure throws. An overflow would typically be due to bad input so an exception makes more sense given that. Change to throw so code using MakeUniquePtr* and AllocArray* doesn't need to check for nullptr. Add some extra info to the log message to help debugging. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Should help with #18905 by avoiding the invalid attempted usage of a nullptr from the allocation. Extra info _might_ help with figuring out where the overflow is coming from which is the real issue.
194 lines
6 KiB
C++
194 lines
6 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/common/safeint.h"
|
|
#include "core/framework/allocator.h"
|
|
#include "core/mlas/inc/mlas.h"
|
|
#include "core/framework/utils.h"
|
|
#include "core/session/ort_apis.h"
|
|
#include <cstdlib>
|
|
#include <sstream>
|
|
|
|
#if defined(USE_MIMALLOC)
|
|
#include <mimalloc.h>
|
|
#endif
|
|
|
|
#include "core/framework/bfc_arena.h"
|
|
|
|
namespace onnxruntime {
|
|
|
|
// private helper for calculation so SafeInt usage doesn't bleed into the public allocator.h header
|
|
bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept {
|
|
bool ok = true;
|
|
|
|
ORT_TRY {
|
|
SafeInt<size_t> alloc_size(size);
|
|
if (alignment == 0) {
|
|
*out = alloc_size * nmemb;
|
|
} else {
|
|
size_t alignment_mask = alignment - 1;
|
|
*out = (alloc_size * nmemb + alignment_mask) & ~static_cast<size_t>(alignment_mask);
|
|
}
|
|
}
|
|
ORT_CATCH(const OnnxRuntimeException& ex) {
|
|
// overflow in calculating the size thrown by SafeInt.
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
LOGS_DEFAULT(ERROR) << ex.what() << " nmemb=" << nmemb << " size=" << size << " alignment=" << alignment;
|
|
ok = false;
|
|
});
|
|
}
|
|
return ok;
|
|
}
|
|
|
|
#ifdef USE_MIMALLOC
|
|
void* AllocatorDefaultAlloc(size_t size) {
|
|
const size_t alignment = MlasGetPreferredBufferAlignment();
|
|
if (size <= 0) return nullptr;
|
|
size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
|
|
void* p;
|
|
#if defined(_MSC_VER)
|
|
p = mi_malloc_aligned(size, alignment);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#elif defined(_LIBCPP_SGX_CONFIG)
|
|
p = mi_memalign(alignment, size);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#else
|
|
int ret = mi_posix_memalign(&p, alignment, size);
|
|
if (ret != 0)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#endif
|
|
return p;
|
|
}
|
|
|
|
void AllocatorDefaultFree(void* p) {
|
|
#if defined(_MSC_VER)
|
|
const size_t alignment = MlasGetPreferredBufferAlignment();
|
|
mi_free_aligned(p, alignment);
|
|
#else
|
|
mi_free(p);
|
|
#endif
|
|
}
|
|
|
|
#else
|
|
void* AllocatorDefaultAlloc(size_t size) {
|
|
const size_t alignment = MlasGetPreferredBufferAlignment();
|
|
if (size <= 0) return nullptr;
|
|
size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
|
|
void* p;
|
|
#if _MSC_VER
|
|
p = _aligned_malloc(size, alignment);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#elif defined(_LIBCPP_SGX_CONFIG)
|
|
p = memalign(alignment, size);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#else
|
|
int ret = posix_memalign(&p, alignment, size);
|
|
if (ret != 0)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#endif
|
|
return p;
|
|
}
|
|
|
|
void AllocatorDefaultFree(void* p) {
|
|
#if _MSC_VER
|
|
_aligned_free(p);
|
|
#else
|
|
free(p);
|
|
#endif
|
|
}
|
|
|
|
#endif // USE_MIMALLOC
|
|
|
|
void* CPUAllocator::Alloc(size_t size) {
|
|
return AllocatorDefaultAlloc(size);
|
|
}
|
|
|
|
void CPUAllocator::Free(void* p) {
|
|
AllocatorDefaultFree(p);
|
|
}
|
|
|
|
void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn) {
|
|
if (use_reserve)
|
|
return alloc.Reserve(size);
|
|
if (stream && alloc.Info().alloc_type == OrtArenaAllocator) {
|
|
#ifdef ORT_ENABLE_STREAM
|
|
auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(static_cast<BFCArena&>(alloc));
|
|
if (stream_aware_alloc) {
|
|
return stream_aware_alloc->AllocOnStream(size, stream, wait_fn);
|
|
}
|
|
#else
|
|
ORT_UNUSED_PARAMETER(wait_fn);
|
|
#endif // ORT_ENABLE_STREAM
|
|
}
|
|
return alloc.Alloc(size);
|
|
}
|
|
} // namespace onnxruntime
|
|
|
|
std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return (out << info.ToString()); }
|
|
#if defined(_MSC_VER) && !defined(__clang__)
|
|
#pragma warning(push)
|
|
#pragma warning(disable : 26409)
|
|
#endif
|
|
ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1,
|
|
enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) {
|
|
if (strcmp(name1, onnxruntime::CPU) == 0) {
|
|
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
|
|
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
|
|
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
|
|
strcmp(name1, onnxruntime::DML) == 0 ||
|
|
strcmp(name1, onnxruntime::HIP) == 0 ||
|
|
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0) {
|
|
*out = new OrtMemoryInfo(
|
|
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
|
|
mem_type1);
|
|
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
|
|
*out = new OrtMemoryInfo(
|
|
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
|
|
id1, mem_type1);
|
|
} else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) {
|
|
*out = new OrtMemoryInfo(
|
|
onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
|
|
id1, mem_type1);
|
|
} else {
|
|
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API(void, OrtApis::ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo* p) { delete p; }
|
|
#if defined(_MSC_VER) && !defined(__clang__)
|
|
#pragma warning(pop)
|
|
#endif
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out) {
|
|
*out = ptr->name;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out) {
|
|
*out = ptr->id;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out) {
|
|
*out = ptr->mem_type;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out) {
|
|
*out = ptr->alloc_type;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2,
|
|
_Out_ int* out) {
|
|
*out = (*info1 == *info2) ? 0 : -1;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API(void, OrtApis::MemoryInfoGetDeviceType, _In_ const OrtMemoryInfo* info, _Out_ OrtMemoryInfoDeviceType* out) {
|
|
*out = static_cast<OrtMemoryInfoDeviceType>(info->device.Type());
|
|
}
|