From ad1701dfb1bdbcaefa56396863bc92b3f1859140 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Tue, 1 Sep 2020 09:25:32 -0700 Subject: [PATCH] Rename DeviceAllocatorRegistrationInfo to a more generic name; Use OrtArenaCfg for arena members; Remove unused OrtMemType; Simplify CreateAllocator interface. (#4970) * Rename DeviceAllocatorRegistrationInfo to a more generic name; Remove OrtMemType; Simplify CreateAllocator interface. * - fix builds - fixed mixed aggregation + constructor calls (which were coded before this PR) - changed default value of max_mem in API header - added some validation of values for for arena_extend_strategy * fix tensorrt and cuda tests --- .../core/session/onnxruntime_c_api.h | 9 ++- onnxruntime/core/framework/allocatormgr.cc | 39 ++++++++--- onnxruntime/core/framework/allocatormgr.h | 42 +++++------- onnxruntime/core/framework/bfc_arena.h | 1 + .../core/framework/provider_bridge_ort.cc | 14 ++-- .../providers/acl/acl_execution_provider.cc | 16 ++--- .../armnn/armnn_execution_provider.cc | 14 ++-- .../providers/cpu/cpu_execution_provider.h | 9 ++- .../providers/cuda/cuda_execution_provider.cc | 67 ++++++++++--------- .../providers/dnnl/dnnl_execution_provider.cc | 32 +++++---- .../migraphx/migraphx_execution_provider.cc | 13 ++-- .../ngraph/ngraph_execution_provider.cc | 12 ++-- .../nnapi_builtin/nnapi_execution_provider.cc | 22 +++--- .../nuphar/nuphar_execution_provider.cc | 13 ++-- .../openvino/openvino_execution_provider.cc | 10 ++- .../rknpu/rknpu_execution_provider.cc | 12 ++-- .../providers/shared_library/provider_api.h | 2 +- .../provider_bridge_provider.cc | 9 ++- .../shared_library/provider_interfaces.h | 21 ++++-- .../tensorrt/tensorrt_execution_provider.cc | 16 ++--- .../vitisai/vitisai_execution_provider.cc | 6 +- onnxruntime/core/session/device_allocator.cc | 50 +++++++------- .../framework/cuda/allocator_cuda_test.cc | 23 +++---- .../test/framework/inference_session_test.cc | 28 +++----- onnxruntime/test/shared_lib/test_inference.cc | 2 +- 25 files changed, 240 insertions(+), 242 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 852ce7dd1f..c88add157a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -146,12 +146,11 @@ typedef enum OrtErrorCode { // This configures the arena based allocator used by ORT // See ONNX_Runtime_Perf_Tuning.md for details on what these mean and how to choose these values -// Use -1 to allow ORT to choose defaults for all the options below typedef struct OrtArenaCfg { - int max_mem; - int arena_extend_strategy; // 0 = kNextPowerOfTwo, 1 = kSameAsRequested - int initial_chunk_size_bytes; - int max_dead_bytes_per_chunk; + size_t max_mem; // use 0 to allow ORT to choose the default + int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default + int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default } OrtArenaCfg; #define ORT_RUNTIME_CLASS(X) \ diff --git a/onnxruntime/core/framework/allocatormgr.cc b/onnxruntime/core/framework/allocatormgr.cc index 8b6bb47e3d..244b6c5677 100644 --- a/onnxruntime/core/framework/allocatormgr.cc +++ b/onnxruntime/core/framework/allocatormgr.cc @@ -4,6 +4,7 @@ #include "core/framework/allocatormgr.h" #include "core/framework/bfc_arena.h" #include "core/framework/mimalloc_arena.h" +#include "core/common/logging/logging.h" #include #include #include @@ -12,21 +13,41 @@ namespace onnxruntime { using namespace common; -AllocatorPtr CreateAllocator(const DeviceAllocatorRegistrationInfo& info, - OrtDevice::DeviceId device_id, bool use_arena) { - auto device_allocator = std::unique_ptr(info.factory(device_id)); +AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) { + auto device_allocator = std::unique_ptr(info.device_alloc_factory(info.device_id)); + + if (info.use_arena) { + size_t max_mem = info.arena_cfg.max_mem == 0 ? BFCArena::DEFAULT_MAX_MEM : info.arena_cfg.max_mem; + int initial_chunk_size_bytes = info.arena_cfg.initial_chunk_size_bytes == -1 + ? BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES + : info.arena_cfg.initial_chunk_size_bytes; + int max_dead_bytes_per_chunk = info.arena_cfg.max_dead_bytes_per_chunk == -1 + ? BFCArena::DEFAULT_MAX_DEAD_BYTES_PER_CHUNK + : info.arena_cfg.max_dead_bytes_per_chunk; + ArenaExtendStrategy arena_extend_str; + switch (info.arena_cfg.arena_extend_strategy) { + case static_cast(ArenaExtendStrategy::kSameAsRequested): + arena_extend_str = ArenaExtendStrategy::kSameAsRequested; + break; + case -1: // default value supplied by user + case static_cast(ArenaExtendStrategy::kNextPowerOfTwo): + arena_extend_str = ArenaExtendStrategy::kNextPowerOfTwo; + break; + default: + LOGS_DEFAULT(ERROR) << "Received invalid value of arena_extend_strategy " << info.arena_cfg.arena_extend_strategy; + return nullptr; + } - if (use_arena) { #ifdef USE_MIMALLOC return std::shared_ptr( - onnxruntime::make_unique(std::move(device_allocator), info.max_mem)); + onnxruntime::make_unique(std::move(device_allocator), max_mem)); #else return std::shared_ptr( onnxruntime::make_unique(std::move(device_allocator), - info.max_mem, - info.arena_extend_strategy, - info.initial_chunk_size_bytes, - info.max_dead_bytes_per_chunk)); + max_mem, + arena_extend_str, + initial_chunk_size_bytes, + max_dead_bytes_per_chunk)); #endif } diff --git a/onnxruntime/core/framework/allocatormgr.h b/onnxruntime/core/framework/allocatormgr.h index c23c954da8..c0b0123b13 100644 --- a/onnxruntime/core/framework/allocatormgr.h +++ b/onnxruntime/core/framework/allocatormgr.h @@ -6,38 +6,32 @@ #include "core/common/common.h" #include "core/framework/arena.h" #include "core/framework/bfc_arena.h" +#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { using DeviceAllocatorFactory = std::function(OrtDevice::DeviceId)>; -// TODO why does DeviceAllocatorRegistrationInfo have arena related configs? -// TODO even if it should, they should be inside their own struct (OrtArenaCfg) as opposed to -// littering them as individual members of DeviceAllocatorRegistrationInfo -struct DeviceAllocatorRegistrationInfo { - DeviceAllocatorRegistrationInfo(OrtMemType ort_mem_type, - DeviceAllocatorFactory alloc_factory, - size_t mem, - ArenaExtendStrategy strategy = BFCArena::DEFAULT_ARENA_EXTEND_STRATEGY, - int initial_chunk_size_bytes0 = BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES, - int max_dead_bytes_per_chunk0 = BFCArena::DEFAULT_MAX_DEAD_BYTES_PER_CHUNK) - : mem_type(ort_mem_type), - factory(alloc_factory), - max_mem(mem), - arena_extend_strategy(strategy), - initial_chunk_size_bytes(initial_chunk_size_bytes0), - max_dead_bytes_per_chunk(max_dead_bytes_per_chunk0) { +struct AllocatorCreationInfo { + AllocatorCreationInfo(DeviceAllocatorFactory device_alloc_factory0, + OrtDevice::DeviceId device_id0 = 0, + bool use_arena0 = true, + OrtArenaCfg arena_cfg0 = {0, -1, -1, -1}) + : device_alloc_factory(device_alloc_factory0), + device_id(device_id0), + use_arena(use_arena0), + arena_cfg(arena_cfg0) { } - OrtMemType mem_type; - DeviceAllocatorFactory factory; - size_t max_mem; - ArenaExtendStrategy arena_extend_strategy; - int initial_chunk_size_bytes; - int max_dead_bytes_per_chunk; + DeviceAllocatorFactory device_alloc_factory; + OrtDevice::DeviceId device_id; + bool use_arena; + OrtArenaCfg arena_cfg; }; -AllocatorPtr CreateAllocator(const DeviceAllocatorRegistrationInfo& info, OrtDevice::DeviceId device_id = 0, - bool use_arena = true); +// Returns an allocator based on the creation info provided. +// Returns nullptr if an invalid value of info.arena_cfg.arena_extend_strategy is supplied. +// Valid values can be found in onnxruntime_c_api.h. +AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info); } // namespace onnxruntime diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index 5606559720..08ea6c8264 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -57,6 +57,7 @@ class BFCArena : public IArenaAllocator { static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo; static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1048576; static const int DEFAULT_MAX_DEAD_BYTES_PER_CHUNK = 128 * 1024 * 1024; + static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max(); BFCArena(std::unique_ptr resource_allocator, size_t total_memory, diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index 5ae24a2c1b..0fe2a7cba1 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -215,16 +215,16 @@ struct ProviderHostImpl : ProviderHost { return onnxruntime::make_unique(name_, type_, device_ ? static_cast(device_)->v_ : OrtDevice(), id_, mem_type_); } - Provider_AllocatorPtr CreateAllocator(const Provider_DeviceAllocatorRegistrationInfo& info, - OrtDevice::DeviceId device_id = 0, - bool use_arena = true) override { - DeviceAllocatorRegistrationInfo info_real{ - info.mem_type, [&info](int value) { + Provider_AllocatorPtr CreateAllocator(const Provider_AllocatorCreationInfo& info) override { + AllocatorCreationInfo info_real{ + [&info](int value) { return std::move(static_cast(&*info.factory(value))->p_); }, - info.max_mem}; + info.device_id, + info.use_arena, + info.arena_cfg}; - return std::make_shared(onnxruntime::CreateAllocator(info_real, device_id, use_arena)); + return std::make_shared(onnxruntime::CreateAllocator(info_real)); } std::unique_ptr CreateCPUAllocator( diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index 3809dd41b4..474773da6b 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -70,24 +70,24 @@ ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kAclExecutionProvider} { ORT_UNUSED_PARAMETER(info); - DeviceAllocatorRegistrationInfo default_memory_info{ - OrtMemTypeDefault, + AllocatorCreationInfo default_memory_info{ [](int) { return onnxruntime::make_unique(OrtMemoryInfo(ACL, OrtAllocatorType::OrtDeviceAllocator)); }, - std::numeric_limits::max()}; + 0, + info.create_arena}; - InsertAllocator(CreateAllocator(default_memory_info, 0, info.create_arena)); + InsertAllocator(CreateAllocator(default_memory_info)); - DeviceAllocatorRegistrationInfo cpu_memory_info{ - OrtMemTypeCPUOutput, + AllocatorCreationInfo cpu_memory_info{ [](int) { return onnxruntime::make_unique( OrtMemoryInfo(ACL_CPU, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); }, - std::numeric_limits::max()}; + 0, + info.create_arena}; - InsertAllocator(CreateAllocator(cpu_memory_info, 0, info.create_arena)); + InsertAllocator(CreateAllocator(cpu_memory_info)); } ACLExecutionProvider::~ACLExecutionProvider() { diff --git a/onnxruntime/core/providers/armnn/armnn_execution_provider.cc b/onnxruntime/core/providers/armnn/armnn_execution_provider.cc index 40b0a67d94..f43e8a905f 100644 --- a/onnxruntime/core/providers/armnn/armnn_execution_provider.cc +++ b/onnxruntime/core/providers/armnn/armnn_execution_provider.cc @@ -43,7 +43,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kArmNNExecutionProvider, kOnnxDo class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kArmNNExecutionProvider, kOnnxDomain, 4, 10, Concat); static void RegisterArmNNKernels(KernelRegistry& kernel_registry) { - #ifdef RELU_ARMNN kernel_registry.Register(BuildKernelCreateInfo()); #endif @@ -82,22 +81,19 @@ ArmNNExecutionProvider::ArmNNExecutionProvider(const ArmNNExecutionProviderInfo& : IExecutionProvider{onnxruntime::kArmNNExecutionProvider} { ORT_UNUSED_PARAMETER(info); - DeviceAllocatorRegistrationInfo default_memory_info{ - OrtMemTypeDefault, + AllocatorCreationInfo default_memory_info{ [](int) { return onnxruntime::make_unique(OrtMemoryInfo(ArmNN, OrtAllocatorType::OrtDeviceAllocator)); }, - std::numeric_limits::max()}; + 0}; InsertAllocator(CreateAllocator(default_memory_info)); - DeviceAllocatorRegistrationInfo cpu_memory_info{ - OrtMemTypeCPUOutput, + AllocatorCreationInfo cpu_memory_info{ [](int) { return onnxruntime::make_unique( OrtMemoryInfo(ArmNN_CPU, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); - }, - std::numeric_limits::max()}; + }}; InsertAllocator(CreateAllocator(cpu_memory_info)); } @@ -112,7 +108,7 @@ std::shared_ptr ArmNNExecutionProvider::GetKernelRegistry() cons std::vector> ArmNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const std::vector& kernel_registries) const { + const std::vector& kernel_registries) const { std::vector> result = IExecutionProvider::GetCapability(graph, kernel_registries); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.h b/onnxruntime/core/providers/cpu/cpu_execution_provider.h index c5f34e1520..2287c7771c 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.h +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.h @@ -27,10 +27,6 @@ class CPUExecutionProvider : public IExecutionProvider { public: explicit CPUExecutionProvider(const CPUExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kCpuExecutionProvider} { - DeviceAllocatorRegistrationInfo device_info{OrtMemTypeDefault, - [](int) { return onnxruntime::make_unique(); }, - std::numeric_limits::max()}; - bool create_arena = info.create_arena; #ifdef USE_JEMALLOC @@ -44,7 +40,10 @@ class CPUExecutionProvider : public IExecutionProvider { create_arena = false; #endif - InsertAllocator(CreateAllocator(device_info, 0, create_arena)); + AllocatorCreationInfo device_info{[](int) { return onnxruntime::make_unique(); }, + 0, create_arena}; + + InsertAllocator(CreateAllocator(device_info)); } std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 70b99b9a98..72b60accee 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -62,16 +62,18 @@ CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_)); CURAND_CALL_THROW(curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); - DeviceAllocatorRegistrationInfo default_memory_info( - {OrtMemTypeDefault, - [](OrtDevice::DeviceId id) { - return onnxruntime::make_unique(id, CUDA); - }, - cuda_mem_limit, - arena_extend_strategy}); + AllocatorCreationInfo default_memory_info( + [](OrtDevice::DeviceId id) { + return onnxruntime::make_unique(id, CUDA); + }, + device_id, + true, + {cuda_mem_limit, + static_cast(arena_extend_strategy), + -1, -1}); // CUDA malloc/free is expensive so always use an arena - allocator_ = CreateAllocator(default_memory_info, device_id, /*create_arena*/ true); + allocator_ = CreateAllocator(default_memory_info); } CUDAExecutionProvider::PerThreadContext::~PerThreadContext() { @@ -135,36 +137,37 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in size_t total = 0; CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); - DeviceAllocatorRegistrationInfo default_memory_info( - {OrtMemTypeDefault, - [](OrtDevice::DeviceId device_id) { - return onnxruntime::make_unique(device_id, CUDA); - }, - cuda_mem_limit_}); + AllocatorCreationInfo default_memory_info( + [](OrtDevice::DeviceId device_id) { + return onnxruntime::make_unique(device_id, CUDA); + }, + device_id_, + true, + {cuda_mem_limit_, + static_cast(arena_extend_strategy_), + -1, -1}); - InsertAllocator(CreateAllocator(default_memory_info, device_id_)); + InsertAllocator(CreateAllocator(default_memory_info)); - DeviceAllocatorRegistrationInfo pinned_memory_info( - {OrtMemTypeCPUOutput, - [](OrtDevice::DeviceId device_id) { - return onnxruntime::make_unique(device_id, CUDA_PINNED); - }, - std::numeric_limits::max()}); + AllocatorCreationInfo pinned_memory_info( + [](OrtDevice::DeviceId device_id) { + return onnxruntime::make_unique(device_id, CUDA_PINNED); + }, + CPU_ALLOCATOR_DEVICE_ID); - InsertAllocator(CreateAllocator(pinned_memory_info, CPU_ALLOCATOR_DEVICE_ID)); + InsertAllocator(CreateAllocator(pinned_memory_info)); // TODO: this is actually used for the cuda kernels which explicitly ask for inputs from CPU. // This will be refactored/removed when allocator and execution provider are decoupled. - DeviceAllocatorRegistrationInfo cpu_memory_info( - {OrtMemTypeCPUInput, - [](int device_id) { - return onnxruntime::make_unique( - OrtMemoryInfo("CUDA_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), device_id, - OrtMemTypeCPUInput)); - }, - std::numeric_limits::max()}); + AllocatorCreationInfo cpu_memory_info( + [](int device_id) { + return onnxruntime::make_unique( + OrtMemoryInfo("CUDA_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), device_id, + OrtMemTypeCPUInput)); + }, + CPU_ALLOCATOR_DEVICE_ID); - InsertAllocator(CreateAllocator(cpu_memory_info, CPU_ALLOCATOR_DEVICE_ID)); + InsertAllocator(CreateAllocator(cpu_memory_info)); UpdateProviderOptionsInfo(); } @@ -812,7 +815,7 @@ KernelCreateInfo BuildKernelCreateInfo() { static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, //default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, //default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 5006e53153..afe603d146 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -28,25 +28,23 @@ constexpr const char* DNNL_CPU = "DnnlCpu"; DNNLExecutionProvider::DNNLExecutionProvider(const DNNLExecutionProviderInfo& info) : Provider_IExecutionProvider{onnxruntime::kDnnlExecutionProvider} { - Provider_DeviceAllocatorRegistrationInfo default_memory_info( - {OrtMemTypeDefault, - [](int) { - return onnxruntime::Provider_CreateCPUAllocator( - onnxruntime::Provider_OrtMemoryInfo::Create(DNNL, OrtAllocatorType::OrtDeviceAllocator)); - }, - std::numeric_limits::max()}); + Provider_AllocatorCreationInfo default_memory_info( + {[](int) { + return onnxruntime::Provider_CreateCPUAllocator( + onnxruntime::Provider_OrtMemoryInfo::Create(DNNL, OrtAllocatorType::OrtDeviceAllocator)); + }}, + 0, info.create_arena); - Provider_DeviceAllocatorRegistrationInfo cpu_memory_info( - {OrtMemTypeCPUOutput, - [](int) { - return onnxruntime::Provider_CreateCPUAllocator( - onnxruntime::Provider_OrtMemoryInfo::Create(DNNL_CPU, OrtAllocatorType::OrtDeviceAllocator, nullptr, 0, - OrtMemTypeCPUOutput)); - }, - std::numeric_limits::max()}); + Provider_AllocatorCreationInfo cpu_memory_info( + {[](int) { + return onnxruntime::Provider_CreateCPUAllocator( + onnxruntime::Provider_OrtMemoryInfo::Create(DNNL_CPU, OrtAllocatorType::OrtDeviceAllocator, nullptr, 0, + OrtMemTypeCPUOutput)); + }}, + 0, info.create_arena); - Provider_InsertAllocator(CreateAllocator(default_memory_info, 0, info.create_arena)); - Provider_InsertAllocator(CreateAllocator(cpu_memory_info, 0, info.create_arena)); + Provider_InsertAllocator(CreateAllocator(default_memory_info)); + Provider_InsertAllocator(CreateAllocator(cpu_memory_info)); } // namespace onnxruntime DNNLExecutionProvider::~DNNLExecutionProvider() { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index d7662d9c51..5ff47f3a8c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -87,14 +87,15 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider} { // Set GPU device to be used hipSetDevice(info.device_id); - DeviceAllocatorRegistrationInfo default_memory_info( - {OrtMemTypeDefault, [](int id) { return onnxruntime::make_unique(id, MIGRAPHX); }, std::numeric_limits::max()}); - allocator_ = CreateAllocator(default_memory_info, device_id_); + AllocatorCreationInfo default_memory_info( + [](int id) { return onnxruntime::make_unique(id, MIGRAPHX); }, device_id_); + allocator_ = CreateAllocator(default_memory_info); InsertAllocator(allocator_); - DeviceAllocatorRegistrationInfo pinned_memory_info( - {OrtMemTypeCPUOutput, [](int) { return onnxruntime::make_unique(0, MIGRAPHX_PINNED); }, std::numeric_limits::max()}); - InsertAllocator(CreateAllocator(pinned_memory_info, device_id_)); + AllocatorCreationInfo pinned_memory_info( + [](int) { return onnxruntime::make_unique(0, MIGRAPHX_PINNED); }, + device_id_); + InsertAllocator(CreateAllocator(pinned_memory_info)); // create the target based on the device_id hipDeviceProp_t prop; diff --git a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc index e2610bc4f2..f2d4b16a4d 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc +++ b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc @@ -35,22 +35,18 @@ NGRAPHExecutionProvider::NGRAPHExecutionProvider(const NGRAPHExecutionProviderIn : IExecutionProvider{onnxruntime::kNGraphExecutionProvider} { ORT_ENFORCE(info.ng_backend_type == "CPU", "nGraph Execution Provider for onnxruntime currently is only supported for CPU backend."); - DeviceAllocatorRegistrationInfo default_memory_info{ - OrtMemTypeDefault, + AllocatorCreationInfo default_memory_info{ [](int) { return onnxruntime::make_unique(OrtMemoryInfo(NGRAPH, OrtAllocatorType::OrtDeviceAllocator)); - }, - std::numeric_limits::max()}; + }}; InsertAllocator(CreateAllocator(default_memory_info)); - DeviceAllocatorRegistrationInfo cpu_memory_info{ - OrtMemTypeCPUOutput, + AllocatorCreationInfo cpu_memory_info{ [](int) { return onnxruntime::make_unique( OrtMemoryInfo(NGRAPH, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); - }, - std::numeric_limits::max()}; + }}; InsertAllocator(CreateAllocator(cpu_memory_info)); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 641c0602d9..4416c0b4c4 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -15,22 +15,18 @@ constexpr const char* NNAPI = "Nnapi"; NnapiExecutionProvider::NnapiExecutionProvider() : IExecutionProvider{onnxruntime::kNnapiExecutionProvider} { - DeviceAllocatorRegistrationInfo device_info( - {OrtMemTypeDefault, - [](int) { - return onnxruntime::make_unique(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator)); - }, - std::numeric_limits::max()}); + AllocatorCreationInfo device_info( + [](int) { + return onnxruntime::make_unique(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator)); + }); InsertAllocator(CreateAllocator(device_info)); - DeviceAllocatorRegistrationInfo cpu_memory_info( - {OrtMemTypeCPUOutput, - [](int) { - return onnxruntime::make_unique( - OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); - }, - std::numeric_limits::max()}); + AllocatorCreationInfo cpu_memory_info( + [](int) { + return onnxruntime::make_unique( + OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); + }); InsertAllocator(CreateAllocator(cpu_memory_info)); } diff --git a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc index 94af286db1..000022f21b 100644 --- a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc +++ b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc @@ -103,14 +103,13 @@ NupharExecutionProvider::NupharExecutionProvider(const NupharExecutionProviderIn whole_graph_shape_infer_ = std::make_shared(); - DeviceAllocatorRegistrationInfo memory_info( - {OrtMemTypeDefault, - [](int /*id*/) { - return onnxruntime::make_unique(OrtMemoryInfo("Nuphar", OrtAllocatorType::OrtDeviceAllocator)); - }, - std::numeric_limits::max()}); + AllocatorCreationInfo memory_info( + [](int /*id*/) { + return onnxruntime::make_unique(OrtMemoryInfo("Nuphar", OrtAllocatorType::OrtDeviceAllocator)); + }, + static_cast(tvm_ctx_.device_id)); - InsertAllocator(CreateAllocator(memory_info, tvm_ctx_.device_id)); + InsertAllocator(CreateAllocator(memory_info)); // TODO add multi-target support tvm_codegen_manager_ = onnxruntime::make_unique(); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index abd2800831..d2fc3cdc01 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -20,12 +20,10 @@ constexpr const char* OpenVINO = "OpenVINO"; OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kOpenVINOExecutionProvider}, info_(info) { - DeviceAllocatorRegistrationInfo device_info( - {OrtMemTypeDefault, - [](int) { - return std::make_unique(OrtMemoryInfo(OpenVINO, OrtDeviceAllocator)); - }, - std::numeric_limits::max()}); + AllocatorCreationInfo device_info( + [](int) { + return std::make_unique(OrtMemoryInfo(OpenVINO, OrtDeviceAllocator)); + }); InsertAllocator(CreateAllocator(device_info)); } diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index 41b714837a..c43796cdc6 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -39,22 +39,18 @@ struct RknpuFuncState { RknpuExecutionProvider::RknpuExecutionProvider() : IExecutionProvider{onnxruntime::kRknpuExecutionProvider} { - DeviceAllocatorRegistrationInfo default_memory_info{ - OrtMemTypeDefault, + AllocatorCreationInfo default_memory_info{ [](int) { return onnxruntime::make_unique(OrtMemoryInfo(RKNPU, OrtAllocatorType::OrtDeviceAllocator)); - }, - std::numeric_limits::max()}; + }}; InsertAllocator(CreateAllocator(default_memory_info)); - DeviceAllocatorRegistrationInfo cpu_memory_info{ - OrtMemTypeCPUOutput, + AllocatorCreationInfo cpu_memory_info{ [](int) { return onnxruntime::make_unique( OrtMemoryInfo(RKNPU, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); - }, - std::numeric_limits::max()}; + }}; InsertAllocator(CreateAllocator(cpu_memory_info)); } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 97bfe4fb4e..d2ed69a5a0 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -160,7 +160,7 @@ using IAllocatorUniquePtr = std::unique_ptr>; std::unique_ptr Provider_CreateCPUAllocator(std::unique_ptr memory_info); std::unique_ptr Provider_CreateCUDAAllocator(int16_t device_id, const char* name); std::unique_ptr Provider_CreateCUDAPinnedAllocator(int16_t device_id, const char* name); -Provider_AllocatorPtr CreateAllocator(const Provider_DeviceAllocatorRegistrationInfo& info, int16_t device_id = 0, bool use_arena = true); +Provider_AllocatorPtr CreateAllocator(const Provider_AllocatorCreationInfo& info); std::unique_ptr Provider_CreateGPUDataTransfer(); diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 9fd266daed..2c24d7898a 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -51,9 +51,8 @@ void operator delete(void* p, size_t /*size*/) { return onnxruntime::g_host->Hea namespace onnxruntime { -Provider_AllocatorPtr CreateAllocator(const Provider_DeviceAllocatorRegistrationInfo& info, int16_t device_id, - bool use_arena) { - return g_host->CreateAllocator(info, device_id, use_arena); +Provider_AllocatorPtr CreateAllocator(const Provider_AllocatorCreationInfo& info) { + return g_host->CreateAllocator(info); } std::unique_ptr Provider_OrtMemoryInfo::Create( @@ -144,8 +143,8 @@ bool CPUIDInfo::HasAVX512f() const { return g_host->CPU_HasAVX512f(); } -Provider_AllocatorPtr CreateAllocator(Provider_DeviceAllocatorRegistrationInfo info, int16_t device_id) { - return g_host->CreateAllocator(info, device_id); +Provider_AllocatorPtr CreateAllocator(Provider_AllocatorCreationInfo info) { + return g_host->CreateAllocator(info); } std::unique_ptr Provider_CreateCPUAllocator(std::unique_ptr info) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 027bd4a4f2..14ef1e95ab 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -152,10 +152,22 @@ struct Provider_IDeviceAllocator : Provider_IAllocator {}; using Provider_AllocatorPtr = std::shared_ptr; using Provider_DeviceAllocatorFactory = std::function(int)>; -struct Provider_DeviceAllocatorRegistrationInfo { - OrtMemType mem_type; +using DeviceId = int16_t; +struct Provider_AllocatorCreationInfo { + Provider_AllocatorCreationInfo(Provider_DeviceAllocatorFactory device_alloc_factory0, + DeviceId device_id0 = 0, + bool use_arena0 = true, + OrtArenaCfg arena_cfg0 = {0, -1, -1, -1}) + : factory(device_alloc_factory0), + device_id(device_id0), + use_arena(use_arena0), + arena_cfg(arena_cfg0) { + } + Provider_DeviceAllocatorFactory factory; - size_t max_mem; + DeviceId device_id; + bool use_arena; + OrtArenaCfg arena_cfg; }; struct Provider_OpKernel { @@ -261,8 +273,7 @@ struct Provider { // calls the virtual function (which will lead to infinite recursion in the bridge). There is no known way to get the non virtual member // function pointer implementation in this case. struct ProviderHost { - virtual Provider_AllocatorPtr CreateAllocator(const Provider_DeviceAllocatorRegistrationInfo& info, - int16_t device_id = 0, bool use_arena = true) = 0; + virtual Provider_AllocatorPtr CreateAllocator(const Provider_AllocatorCreationInfo& info) = 0; virtual logging::Logger* LoggingManager_GetDefaultLogger() = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 84db1c23bf..a694b632a7 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -46,7 +46,7 @@ std::string GetEnginePath(const ::std::string& root, const std::string& name) { } } -std::string GetVecHash(const ::std::vector & vec) { +std::string GetVecHash(const ::std::vector& vec) { std::size_t ret = 0; for (auto& i : vec) { ret ^= std::hash()(i); @@ -171,14 +171,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv : Provider_IExecutionProvider{onnxruntime::kTensorrtExecutionProvider}, device_id_(info.device_id) { CUDA_CALL_THROW(cudaSetDevice(device_id_)); - Provider_DeviceAllocatorRegistrationInfo default_memory_info( - {OrtMemTypeDefault, [](int id) { return Provider_CreateCUDAAllocator(id, TRT); }, std::numeric_limits::max()}); - allocator_ = CreateAllocator(default_memory_info, device_id_); + Provider_AllocatorCreationInfo default_memory_info( + [](int id) { return Provider_CreateCUDAAllocator(id, TRT); }, device_id_); + allocator_ = CreateAllocator(default_memory_info); Provider_InsertAllocator(allocator_); - Provider_DeviceAllocatorRegistrationInfo pinned_allocator_info( - {OrtMemTypeCPUOutput, [](int) { return Provider_CreateCUDAPinnedAllocator(0, TRT_PINNED); }, std::numeric_limits::max()}); - Provider_InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_)); + Provider_AllocatorCreationInfo pinned_allocator_info( + [](int) { return Provider_CreateCUDAPinnedAllocator(0, TRT_PINNED); }, device_id_); + Provider_InsertAllocator(CreateAllocator(pinned_allocator_info)); // Get environment variables const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); @@ -1062,7 +1062,7 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vectorruntime; trt_state->engine->reset(); *(trt_state->engine) = tensorrt_ptr::unique_pointer( - runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); if (trt_state->engine->get() == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 7345e03b66..3b9e6d15e9 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -31,12 +31,10 @@ typedef std::shared_ptr XLayerHolder; VitisAIExecutionProvider::VitisAIExecutionProvider(const VitisAIExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, backend_type_(info.backend_type), device_id_(info.device_id) { - DeviceAllocatorRegistrationInfo default_memory_info{ - OrtMemTypeDefault, + AllocatorCreationInfo default_memory_info{ [](int) { return onnxruntime::make_unique(OrtMemoryInfo(VITISAI, OrtAllocatorType::OrtDeviceAllocator)); - }, - std::numeric_limits::max()}; + }}; InsertAllocator(CreateAllocator(default_memory_info)); } diff --git a/onnxruntime/core/session/device_allocator.cc b/onnxruntime/core/session/device_allocator.cc index 3b4012e4b3..f87920dddc 100644 --- a/onnxruntime/core/session/device_allocator.cc +++ b/onnxruntime/core/session/device_allocator.cc @@ -65,37 +65,41 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _I #endif AllocatorPtr allocator_ptr; - size_t max_mem = std::numeric_limits::max(); - // create appropriate DeviceAllocatorRegistrationInfo and allocator based on create_arena if (create_arena) { - ArenaExtendStrategy arena_extend_strategy = BFCArena::DEFAULT_ARENA_EXTEND_STRATEGY; - int initial_chunk_size_bytes = BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES; - int max_dead_bytes_per_chunk = BFCArena::DEFAULT_MAX_DEAD_BYTES_PER_CHUNK; + // defaults in case arena_cfg is nullptr (not supplied by the user) + size_t max_mem = 0; + int arena_extend_strategy = -1; + int initial_chunk_size_bytes = -1; + int max_dead_bytes_per_chunk = -1; + + // override with values from the user supplied arena_cfg object if (arena_cfg) { - if (arena_cfg->max_mem != -1) max_mem = arena_cfg->max_mem; - if (arena_cfg->arena_extend_strategy == 0) { - arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo; - } else if (arena_cfg->arena_extend_strategy == 1) { - arena_extend_strategy = ArenaExtendStrategy::kSameAsRequested; + max_mem = arena_cfg->max_mem; + + arena_extend_strategy = arena_cfg->arena_extend_strategy; + // validate the value here + if (!(arena_extend_strategy == -1 || arena_extend_strategy == 0 || arena_extend_strategy == 1)) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Received invalid value for arena extend strategy." + " Valid values can be either 0, 1 or -1."); } - if (arena_cfg->initial_chunk_size_bytes != -1) initial_chunk_size_bytes = arena_cfg->initial_chunk_size_bytes; - if (arena_cfg->max_dead_bytes_per_chunk != -1) max_dead_bytes_per_chunk = arena_cfg->max_dead_bytes_per_chunk; + + initial_chunk_size_bytes = arena_cfg->initial_chunk_size_bytes; + max_dead_bytes_per_chunk = arena_cfg->max_dead_bytes_per_chunk; } - DeviceAllocatorRegistrationInfo device_info{ - OrtMemTypeDefault, + OrtArenaCfg l_arena_cfg{max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk}; + AllocatorCreationInfo alloc_creation_info{ [mem_info](int) { return onnxruntime::make_unique(*mem_info); }, - max_mem, - arena_extend_strategy, - initial_chunk_size_bytes, - max_dead_bytes_per_chunk}; - allocator_ptr = CreateAllocator(device_info, 0, create_arena); + 0, + create_arena, + l_arena_cfg}; + allocator_ptr = CreateAllocator(alloc_creation_info); } else { - DeviceAllocatorRegistrationInfo device_info{OrtMemTypeDefault, - [](int) { return onnxruntime::make_unique(); }, - max_mem}; - allocator_ptr = CreateAllocator(device_info, 0, create_arena); + AllocatorCreationInfo alloc_creation_info{[](int) { return onnxruntime::make_unique(); }, + 0, create_arena}; + allocator_ptr = CreateAllocator(alloc_creation_info); } auto st = env->RegisterAllocator(allocator_ptr); diff --git a/onnxruntime/test/framework/cuda/allocator_cuda_test.cc b/onnxruntime/test/framework/cuda/allocator_cuda_test.cc index 9b9fc5895b..2b159f5b35 100644 --- a/onnxruntime/test/framework/cuda/allocator_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/allocator_cuda_test.cc @@ -12,12 +12,10 @@ namespace onnxruntime { namespace test { TEST(AllocatorTest, CUDAAllocatorTest) { OrtDevice::DeviceId cuda_device_id = 0; - DeviceAllocatorRegistrationInfo default_memory_info( - {OrtMemTypeDefault, - [](OrtDevice::DeviceId id) { return onnxruntime::make_unique(id, CUDA); }, - std::numeric_limits::max()}); + AllocatorCreationInfo default_memory_info( + {[](OrtDevice::DeviceId id) { return onnxruntime::make_unique(id, CUDA); }, cuda_device_id}); - auto cuda_arena = CreateAllocator(default_memory_info, cuda_device_id); + auto cuda_arena = CreateAllocator(default_memory_info); size_t size = 1024; @@ -30,10 +28,8 @@ TEST(AllocatorTest, CUDAAllocatorTest) { auto cuda_addr = cuda_arena->Alloc(size); EXPECT_TRUE(cuda_addr); - DeviceAllocatorRegistrationInfo pinned_memory_info( - {OrtMemTypeCPUOutput, - [](int) { return onnxruntime::make_unique(static_cast(0), CUDA_PINNED); }, - std::numeric_limits::max()}); + AllocatorCreationInfo pinned_memory_info( + [](int) { return onnxruntime::make_unique(static_cast(0), CUDA_PINNED); }); auto pinned_allocator = CreateAllocator(pinned_memory_info); @@ -86,12 +82,11 @@ TEST(AllocatorTest, CUDAAllocatorFallbackTest) { // need extra test logic if this ever happens. EXPECT_NE(free, total) << "All memory is free. Test logic does not handle this."; - DeviceAllocatorRegistrationInfo default_memory_info( - {OrtMemTypeDefault, - [](OrtDevice::DeviceId id) { return onnxruntime::make_unique(id, CUDA); }, - std::numeric_limits::max()}); + AllocatorCreationInfo default_memory_info( + {[](OrtDevice::DeviceId id) { return onnxruntime::make_unique(id, CUDA); }, + cuda_device_id}); - auto cuda_arena = CreateAllocator(default_memory_info, cuda_device_id); + auto cuda_arena = CreateAllocator(default_memory_info); // initial allocation that sets the growth size for the next allocation size_t size = total / 2; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 06a7888fad..bd68287015 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -96,13 +96,11 @@ KernelRegistryAndStatus GetFusedKernelRegistry() { class FuseExecutionProvider : public IExecutionProvider { public: explicit FuseExecutionProvider() : IExecutionProvider{kFuseExecutionProvider} { - DeviceAllocatorRegistrationInfo device_info( - {OrtMemTypeDefault, - [](int) { - return onnxruntime::make_unique(OrtMemoryInfo("Fuse", OrtAllocatorType::OrtDeviceAllocator)); - }, - std::numeric_limits::max()}); - InsertAllocator(device_info.factory(0)); + AllocatorCreationInfo device_info{ + [](int) { + return onnxruntime::make_unique(OrtMemoryInfo("Fuse", OrtAllocatorType::OrtDeviceAllocator)); + }}; + InsertAllocator(device_info.device_alloc_factory(0)); } std::vector> @@ -2329,13 +2327,11 @@ TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsUseSameOrtCreatedAllo use_arena = false; #endif OrtMemoryInfo mem_info{onnxruntime::CPU, use_arena ? OrtArenaAllocator : OrtDeviceAllocator}; - size_t max_mem = std::numeric_limits::max(); - DeviceAllocatorRegistrationInfo device_info{ - OrtMemTypeDefault, + AllocatorCreationInfo device_info{ [mem_info](int) { return onnxruntime::make_unique(mem_info); }, - max_mem}; + 0, use_arena}; - AllocatorPtr allocator_ptr = CreateAllocator(device_info, 0, use_arena); + AllocatorPtr allocator_ptr = CreateAllocator(device_info); st = env->RegisterAllocator(allocator_ptr); ASSERT_STATUS_OK(st); // create sessions to share the allocator @@ -2376,13 +2372,11 @@ TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsDontUseSameOrtCreated use_arena = false; #endif OrtMemoryInfo mem_info{onnxruntime::CPU, use_arena ? OrtArenaAllocator : OrtDeviceAllocator}; - size_t max_mem = std::numeric_limits::max(); - DeviceAllocatorRegistrationInfo device_info{ - OrtMemTypeDefault, + AllocatorCreationInfo device_info{ [mem_info](int) { return onnxruntime::make_unique(mem_info); }, - max_mem}; + 0, use_arena}; - AllocatorPtr allocator_ptr = CreateAllocator(device_info, 0, use_arena); + AllocatorPtr allocator_ptr = CreateAllocator(device_info); st = env->RegisterAllocator(allocator_ptr); ASSERT_STATUS_OK(st); // create sessions to share the allocator diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index e0ee29b77b..d36d1d353d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -883,7 +883,7 @@ TEST(CApiTest, TestSharedAllocatorUsingCreateAndRegisterAllocator) { std::unique_ptr rel_info(mem_info, api.ReleaseMemoryInfo); ASSERT_TRUE(api.CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info) == nullptr); - OrtArenaCfg arena_cfg{-1, -1, -1, -1}; + OrtArenaCfg arena_cfg{0, -1, -1, -1}; ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, &arena_cfg) == nullptr); // test for duplicates