diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 97051e99ef..7f438b6d86 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1030,6 +1030,10 @@ Do not modify directly.*
|Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)|
|OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|OptionalGetElement|*in* input:**O**
*out* output:**V**|18+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||15+|**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|OptionalHasElement|*in* input:**O**
*out* output:**B**|18+|**B** = tensor(bool)
**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||15+|**B** = tensor(bool)
**O** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8))|
|Or|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)|
|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
diff --git a/onnxruntime/core/providers/cpu/optional/optional_ops.cc b/onnxruntime/core/providers/cpu/optional/optional_ops.cc
index 33e9401935..8bdc783059 100644
--- a/onnxruntime/core/providers/cpu/optional/optional_ops.cc
+++ b/onnxruntime/core/providers/cpu/optional/optional_ops.cc
@@ -61,7 +61,8 @@ ONNX_CPU_OPERATOR_KERNEL(OptionalGetElement,
static void CopySequenceTensor(AllocatorPtr alloc,
const TensorSeq* src,
- TensorSeq* tgt) {
+ TensorSeq* tgt,
+ const DataTransferManager& data_transfer_mgr) {
// The static allocation planner has deemed that the input can be re-used as the output
// Analogy: Checking if data pointers for the input and output Tensors are the same
// before proceeding to make the copy.
@@ -76,13 +77,16 @@ static void CopySequenceTensor(AllocatorPtr alloc,
for (; in_tensor != src->end(); ++in_tensor) {
auto& tensor = in_tensor->Get();
Tensor tmp(tensor.DataType(), tensor.Shape(), alloc);
- CopyCpuTensor(&tensor, &tmp);
+ // Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops
+ (void)data_transfer_mgr.CopyTensor(tensor, tmp);
+
tgt->Add(std::move(tmp));
}
}
static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_value,
- OpKernelContext* ctx) {
+ OpKernelContext* ctx,
+ const DataTransferManager& data_transfer_mgr) {
if (input_ort_value->IsTensor()) {
const auto* input_tensor = &input_ort_value->Get();
auto* output_tensor = ctx->Output(0, input_tensor->Shape());
@@ -90,10 +94,9 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu
// If the allocation planner had deemed that we re-use the input OrtValue
// as the output OrtValue, the data pointers in the input_tensor and the
// output_tensor will be the same and the copy is a no-op.
- // CopyCpuTensor() already has such copy optimizations - so
+ // DataTransferManager.CopyTensor() already has such copy optimizations - so
// just re-use it.
- CopyCpuTensor(input_tensor, output_tensor);
-
+ ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(*input_tensor, *output_tensor));
} else if (input_ort_value->IsTensorSequence()) {
const auto* input_tensor_sequence = &input_ort_value->Get();
auto* output_tensor_sequence = ctx->Output(0);
@@ -105,7 +108,7 @@ static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_valu
// as the output OrtValue, the pointers of the source TensorSeq and the
// target TensorSeq will be the same and the copy is a no-op.
// CopySequenceTensor() already has such copy optimizations
- CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence);
+ CopySequenceTensor(alloc, input_tensor_sequence, output_tensor_sequence, data_transfer_mgr);
} else {
// Will not reach here
@@ -132,7 +135,8 @@ Status Optional::Compute(OpKernelContext* ctx) const {
if (input_ort_value != nullptr) {
// An input was provided by the user - so just propagate it to the output
- ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx));
+ const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager();
+ ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr));
} else { // No input was provided - we use the type proto to construct the output OrtValue
@@ -176,7 +180,8 @@ Status OptionalGetElement::Compute(OpKernelContext* ctx) const {
}
// Propagate input to the output
- ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx));
+ const DataTransferManager& data_transfer_mgr = Info().GetDataTransferManager();
+ ORT_RETURN_IF_ERROR(PropagateInputOrtValueToFirstOutput(input_ort_value, ctx, data_transfer_mgr));
return Status::OK();
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
index 7edd2592dd..411e48722a 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
@@ -624,18 +624,20 @@ namespace Dml
bool IsCpuOnDmlOperator(const onnxruntime::Node& node)
{
- auto sequence_ops = std::array{
+ auto cpuOnDmlOperators = std::array{
"SequenceAt",
"SequenceConstruct",
"SequenceEmpty",
"SequenceLength",
"SequenceErase",
- "SequenceInsert"
+ "SequenceInsert",
+ "OptionalGetElement",
+ "OptionalHasElement"
};
- for (auto& sequence_op : sequence_ops)
+ for (auto& cpuOnDmlOperator : cpuOnDmlOperators)
{
- if (strcmp(sequence_op, node.OpType().c_str()) == 0)
+ if (strcmp(cpuOnDmlOperator, node.OpType().c_str()) == 0)
{
return true;
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index fd0ad8385f..53bce5c715 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -20,6 +20,7 @@ using namespace Microsoft::WRL;
#include "core/framework/TensorSeq.h"
#include "core/providers/cpu/sequence/sequence_ops.h"
#include "core/providers/cpu/tensor/concatbase.h"
+#include "core/providers/cpu/optional/optional_ops.h"
namespace onnxruntime {
@@ -30,6 +31,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, Se
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceErase);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceInsert);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalHasElement);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 15, OptionalGetElement);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalHasElement);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 18, OptionalGetElement);
}
@@ -105,6 +110,53 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType()}),
SequenceInsert);
+ONNX_OPERATOR_KERNEL_EX(
+ OptionalHasElement,
+ kOnnxDomain,
+ 15,
+ kDmlExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("O", DataTypeImpl::AllOptionalTypes())
+ .TypeConstraint("B", DataTypeImpl::GetTensorType()),
+ OptionalHasElement);
+
+ONNX_OPERATOR_KERNEL_EX(
+ OptionalGetElement,
+ kOnnxDomain,
+ 15,
+ kDmlExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("O", DataTypeImpl::AllOptionalTypes())
+ .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
+ // We may be able to re-use the input for the output as is unless the output
+ // is a graph output. We provide this hint to the allocation planner
+ // to make the re-use call.
+ .Alias(0, 0),
+ OptionalGetElement);
+
+ONNX_OPERATOR_KERNEL_EX(
+ OptionalHasElement,
+ kOnnxDomain,
+ 18,
+ kDmlExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes())
+ .TypeConstraint("B", DataTypeImpl::GetTensorType()),
+ OptionalHasElement);
+
+ONNX_OPERATOR_KERNEL_EX(
+ OptionalGetElement,
+ kOnnxDomain,
+ 18,
+ kDmlExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("O", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypes())
+ .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
+ // We may be able to re-use the input for the output as is unless the output
+ // is a graph output. We provide this hint to the allocation planner
+ // to make the re-use call.
+ .Alias(0, 0),
+ OptionalGetElement);
}
namespace Dml
@@ -925,6 +977,10 @@ void RegisterCpuOperatorsAsDml(onnxruntime::KernelRegistry* registry)
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
};
for (auto& function_table_entry : function_table) {