mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Register CPU OptionalGetElement, OptionalHasElement on DirectML (#15926)
Register CPU OptionalGetElement, OptionalHasElement on DirectML Graphs with OptionalGetElement and OptionalHasElement should work in a DML graph without extra memcpy operation on and off the GPU. CopyCpuTensor is swapped with DataTransferManager.CopyTensor() to make the CPU operator usable by other providers. --------- Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
parent
18133ddadb
commit
fa16e2e0f3
4 changed files with 80 additions and 13 deletions
|
|
@ -1030,6 +1030,10 @@ Do not modify directly.*
|
|||
|Not|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bool)|
|
||||
|OneHot|*in* indices:**T1**<br> *in* depth:**T2**<br> *in* values:**T3**<br> *out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|OptionalGetElement|*in* input:**O**<br> *out* output:**V**|18+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||15+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))<br/> **V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|OptionalHasElement|*in* input:**O**<br> *out* output:**B**|18+|**B** = tensor(bool)<br/> **O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||15+|**B** = tensor(bool)<br/> **O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))|
|
||||
|Or|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)|
|
||||
|PRelu|*in* X:**T**<br> *in* slope:**T**<br> *out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|
||||
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|
||||
|
|
|
|||
|
|
@ -61,7 +61,8 @@ ONNX_CPU_OPERATOR_KERNEL(OptionalGetElement,
|
|||
|
||||
static void CopySequenceTensor(AllocatorPtr alloc,
|
||||
const TensorSeq* src,
|
||||
TensorSeq* tgt) {
|
||||
TensorSeq* tgt,
|
||||
const DataTransferManager& data_transfer_mgr) {
|
||||
// The static allocation planner has deemed that the input can be re-used as the output
|
||||
// Analogy: Checking if data pointers for the input and output Tensors are the same
|
||||
// before proceeding to make the copy.
|
||||
|
|
@ -76,13 +77,16 @@ static void CopySequenceTensor(AllocatorPtr alloc,
|
|||
for (; in_tensor != src->end(); ++in_tensor) {
|
||||
auto& tensor = in_tensor->Get<Tensor>();
|
||||
Tensor tmp(tensor.DataType(), tensor.Shape(), alloc);
|
||||
CopyCpuTensor(&tensor, &tmp);
|
||||
// Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops
|
||||
(void)data_transfer_mgr.CopyTensor(tensor, tmp);
|
||||
|
||||
tgt->Add(std::move(tmp));
|
||||
}
|
||||
}
|
||||
|
||||
static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_value,
|
||||
OpKernelContext* ctx) {
|
||||
OpKernelContext* ctx,
|
||||
const DataTransferManager& data_transfer_mgr) {
|
||||
if (input_ort_value->IsTensor()) {
|
||||
const auto* input_tensor = &input_ort_value->Get<Tensor>();
|
||||
auto* output_tensor = ctx->Output(0, input_tensor->Shape());
|
||||
|
|
@ -90,10 +94,9 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu
|
|||
// If the allocation planner had deemed that we re-use the input OrtValue
|
||||
// as the output OrtValue, the data pointers in the input_tensor and the
|
||||
// output_tensor will be the same and the copy is a no-op.
|
||||
// CopyCpuTensor() already has such copy optimizations - so
|
||||
// DataTransferManager.CopyTensor() already has such copy optimizations - so
|
||||
// just re-use it.
|
||||
CopyCpuTensor(input_tensor, output_tensor);
|
||||
|
||||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(*input_tensor, *output_tensor));
|
||||
} else if (input_ort_value->IsTensorSequence()) {
|
||||
const auto* input_tensor_sequence = &input_ort_value->Get<TensorSeq>();
|
||||
auto* output_tensor_sequence = ctx->Output<TensorSeq>(0);
|
||||
|
|
@ -105,7 +108,7 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu
|
|||
// as the output OrtValue, the pointers of the source TensorSeq and the
|
||||
// target TensorSeq will be the same and the copy is a no-op.
|
||||
// CopySequenceTensor() already has such copy optimizations
|
||||
CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence);
|
||||
CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence, data_transfer_mgr);
|
||||
|
||||
} else {
|
||||
// Will not reach here
|
||||
|
|
@ -132,7 +135,8 @@ Status Optional::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
if (input_ort_value != nullptr) {
|
||||
// An input was provided by the user - so just propagate it to the output
|
||||
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx));
|
||||
const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager();
|
||||
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr));
|
||||
|
||||
} else { // No input was provided - we use the type proto to construct the output OrtValue
|
||||
|
||||
|
|
@ -176,7 +180,8 @@ Status OptionalGetElement::Compute(OpKernelContext* ctx) const {
|
|||
}
|
||||
|
||||
// Propagate input to the output
|
||||
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx));
|
||||
const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager();
|
||||
ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -624,18 +624,20 @@ namespace Dml
|
|||
|
||||
bool IsCpuOnDmlOperator(const onnxruntime::Node& node)
|
||||
{
|
||||
auto sequence_ops = std::array<char*, 6>{
|
||||
auto cpuOnDmlOperators = std::array<char*, 8>{
|
||||
"SequenceAt",
|
||||
"SequenceConstruct",
|
||||
"SequenceEmpty",
|
||||
"SequenceLength",
|
||||
"SequenceErase",
|
||||
"SequenceInsert"
|
||||
"SequenceInsert",
|
||||
"OptionalGetElement",
|
||||
"OptionalHasElement"
|
||||
};
|
||||
|
||||
for (auto& sequence_op : sequence_ops)
|
||||
for (auto& cpuOnDmlOperator : cpuOnDmlOperators)
|
||||
{
|
||||
if (strcmp(sequence_op, node.OpType().c_str()) == 0)
|
||||
if (strcmp(cpuOnDmlOperator, node.OpType().c_str()) == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ using namespace Microsoft::WRL;
|
|||
#include "core/framework/TensorSeq.h"
|
||||
#include "core/providers/cpu/sequence/sequence_ops.h"
|
||||
#include "core/providers/cpu/tensor/concatbase.h"
|
||||
#include "core/providers/cpu/optional/optional_ops.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -30,6 +31,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, Se
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceErase);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceInsert);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalHasElement);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalGetElement);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalHasElement);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalGetElement);
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -105,6 +110,53 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
SequenceInsert);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
OptionalHasElement,
|
||||
kOnnxDomain,
|
||||
15,
|
||||
kDmlExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("O", DataTypeImpl::AllOptionalTypes())
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()),
|
||||
OptionalHasElement);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
OptionalGetElement,
|
||||
kOnnxDomain,
|
||||
15,
|
||||
kDmlExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("O", DataTypeImpl::AllOptionalTypes())
|
||||
.TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
|
||||
// We may be able to re-use the input for the output as is unless the output
|
||||
// is a graph output. We provide this hint to the allocation planner
|
||||
// to make the re-use call.
|
||||
.Alias(0, 0),
|
||||
OptionalGetElement);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
OptionalHasElement,
|
||||
kOnnxDomain,
|
||||
18,
|
||||
kDmlExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes())
|
||||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()),
|
||||
OptionalHasElement);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
OptionalGetElement,
|
||||
kOnnxDomain,
|
||||
18,
|
||||
kDmlExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes())
|
||||
.TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
|
||||
// We may be able to re-use the input for the output as is unless the output
|
||||
// is a graph output. We provide this hint to the allocation planner
|
||||
// to make the re-use call.
|
||||
.Alias(0, 0),
|
||||
OptionalGetElement);
|
||||
}
|
||||
|
||||
namespace Dml
|
||||
|
|
@ -925,6 +977,10 @@ void RegisterCpuOperatorsAsDml(onnxruntime::KernelRegistry* registry)
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceLength)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceErase)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceInsert)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalHasElement)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalGetElement)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalHasElement)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalGetElement)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue