diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index a0ee87bbf4..854d797f9e 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -156,8 +156,6 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr, // return Status::OK(); //} - assert(source_mlvalue.IsTensor()); - auto& source_tensor = source_mlvalue.Get(); if (!target_mlvalue.IsAllocated()) { ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device, @@ -647,7 +645,7 @@ void DumpNodeOutputs(OpKernelContext& context, const Node& node, const SessionSt } } #else - ORT_UNUSED_PARAMETER(session_state); + ORT_UNUSED_PARAMETER(session_state); #endif } } diff --git a/onnxruntime/core/providers/cpu/controlflow/if.cc b/onnxruntime/core/providers/cpu/controlflow/if.cc index 58e265a1bd..01d2675e2d 100644 --- a/onnxruntime/core/providers/cpu/controlflow/if.cc +++ b/onnxruntime/core/providers/cpu/controlflow/if.cc @@ -304,10 +304,10 @@ Status IfImpl::Execute(const FeedsFetchesManager& ffm) { // allocation plan for the If node's output is used. fetch_allocators[i] = [this, i, &fetches](const TensorShape& shape, const OrtMemoryInfo& location, OrtValue& ort_value, bool& allocated) { - // for now we only allocate on CPU as currently all 'If' outputs are on CPU. - // if that does not match the required device we don't update the provided OrtValue and return false for - // 'allocated'. the execution frame will allocate a buffer on the required device, and the fetches copy - // logic in utils::ExecuteSubgraph will handle moving it to CPU (and into the tensor we allocated here) + // if the device the If output is allocated on does not match the required device for the subgraph output + // we don't update the provided OrtValue and return false for 'allocated'. + // the execution frame will allocate a buffer on the required device, and the fetches copy + // logic in utils::ExecuteSubgraph will handle moving it into the tensor we allocated here. auto* tensor = context_.Output(i, shape); if (!tensor) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for If output ", i); diff --git a/onnxruntime/core/providers/cpu/controlflow/if.h b/onnxruntime/core/providers/cpu/controlflow/if.h index 5c76c908e1..788927b5ae 100644 --- a/onnxruntime/core/providers/cpu/controlflow/if.h +++ b/onnxruntime/core/providers/cpu/controlflow/if.h @@ -13,7 +13,7 @@ namespace onnxruntime { class SessionState; -class If final : public OpKernel, public controlflow::IControlFlowKernel { +class If : public OpKernel, public controlflow::IControlFlowKernel { public: If(const OpKernelInfo& info); diff --git a/onnxruntime/core/providers/cpu/tensor/gather.cc b/onnxruntime/core/providers/cpu/tensor/gather.cc index 01665d4deb..c0d4a4dac9 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather.cc @@ -11,13 +11,19 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Gather, 1, 10, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Gather); ONNX_CPU_OPERATOR_KERNEL( Gather, 11, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Gather); Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { @@ -63,7 +69,7 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin if (idx < -axis_dim_limit || idx >= axis_dim_limit) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices element out of data bounds, idx=", idx, - " must be within the inclusive range [", -axis_dim_limit,",", axis_dim_limit - 1, "]"); + " must be within the inclusive range [", -axis_dim_limit, ",", axis_dim_limit - 1, "]"); } } diff --git a/onnxruntime/core/providers/cuda/controlflow/if.cc b/onnxruntime/core/providers/cuda/controlflow/if.cc new file mode 100644 index 0000000000..4a79145588 --- /dev/null +++ b/onnxruntime/core/providers/cuda/controlflow/if.cc @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/controlflow/if.h" + +#include "core/providers/cuda/cuda_common.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + If); + +// output shape rules requiring the output shapes of the 'THEN' and 'ELSE' +// branches to be the same were relaxed in opset-11 +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 11, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + If); + +Status If::Compute(OpKernelContext* ctx) const { + // call the base CPU version. + // we have this CUDA implementation so the inputs/outputs stay on GPU where possible. + // the logic to run the subgraph must be on CPU either way. + // technically we don't need this override of Compute, but it will be optimized out and it's easier to debug + // that this implementation is being called with it. + auto status = onnxruntime::If::Compute(ctx); + return status; +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/controlflow/if.h b/onnxruntime/core/providers/cuda/controlflow/if.h new file mode 100644 index 0000000000..f182bbeba2 --- /dev/null +++ b/onnxruntime/core/providers/cuda/controlflow/if.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "gsl/gsl" + +#include "core/common/common.h" +#include "core/providers/cpu/controlflow/if.h" + +namespace onnxruntime { +class SessionState; + +namespace cuda { + +// Use the CPU implementation for the logic +class If final : public onnxruntime::If { + public: + If(const OpKernelInfo& info) : onnxruntime::If(info) {} + + Status Compute(OpKernelContext* ctx) const override; +}; +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 0b37ea6383..20ac14cc36 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -544,6 +544,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, NonZero); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If); // opset 10 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, AveragePool); @@ -585,6 +586,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, G class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, If); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ReduceL1); @@ -999,6 +1001,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -1015,6 +1018,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 56000e5a3d..e18aa82de3 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -502,10 +502,16 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio /// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, SessionState& session_state) { for (auto& node : graph.Nodes()) { - // We only need subgraph session state for control flow nodes being handled by the CPU execution provider. + // We only need subgraph session state for control flow nodes being handled by our CPU or CUDA execution provider. // Remove it if it's not needed. - if (node.ContainsSubgraph() && node.GetExecutionProviderType() != kCpuExecutionProvider) { - session_state.RemoveSubgraphSessionState(node.Index()); + if (node.ContainsSubgraph()) { + const auto ep = node.GetExecutionProviderType(); + if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { + session_state.RemoveSubgraphSessionState(node.Index()); + continue; + } + } else { + // not a control flow node continue; } @@ -526,7 +532,7 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), // &*subgraph_info.session_state); - // setup all the info for handling the feeds and fetches used in subraph execution + // setup all the info for handling the feeds and fetches used in subgraph execution auto* p_op_kernel = session_state.GetMutableKernel(node.Index()); ORT_ENFORCE(p_op_kernel); auto& control_flow_kernel = dynamic_cast(*p_op_kernel); diff --git a/onnxruntime/test/providers/cpu/controlflow/if_test.cc b/onnxruntime/test/providers/cpu/controlflow/if_test.cc index 43b405b610..79cc219e48 100644 --- a/onnxruntime/test/providers/cpu/controlflow/if_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/if_test.cc @@ -297,6 +297,12 @@ TEST(If, MixedExecutionProviders) { RunTest(true, options); } +TEST(If, MixedExecutionProvidersOpset11) { + RunOptions options{}; + options.mixed_execution_providers = true; + RunTest(true, options, false, test::OpTester::ExpectResult::kExpectSuccess, "", 11); +} + TEST(If, MixedExecutionProvidersNoShapeInSubgraph) { RunOptions options{}; options.mixed_execution_providers = true;