mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-16 01:33:39 +00:00
### Description Enables using the MLTensor to pass data between models. ### Motivation and Context Using MLTensor instead of ArrayBuffers reduces the number of copies between the CPU and devices as well as the renderer and GPU process in Chromium.
199 lines
6.3 KiB
C++
199 lines
6.3 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/common/safeint.h"
|
|
#include "core/framework/allocator.h"
|
|
#include "core/mlas/inc/mlas.h"
|
|
#include "core/framework/utils.h"
|
|
#include "core/session/ort_apis.h"
|
|
#include <cstdlib>
|
|
#include <sstream>
|
|
|
|
#if defined(USE_MIMALLOC)
|
|
#include <mimalloc.h>
|
|
#endif
|
|
|
|
#include "core/framework/bfc_arena.h"
|
|
|
|
namespace onnxruntime {
|
|
|
|
// private helper for calculation so SafeInt usage doesn't bleed into the public allocator.h header
|
|
bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept {
|
|
bool ok = true;
|
|
|
|
ORT_TRY {
|
|
SafeInt<size_t> alloc_size(size);
|
|
if (alignment == 0) {
|
|
*out = alloc_size * nmemb;
|
|
} else {
|
|
size_t alignment_mask = alignment - 1;
|
|
*out = (alloc_size * nmemb + alignment_mask) & ~static_cast<size_t>(alignment_mask);
|
|
}
|
|
}
|
|
ORT_CATCH(const OnnxRuntimeException& ex) {
|
|
// overflow in calculating the size thrown by SafeInt.
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
LOGS_DEFAULT(ERROR) << ex.what() << " nmemb=" << nmemb << " size=" << size << " alignment=" << alignment;
|
|
ok = false;
|
|
});
|
|
}
|
|
return ok;
|
|
}
|
|
|
|
#ifdef USE_MIMALLOC
|
|
void* AllocatorDefaultAlloc(size_t size) {
|
|
const size_t alignment = MlasGetPreferredBufferAlignment();
|
|
if (size <= 0) return nullptr;
|
|
size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
|
|
void* p;
|
|
#if defined(_MSC_VER)
|
|
p = mi_malloc_aligned(size, alignment);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#elif defined(_LIBCPP_SGX_CONFIG)
|
|
p = mi_memalign(alignment, size);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#else
|
|
int ret = mi_posix_memalign(&p, alignment, size);
|
|
if (ret != 0)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#endif
|
|
return p;
|
|
}
|
|
|
|
void AllocatorDefaultFree(void* p) {
|
|
#if defined(_MSC_VER)
|
|
const size_t alignment = MlasGetPreferredBufferAlignment();
|
|
mi_free_aligned(p, alignment);
|
|
#else
|
|
mi_free(p);
|
|
#endif
|
|
}
|
|
|
|
#else
|
|
void* AllocatorDefaultAlloc(size_t size) {
|
|
const size_t alignment = MlasGetPreferredBufferAlignment();
|
|
if (size <= 0) return nullptr;
|
|
size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
|
|
void* p;
|
|
#if _MSC_VER
|
|
p = _aligned_malloc(size, alignment);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#elif defined(_LIBCPP_SGX_CONFIG)
|
|
p = memalign(alignment, size);
|
|
if (p == nullptr)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#else
|
|
int ret = posix_memalign(&p, alignment, size);
|
|
if (ret != 0)
|
|
ORT_THROW_EX(std::bad_alloc);
|
|
#endif
|
|
return p;
|
|
}
|
|
|
|
void AllocatorDefaultFree(void* p) {
|
|
#if _MSC_VER
|
|
_aligned_free(p);
|
|
#else
|
|
free(p);
|
|
#endif
|
|
}
|
|
|
|
#endif // USE_MIMALLOC
|
|
|
|
void* CPUAllocator::Alloc(size_t size) {
|
|
return AllocatorDefaultAlloc(size);
|
|
}
|
|
|
|
void CPUAllocator::Free(void* p) {
|
|
AllocatorDefaultFree(p);
|
|
}
|
|
|
|
void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn) {
|
|
if (use_reserve)
|
|
return alloc.Reserve(size);
|
|
if (stream && alloc.Info().alloc_type == OrtArenaAllocator) {
|
|
#ifdef ORT_ENABLE_STREAM
|
|
auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(static_cast<BFCArena&>(alloc));
|
|
if (stream_aware_alloc) {
|
|
return stream_aware_alloc->AllocOnStream(size, stream, wait_fn);
|
|
}
|
|
#else
|
|
ORT_UNUSED_PARAMETER(wait_fn);
|
|
#endif // ORT_ENABLE_STREAM
|
|
}
|
|
return alloc.Alloc(size);
|
|
}
|
|
} // namespace onnxruntime
|
|
|
|
std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return (out << info.ToString()); }
|
|
#if defined(_MSC_VER) && !defined(__clang__)
|
|
#pragma warning(push)
|
|
#pragma warning(disable : 26409)
|
|
#endif
|
|
ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1,
|
|
enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) {
|
|
if (strcmp(name1, onnxruntime::CPU) == 0) {
|
|
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
|
|
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
|
|
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
|
|
strcmp(name1, onnxruntime::DML) == 0 ||
|
|
strcmp(name1, onnxruntime::HIP) == 0 ||
|
|
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
|
|
strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
|
|
*out = new OrtMemoryInfo(
|
|
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
|
|
mem_type1);
|
|
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
|
|
*out = new OrtMemoryInfo(
|
|
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
|
|
mem_type1);
|
|
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
|
|
*out = new OrtMemoryInfo(
|
|
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
|
|
id1, mem_type1);
|
|
} else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) {
|
|
*out = new OrtMemoryInfo(
|
|
onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
|
|
id1, mem_type1);
|
|
} else {
|
|
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API(void, OrtApis::ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo* p) { delete p; }
|
|
#if defined(_MSC_VER) && !defined(__clang__)
|
|
#pragma warning(pop)
|
|
#endif
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out) {
|
|
*out = ptr->name;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out) {
|
|
*out = ptr->id;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out) {
|
|
*out = ptr->mem_type;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out) {
|
|
*out = ptr->alloc_type;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(OrtApis::CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2,
|
|
_Out_ int* out) {
|
|
*out = (*info1 == *info2) ? 0 : -1;
|
|
return nullptr;
|
|
}
|
|
|
|
ORT_API(void, OrtApis::MemoryInfoGetDeviceType, _In_ const OrtMemoryInfo* info, _Out_ OrtMemoryInfoDeviceType* out) {
|
|
*out = static_cast<OrtMemoryInfoDeviceType>(info->device.Type());
|
|
}
|