diff --git a/onnxruntime/core/framework/parallel_executor.cc b/onnxruntime/core/framework/parallel_executor.cc index ad9596f5e5..505f6b3bbf 100644 --- a/onnxruntime/core/framework/parallel_executor.cc +++ b/onnxruntime/core/framework/parallel_executor.cc @@ -138,14 +138,22 @@ void ParallelExecutor::RunNodeAsyncInternal(size_t p_node_index, for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { Fence_t fence = op_kernel_context.InputFence(input_index); if (fence) { - fence->BeforeUsingAsInput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } } for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); if (fence) { - fence->BeforeUsingAsInput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } } diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index d465279104..693e4cc2e0 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -78,14 +78,22 @@ Status SequentialExecutor::Execute(const SessionState& session_state, for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { Fence_t fence = op_kernel_context.InputFence(input_index); if (fence) { - fence->BeforeUsingAsInput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } } for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); if (fence) { - fence->BeforeUsingAsInput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } } diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 1ea86f918b..11596f7b2e 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -327,8 +327,6 @@ int real_main(int argc, char* argv[]) { broken_tests["maxpool_3d_default"] = "cudnn pooling only support input dimension >= 3"; broken_tests["maxpool_1d_default"] = "cudnn pooling only support input dimension >= 3"; - broken_tests["tf_inception_resnet_v2"] = "unknown failure on CUDA"; - broken_tests["tf_inception_v4"] = "unknown failure on CUDA"; broken_tests["fp16_tiny_yolov2"] = "unknown failure on CUDA"; broken_tests["fp16_shufflenet"] = "unknown failure on CUDA"; broken_tests["fp16_inception_v1"] = "unknown failure on CUDA";