diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc index cbbfc646d4..aae5959264 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc @@ -32,16 +32,26 @@ ONNX_OPERATOR_KERNEL_EX( } // namespace mkl_dnn -MKLDNNExecutionProvider::MKLDNNExecutionProvider(const MKLDNNExecutionProviderInfo& /*info*/) - : IExecutionProvider{onnxruntime::kMklDnnExecutionProvider} { +MKLDNNExecutionProvider::MKLDNNExecutionProvider(const MKLDNNExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kMklDnnExecutionProvider} { DeviceAllocatorRegistrationInfo default_allocator_info({OrtMemTypeDefault, [](int) { return std::make_unique(std::make_unique(MKLDNN, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault)); }, std::numeric_limits::max()}); - InsertAllocator(CreateAllocator(default_allocator_info)); DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUOutput, [](int) { return std::make_unique(std::make_unique(MKLDNN_CPU, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeCPUOutput)); }, std::numeric_limits::max()}); - InsertAllocator(CreateAllocator(cpu_allocator_info)); -} + + if (info.create_arena) { + InsertAllocator(CreateAllocator(default_allocator_info)); + + InsertAllocator(CreateAllocator(cpu_allocator_info)); + } else { + InsertAllocator(std::shared_ptr( + std::make_unique(default_allocator_info.factory(0)))); + + InsertAllocator(std::shared_ptr( + std::make_unique(cpu_allocator_info.factory(0)))); + } +} // namespace onnxruntime MKLDNNExecutionProvider::~MKLDNNExecutionProvider() { } @@ -80,19 +90,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomai void RegisterMKLDNNKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) {