diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 97051e99ef..7f438b6d86 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1030,6 +1030,10 @@ Do not modify directly.* |Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)| |OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|OptionalGetElement|*in* input:**O**
*out* output:**V**|18+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||15+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|OptionalHasElement|*in* input:**O**
*out* output:**B**|18+|**B** = tensor(bool)
**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||15+|**B** = tensor(bool)
**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))| |Or|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)| |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| diff --git a/onnxruntime/core/providers/cpu/optional/optional_ops.cc b/onnxruntime/core/providers/cpu/optional/optional_ops.cc index 33e9401935..8bdc783059 100644 --- a/onnxruntime/core/providers/cpu/optional/optional_ops.cc +++ b/onnxruntime/core/providers/cpu/optional/optional_ops.cc @@ -61,7 +61,8 @@ ONNX_CPU_OPERATOR_KERNEL(OptionalGetElement, static void CopySequenceTensor(AllocatorPtr alloc, const TensorSeq* src, - TensorSeq* tgt) { + TensorSeq* tgt, + const DataTransferManager& data_transfer_mgr) { // The static allocation planner has deemed that the input can be re-used as the output // Analogy: Checking if data pointers for the input and output Tensors are the same // before proceeding to make the copy. @@ -76,13 +77,16 @@ static void CopySequenceTensor(AllocatorPtr alloc, for (; in_tensor != src->end(); ++in_tensor) { auto& tensor = in_tensor->Get(); Tensor tmp(tensor.DataType(), tensor.Shape(), alloc); - CopyCpuTensor(&tensor, &tmp); + // Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops + (void)data_transfer_mgr.CopyTensor(tensor, tmp); + tgt->Add(std::move(tmp)); } } static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_value, - OpKernelContext* ctx) { + OpKernelContext* ctx, + const DataTransferManager& data_transfer_mgr) { if (input_ort_value->IsTensor()) { const auto* input_tensor = &input_ort_value->Get(); auto* output_tensor = ctx->Output(0, input_tensor->Shape()); @@ -90,10 +94,9 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu // If the allocation planner had deemed that we re-use the input OrtValue // as the output OrtValue, the data pointers in the input_tensor and the // output_tensor will be the same and the copy is a no-op. - // CopyCpuTensor() already has such copy optimizations - so + // DataTransferManager.CopyTensor() already has such copy optimizations - so // just re-use it. - CopyCpuTensor(input_tensor, output_tensor); - + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(*input_tensor, *output_tensor)); } else if (input_ort_value->IsTensorSequence()) { const auto* input_tensor_sequence = &input_ort_value->Get(); auto* output_tensor_sequence = ctx->Output(0); @@ -105,7 +108,7 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu // as the output OrtValue, the pointers of the source TensorSeq and the // target TensorSeq will be the same and the copy is a no-op. // CopySequenceTensor() already has such copy optimizations - CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence); + CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence, data_transfer_mgr); } else { // Will not reach here @@ -132,7 +135,8 @@ Status Optional::Compute(OpKernelContext* ctx) const { if (input_ort_value != nullptr) { // An input was provided by the user - so just propagate it to the output - ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx)); + const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager(); + ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr)); } else { // No input was provided - we use the type proto to construct the output OrtValue @@ -176,7 +180,8 @@ Status OptionalGetElement::Compute(OpKernelContext* ctx) const { } // Propagate input to the output - ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx)); + const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager(); + ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr)); return Status::OK(); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 7edd2592dd..411e48722a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -624,18 +624,20 @@ namespace Dml bool IsCpuOnDmlOperator(const onnxruntime::Node& node) { - auto sequence_ops = std::array{ + auto cpuOnDmlOperators = std::array{ "SequenceAt", "SequenceConstruct", "SequenceEmpty", "SequenceLength", "SequenceErase", - "SequenceInsert" + "SequenceInsert", + "OptionalGetElement", + "OptionalHasElement" }; - for (auto& sequence_op : sequence_ops) + for (auto& cpuOnDmlOperator : cpuOnDmlOperators) { - if (strcmp(sequence_op, node.OpType().c_str()) == 0) + if (strcmp(cpuOnDmlOperator, node.OpType().c_str()) == 0) { return true; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index fd0ad8385f..53bce5c715 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -20,6 +20,7 @@ using namespace Microsoft::WRL; #include "core/framework/TensorSeq.h" #include "core/providers/cpu/sequence/sequence_ops.h" #include "core/providers/cpu/tensor/concatbase.h" +#include "core/providers/cpu/optional/optional_ops.h" namespace onnxruntime { @@ -30,6 +31,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, Se class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, ConcatFromSequence); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceErase); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceInsert); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalHasElement); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalGetElement); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalHasElement); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalGetElement); } @@ -105,6 +110,53 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType()}), SequenceInsert); +ONNX_OPERATOR_KERNEL_EX( + OptionalHasElement, + kOnnxDomain, + 15, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("O", DataTypeImpl::AllOptionalTypes()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()), + OptionalHasElement); + +ONNX_OPERATOR_KERNEL_EX( + OptionalGetElement, + kOnnxDomain, + 15, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("O", DataTypeImpl::AllOptionalTypes()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()) + // We may be able to re-use the input for the output as is unless the output + // is a graph output. We provide this hint to the allocation planner + // to make the re-use call. + .Alias(0, 0), + OptionalGetElement); + +ONNX_OPERATOR_KERNEL_EX( + OptionalHasElement, + kOnnxDomain, + 18, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()), + OptionalHasElement); + +ONNX_OPERATOR_KERNEL_EX( + OptionalGetElement, + kOnnxDomain, + 18, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()) + // We may be able to re-use the input for the output as is unless the output + // is a graph output. We provide this hint to the allocation planner + // to make the re-use call. + .Alias(0, 0), + OptionalGetElement); } namespace Dml @@ -925,6 +977,10 @@ void RegisterCpuOperatorsAsDml(onnxruntime::KernelRegistry* registry) BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) {