mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add CUDA If operator. (#2377)
* Add CUDA If operator. Uses CPU operator for implementation. By adding a CUDA version the inputs/outputs (with the exception of the 'cond' input) stay on GPU, and no other logic is required to avoid a copy to CPU across the control flow node.
This commit is contained in:
parent
1cb6bdc33c
commit
be12cdc73f
9 changed files with 106 additions and 15 deletions
|
|
@ -156,8 +156,6 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
|
|||
// return Status::OK();
|
||||
//}
|
||||
|
||||
assert(source_mlvalue.IsTensor());
|
||||
|
||||
auto& source_tensor = source_mlvalue.Get<Tensor>();
|
||||
if (!target_mlvalue.IsAllocated()) {
|
||||
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device,
|
||||
|
|
@ -647,7 +645,7 @@ void DumpNodeOutputs(OpKernelContext& context, const Node& node, const SessionSt
|
|||
}
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(session_state);
|
||||
ORT_UNUSED_PARAMETER(session_state);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -304,10 +304,10 @@ Status IfImpl::Execute(const FeedsFetchesManager& ffm) {
|
|||
// allocation plan for the If node's output is used.
|
||||
fetch_allocators[i] = [this, i, &fetches](const TensorShape& shape, const OrtMemoryInfo& location,
|
||||
OrtValue& ort_value, bool& allocated) {
|
||||
// for now we only allocate on CPU as currently all 'If' outputs are on CPU.
|
||||
// if that does not match the required device we don't update the provided OrtValue and return false for
|
||||
// 'allocated'. the execution frame will allocate a buffer on the required device, and the fetches copy
|
||||
// logic in utils::ExecuteSubgraph will handle moving it to CPU (and into the tensor we allocated here)
|
||||
// if the device the If output is allocated on does not match the required device for the subgraph output
|
||||
// we don't update the provided OrtValue and return false for 'allocated'.
|
||||
// the execution frame will allocate a buffer on the required device, and the fetches copy
|
||||
// logic in utils::ExecuteSubgraph will handle moving it into the tensor we allocated here.
|
||||
auto* tensor = context_.Output(i, shape);
|
||||
if (!tensor)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for If output ", i);
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
namespace onnxruntime {
|
||||
class SessionState;
|
||||
|
||||
class If final : public OpKernel, public controlflow::IControlFlowKernel {
|
||||
class If : public OpKernel, public controlflow::IControlFlowKernel {
|
||||
public:
|
||||
If(const OpKernelInfo& info);
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,19 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
Gather,
|
||||
1,
|
||||
10,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Gather);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Gather,
|
||||
11,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Gather);
|
||||
|
||||
Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
|
||||
|
|
@ -63,7 +69,7 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
|
|||
if (idx < -axis_dim_limit || idx >= axis_dim_limit) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"indices element out of data bounds, idx=", idx,
|
||||
" must be within the inclusive range [", -axis_dim_limit,",", axis_dim_limit - 1, "]");
|
||||
" must be within the inclusive range [", -axis_dim_limit, ",", axis_dim_limit - 1, "]");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
47
onnxruntime/core/providers/cuda/controlflow/if.cc
Normal file
47
onnxruntime/core/providers/cuda/controlflow/if.cc
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/controlflow/if.h"
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
|
||||
kOnnxDomain,
|
||||
1, 10,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) // 'cond' needs to be on CPU
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
|
||||
If);
|
||||
|
||||
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
|
||||
// branches to be the same were relaxed in opset-11
|
||||
ONNX_OPERATOR_KERNEL_EX(If,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) // 'cond' needs to be on CPU
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
|
||||
If);
|
||||
|
||||
Status If::Compute(OpKernelContext* ctx) const {
|
||||
// call the base CPU version.
|
||||
// we have this CUDA implementation so the inputs/outputs stay on GPU where possible.
|
||||
// the logic to run the subgraph must be on CPU either way.
|
||||
// technically we don't need this override of Compute, but it will be optimized out and it's easier to debug
|
||||
// that this implementation is being called with it.
|
||||
auto status = onnxruntime::If::Compute(ctx);
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/core/providers/cuda/controlflow/if.h
Normal file
24
onnxruntime/core/providers/cuda/controlflow/if.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "gsl/gsl"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cpu/controlflow/if.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class SessionState;
|
||||
|
||||
namespace cuda {
|
||||
|
||||
// Use the CPU implementation for the logic
|
||||
class If final : public onnxruntime::If {
|
||||
public:
|
||||
If(const OpKernelInfo& info) : onnxruntime::If(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
};
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -544,6 +544,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If);
|
||||
|
||||
// opset 10
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, AveragePool);
|
||||
|
|
@ -585,6 +586,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, G
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, If);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ReduceL1);
|
||||
|
|
@ -999,6 +1001,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If)>,
|
||||
|
||||
// opset 11
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMax)>,
|
||||
|
|
@ -1015,6 +1018,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, If)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ReduceL1)>,
|
||||
|
|
|
|||
|
|
@ -502,10 +502,16 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
|
|||
/// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future
|
||||
common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, SessionState& session_state) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
// We only need subgraph session state for control flow nodes being handled by the CPU execution provider.
|
||||
// We only need subgraph session state for control flow nodes being handled by our CPU or CUDA execution provider.
|
||||
// Remove it if it's not needed.
|
||||
if (node.ContainsSubgraph() && node.GetExecutionProviderType() != kCpuExecutionProvider) {
|
||||
session_state.RemoveSubgraphSessionState(node.Index());
|
||||
if (node.ContainsSubgraph()) {
|
||||
const auto ep = node.GetExecutionProviderType();
|
||||
if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) {
|
||||
session_state.RemoveSubgraphSessionState(node.Index());
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// not a control flow node
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -526,7 +532,7 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
|
|||
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
|
||||
// &*subgraph_info.session_state);
|
||||
|
||||
// setup all the info for handling the feeds and fetches used in subraph execution
|
||||
// setup all the info for handling the feeds and fetches used in subgraph execution
|
||||
auto* p_op_kernel = session_state.GetMutableKernel(node.Index());
|
||||
ORT_ENFORCE(p_op_kernel);
|
||||
auto& control_flow_kernel = dynamic_cast<controlflow::IControlFlowKernel&>(*p_op_kernel);
|
||||
|
|
|
|||
|
|
@ -297,6 +297,12 @@ TEST(If, MixedExecutionProviders) {
|
|||
RunTest(true, options);
|
||||
}
|
||||
|
||||
TEST(If, MixedExecutionProvidersOpset11) {
|
||||
RunOptions options{};
|
||||
options.mixed_execution_providers = true;
|
||||
RunTest(true, options, false, test::OpTester::ExpectResult::kExpectSuccess, "", 11);
|
||||
}
|
||||
|
||||
TEST(If, MixedExecutionProvidersNoShapeInSubgraph) {
|
||||
RunOptions options{};
|
||||
options.mixed_execution_providers = true;
|
||||
|
|
|
|||
Loading…
Reference in a new issue