Add CUDA If operator. (#2377)

* Add CUDA If operator.
Uses CPU operator for implementation.
By adding a CUDA version the inputs/outputs (with the exception of the 'cond' input) stay on GPU, and no other logic is required to avoid a copy to CPU across the control flow node.
This commit is contained in:
Scott McKay 2019-11-19 12:01:46 +10:00 committed by GitHub
parent 1cb6bdc33c
commit be12cdc73f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 106 additions and 15 deletions

View file

@ -156,8 +156,6 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
// return Status::OK();
//}
assert(source_mlvalue.IsTensor());
auto& source_tensor = source_mlvalue.Get<Tensor>();
if (!target_mlvalue.IsAllocated()) {
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device,
@ -647,7 +645,7 @@ void DumpNodeOutputs(OpKernelContext& context, const Node& node, const SessionSt
}
}
#else
ORT_UNUSED_PARAMETER(session_state);
ORT_UNUSED_PARAMETER(session_state);
#endif
}
}

View file

@ -304,10 +304,10 @@ Status IfImpl::Execute(const FeedsFetchesManager& ffm) {
// allocation plan for the If node's output is used.
fetch_allocators[i] = [this, i, &fetches](const TensorShape& shape, const OrtMemoryInfo& location,
OrtValue& ort_value, bool& allocated) {
// for now we only allocate on CPU as currently all 'If' outputs are on CPU.
// if that does not match the required device we don't update the provided OrtValue and return false for
// 'allocated'. the execution frame will allocate a buffer on the required device, and the fetches copy
// logic in utils::ExecuteSubgraph will handle moving it to CPU (and into the tensor we allocated here)
// if the device the If output is allocated on does not match the required device for the subgraph output
// we don't update the provided OrtValue and return false for 'allocated'.
// the execution frame will allocate a buffer on the required device, and the fetches copy
// logic in utils::ExecuteSubgraph will handle moving it into the tensor we allocated here.
auto* tensor = context_.Output(i, shape);
if (!tensor)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for If output ", i);

View file

@ -13,7 +13,7 @@
namespace onnxruntime {
class SessionState;
class If final : public OpKernel, public controlflow::IControlFlowKernel {
class If : public OpKernel, public controlflow::IControlFlowKernel {
public:
If(const OpKernelInfo& info);

View file

@ -11,13 +11,19 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Gather,
1,
10,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Gather);
ONNX_CPU_OPERATOR_KERNEL(
Gather,
11,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Gather);
Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
@ -63,7 +69,7 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
if (idx < -axis_dim_limit || idx >= axis_dim_limit) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"indices element out of data bounds, idx=", idx,
" must be within the inclusive range [", -axis_dim_limit,",", axis_dim_limit - 1, "]");
" must be within the inclusive range [", -axis_dim_limit, ",", axis_dim_limit - 1, "]");
}
}

View file

@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/controlflow/if.h"
#include "core/providers/cuda/cuda_common.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
1, 10,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
If);
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
// branches to be the same were relaxed in opset-11
ONNX_OPERATOR_KERNEL_EX(If,
kOnnxDomain,
11,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
If);
Status If::Compute(OpKernelContext* ctx) const {
// call the base CPU version.
// we have this CUDA implementation so the inputs/outputs stay on GPU where possible.
// the logic to run the subgraph must be on CPU either way.
// technically we don't need this override of Compute, but it will be optimized out and it's easier to debug
// that this implementation is being called with it.
auto status = onnxruntime::If::Compute(ctx);
return status;
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <functional>
#include "gsl/gsl"
#include "core/common/common.h"
#include "core/providers/cpu/controlflow/if.h"
namespace onnxruntime {
class SessionState;
namespace cuda {
// Use the CPU implementation for the logic
class If final : public onnxruntime::If {
public:
If(const OpKernelInfo& info) : onnxruntime::If(info) {}
Status Compute(OpKernelContext* ctx) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -544,6 +544,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, NonZero);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If);
// opset 10
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, AveragePool);
@ -585,6 +586,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, G
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ReduceL1);
@ -999,6 +1001,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If)>,
// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMax)>,
@ -1015,6 +1018,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ReduceL1)>,

View file

@ -502,10 +502,16 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
/// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future
common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, SessionState& session_state) {
for (auto& node : graph.Nodes()) {
// We only need subgraph session state for control flow nodes being handled by the CPU execution provider.
// We only need subgraph session state for control flow nodes being handled by our CPU or CUDA execution provider.
// Remove it if it's not needed.
if (node.ContainsSubgraph() && node.GetExecutionProviderType() != kCpuExecutionProvider) {
session_state.RemoveSubgraphSessionState(node.Index());
if (node.ContainsSubgraph()) {
const auto ep = node.GetExecutionProviderType();
if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) {
session_state.RemoveSubgraphSessionState(node.Index());
continue;
}
} else {
// not a control flow node
continue;
}
@ -526,7 +532,7 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
// &*subgraph_info.session_state);
// setup all the info for handling the feeds and fetches used in subraph execution
// setup all the info for handling the feeds and fetches used in subgraph execution
auto* p_op_kernel = session_state.GetMutableKernel(node.Index());
ORT_ENFORCE(p_op_kernel);
auto& control_flow_kernel = dynamic_cast<controlflow::IControlFlowKernel&>(*p_op_kernel);

View file

@ -297,6 +297,12 @@ TEST(If, MixedExecutionProviders) {
RunTest(true, options);
}
TEST(If, MixedExecutionProvidersOpset11) {
RunOptions options{};
options.mixed_execution_providers = true;
RunTest(true, options, false, test::OpTester::ExpectResult::kExpectSuccess, "", 11);
}
TEST(If, MixedExecutionProvidersNoShapeInSubgraph) {
RunOptions options{};
options.mixed_execution_providers = true;