From 7e88ca19ee63af33fabbb72a3d2ec5dd6fc7a147 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 6 May 2019 10:54:45 -0700 Subject: [PATCH] Support the option to disable memory arena in MLDNN provider (#970) Support the option to disable memory arena in MLDNN provider so we can do memory profiling when necessary. --- .../mkldnn/mkldnn_execution_provider.cc | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc index cbbfc646d4..aae5959264 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc @@ -32,16 +32,26 @@ ONNX_OPERATOR_KERNEL_EX( } // namespace mkl_dnn -MKLDNNExecutionProvider::MKLDNNExecutionProvider(const MKLDNNExecutionProviderInfo& /*info*/) - : IExecutionProvider{onnxruntime::kMklDnnExecutionProvider} { +MKLDNNExecutionProvider::MKLDNNExecutionProvider(const MKLDNNExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kMklDnnExecutionProvider} { DeviceAllocatorRegistrationInfo default_allocator_info({OrtMemTypeDefault, [](int) { return std::make_unique(std::make_unique(MKLDNN, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault)); }, std::numeric_limits::max()}); - InsertAllocator(CreateAllocator(default_allocator_info)); DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUOutput, [](int) { return std::make_unique(std::make_unique(MKLDNN_CPU, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeCPUOutput)); }, std::numeric_limits::max()}); - InsertAllocator(CreateAllocator(cpu_allocator_info)); -} + + if (info.create_arena) { + InsertAllocator(CreateAllocator(default_allocator_info)); + + InsertAllocator(CreateAllocator(cpu_allocator_info)); + } else { + InsertAllocator(std::shared_ptr( + std::make_unique(default_allocator_info.factory(0)))); + + InsertAllocator(std::shared_ptr( + std::make_unique(cpu_allocator_info.factory(0)))); + } +} // namespace onnxruntime MKLDNNExecutionProvider::~MKLDNNExecutionProvider() { } @@ -80,19 +90,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomai void RegisterMKLDNNKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) {