From 3b0dda0acaacdcf455e232dc7f674b7f96f9bbb9 Mon Sep 17 00:00:00 2001 From: nivas-x86 <43652421+nivas-x86@users.noreply.github.com> Date: Tue, 30 Apr 2019 12:10:28 -0700 Subject: [PATCH] nGraph: Avoid input and output data copies (#940) --- onnxruntime/core/framework/graph_partitioner.cc | 2 +- onnxruntime/core/framework/utils.cc | 4 ++-- onnxruntime/core/optimizer/transformer_memcpy.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 426f9fbf1d..d4d250027c 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -176,7 +176,7 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f //prepare the func kernel KernelDefBuilder builder; BuildFusedKernelDef(builder, *node); - if (node->GetExecutionProviderType() == onnxruntime::kTensorrtExecutionProvider) { + if (node->GetExecutionProviderType() == onnxruntime::kTensorrtExecutionProvider || node->GetExecutionProviderType() == onnxruntime::kNGraphExecutionProvider) { builder.SetDefaultInputsMemoryType(OrtMemTypeCPUInput); builder.SetDefaultOutputMemoryType(OrtMemTypeCPUOutput); } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 2e613fe5f2..e38d769f0b 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -128,8 +128,8 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, ORT_ENFORCE(p_input_provider); } - //no copy for TRT - if (required_provider_type == onnxruntime::kTensorrtExecutionProvider) { + //no copy for TRT and nGraph + if (required_provider_type == onnxruntime::kTensorrtExecutionProvider || required_provider_type == onnxruntime::kNGraphExecutionProvider) { new_mlvalue = orig_mlvalue; break; } diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 3175ee6ec8..ac069f5534 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -184,7 +184,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg } else { // TODO: copy between devices? i.e. multiple GPUs if (node.GetExecutionProviderType() != onnxruntime::kCpuExecutionProvider && node.GetExecutionProviderType() != onnxruntime::kTensorrtExecutionProvider && - !node.GetExecutionProviderType().empty()) { + node.GetExecutionProviderType() != onnxruntime::kNGraphExecutionProvider && !node.GetExecutionProviderType().empty()) { ORT_THROW("Execution type '", node.GetExecutionProviderType(), "' doesn't support memcpy "); }