diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 07dc8eab41..0b37ea6383 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1204,8 +1204,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, continue; const auto& node = *p_node; - const auto* cuda_kernel_def = GetKernelRegistry()->TryFindKernel(node, Type()); - if (cuda_kernel_def == nullptr || !node.GetExecutionProviderType().empty()) { + const KernelCreateInfo* cuda_kernel_def = nullptr; + if (!node.GetExecutionProviderType().empty() || + !(cuda_kernel_def = GetKernelRegistry()->TryFindKernel(node, Type()))) { // node is not in cuda exeuction provider if no kernel def found, // or if other execution provider already assigned to it defs_outside_cuda.insert(node.OutputDefs().cbegin(), node.OutputDefs().cend());