Register CPU OptionalGetElement, OptionalHasElement on DirectML (#15926)

Register CPU OptionalGetElement, OptionalHasElement on DirectML

Graphs with OptionalGetElement and OptionalHasElement should work in a
DML graph without extra memcpy operation on and off the GPU.

CopyCpuTensor is swapped with DataTransferManager.CopyTensor() to make
the CPU operator usable by other providers.

---------

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
Sheil Kumar 2023-05-15 09:53:35 -07:00 committed by GitHub
parent 18133ddadb
commit fa16e2e0f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 13 deletions

View file

@ -1030,6 +1030,10 @@ Do not modify directly.*
|Not|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bool)|
|OneHot|*in* indices:**T1**<br> *in* depth:**T2**<br> *in* values:**T3**<br> *out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|OptionalGetElement|*in* input:**O**<br> *out* output:**V**|18+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||15+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))<br/> **V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|OptionalHasElement|*in* input:**O**<br> *out* output:**B**|18+|**B** = tensor(bool)<br/> **O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||15+|**B** = tensor(bool)<br/> **O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))|
|Or|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)|
|PRelu|*in* X:**T**<br> *in* slope:**T**<br> *out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|

View file

@ -61,7 +61,8 @@ ONNX_CPU_OPERATOR_KERNEL(OptionalGetElement,
static void CopySequenceTensor(AllocatorPtr alloc,
const TensorSeq* src,
TensorSeq* tgt) {
TensorSeq* tgt,
const DataTransferManager& data_transfer_mgr) {
// The static allocation planner has deemed that the input can be re-used as the output
// Analogy: Checking if data pointers for the input and output Tensors are the same
// before proceeding to make the copy.
@ -76,13 +77,16 @@ static void CopySequenceTensor(AllocatorPtr alloc,
for (; in_tensor != src->end(); ++in_tensor) {
auto& tensor = in_tensor->Get<Tensor>();
Tensor tmp(tensor.DataType(), tensor.Shape(), alloc);
CopyCpuTensor(&tensor, &tmp);
// Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops
(void)data_transfer_mgr.CopyTensor(tensor, tmp);
tgt->Add(std::move(tmp));
}
}
static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_value,
OpKernelContext* ctx) {
OpKernelContext* ctx,
const DataTransferManager& data_transfer_mgr) {
if (input_ort_value->IsTensor()) {
const auto* input_tensor = &input_ort_value->Get<Tensor>();
auto* output_tensor = ctx->Output(0, input_tensor->Shape());
@ -90,10 +94,9 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu
// If the allocation planner had deemed that we re-use the input OrtValue
// as the output OrtValue, the data pointers in the input_tensor and the
// output_tensor will be the same and the copy is a no-op.
// CopyCpuTensor() already has such copy optimizations - so
// DataTransferManager.CopyTensor() already has such copy optimizations - so
// just re-use it.
CopyCpuTensor(input_tensor, output_tensor);
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(*input_tensor, *output_tensor));
} else if (input_ort_value->IsTensorSequence()) {
const auto* input_tensor_sequence = &input_ort_value->Get<TensorSeq>();
auto* output_tensor_sequence = ctx->Output<TensorSeq>(0);
@ -105,7 +108,7 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu
// as the output OrtValue, the pointers of the source TensorSeq and the
// target TensorSeq will be the same and the copy is a no-op.
// CopySequenceTensor() already has such copy optimizations
CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence);
CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence, data_transfer_mgr);
} else {
// Will not reach here
@ -132,7 +135,8 @@ Status Optional::Compute(OpKernelContext* ctx) const {
if (input_ort_value != nullptr) {
// An input was provided by the user - so just propagate it to the output
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx));
const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager();
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr));
} else { // No input was provided - we use the type proto to construct the output OrtValue
@ -176,7 +180,8 @@ Status OptionalGetElement::Compute(OpKernelContext* ctx) const {
}
// Propagate input to the output
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx));
const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager();
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr));
return Status::OK();
}

View file

@ -624,18 +624,20 @@ namespace Dml
bool IsCpuOnDmlOperator(const onnxruntime::Node& node)
{
auto sequence_ops = std::array<char*, 6>{
auto cpuOnDmlOperators = std::array<char*, 8>{
"SequenceAt",
"SequenceConstruct",
"SequenceEmpty",
"SequenceLength",
"SequenceErase",
"SequenceInsert"
"SequenceInsert",
"OptionalGetElement",
"OptionalHasElement"
};
for (auto& sequence_op : sequence_ops)
for (auto& cpuOnDmlOperator : cpuOnDmlOperators)
{
if (strcmp(sequence_op, node.OpType().c_str()) == 0)
if (strcmp(cpuOnDmlOperator, node.OpType().c_str()) == 0)
{
return true;
}

View file

@ -20,6 +20,7 @@ using namespace Microsoft::WRL;
#include "core/framework/TensorSeq.h"
#include "core/providers/cpu/sequence/sequence_ops.h"
#include "core/providers/cpu/tensor/concatbase.h"
#include "core/providers/cpu/optional/optional_ops.h"
namespace onnxruntime {
@ -30,6 +31,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, Se
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceErase);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceInsert);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalHasElement);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalGetElement);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalHasElement);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalGetElement);
}
@ -105,6 +110,53 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType<int64_t>()}),
SequenceInsert);
ONNX_OPERATOR_KERNEL_EX(
OptionalHasElement,
kOnnxDomain,
15,
kDmlExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("O", DataTypeImpl::AllOptionalTypes())
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()),
OptionalHasElement);
ONNX_OPERATOR_KERNEL_EX(
OptionalGetElement,
kOnnxDomain,
15,
kDmlExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("O", DataTypeImpl::AllOptionalTypes())
.TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
// We may be able to re-use the input for the output as is unless the output
// is a graph output. We provide this hint to the allocation planner
// to make the re-use call.
.Alias(0, 0),
OptionalGetElement);
ONNX_OPERATOR_KERNEL_EX(
OptionalHasElement,
kOnnxDomain,
18,
kDmlExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes())
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()),
OptionalHasElement);
ONNX_OPERATOR_KERNEL_EX(
OptionalGetElement,
kOnnxDomain,
18,
kDmlExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes())
.TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
// We may be able to re-use the input for the output as is unless the output
// is a graph output. We provide this hint to the allocation planner
// to make the re-use call.
.Alias(0, 0),
OptionalGetElement);
}
namespace Dml
@ -925,6 +977,10 @@ void RegisterCpuOperatorsAsDml(onnxruntime::KernelRegistry* registry)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceLength)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceErase)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceInsert)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalHasElement)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalGetElement)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalHasElement)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalGetElement)>,
};
for (auto& function_table_entry : function_table) {