From 2c3697be00dfe074d2368637c4830cca89957c04 Mon Sep 17 00:00:00 2001 From: Linnea May Date: Thu, 27 Apr 2023 20:32:11 -0700 Subject: [PATCH] User/linneamay/reduce 18 (#15701) ### Description Add registration for DML reduce functions in opset 18. ### Motivation and Context --------- Co-authored-by: Linnea May --- docs/OperatorKernels.md | 27 ++++++++++++------- .../src/Operators/OperatorRegistration.cpp | 9 +++++++ .../OperatorAuthorHelper/OperatorVersions.h | 13 +++++++++ .../cpu/reduction/reduction_ops_test.cc | 10 +++---- 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index cfd15f5cc2..114ab09b5a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1049,36 +1049,45 @@ Do not modify directly.* |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(float), tensor(int16), tensor(int32), tensor(int64)| |Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||1+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16)| +|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16)| +|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16)| +|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16)| +|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||1+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||1+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||1+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index be5a56007a..6fdbb5b099 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -718,32 +718,41 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, + {REG_INFO( 18, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported)}, {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported)}, {REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 77b297cd3e..5e2ca4cb11 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -388,6 +388,19 @@ namespace OperatorHelper static const int sc_sinceVer_LayerNormalization = 17; } // namespace OnnxOperatorSet17 + namespace OnnxOperatorSet18 + { + static const int sc_sinceVer_ReduceL1 = 18; + static const int sc_sinceVer_ReduceL2 = 18; + static const int sc_sinceVer_ReduceLogSum = 18; + static const int sc_sinceVer_ReduceLogSumExp = 18; + static const int sc_sinceVer_ReduceMax = 18; + static const int sc_sinceVer_ReduceMean = 18; + static const int sc_sinceVer_ReduceMin = 18; + static const int sc_sinceVer_ReduceProd = 18; + static const int sc_sinceVer_ReduceSumSquare = 18; + } + namespace MsftOperatorSet1 { static const int sc_sinceVer_DmlFusedConv = 1; diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index f0d6cf1269..d8841b4f5d 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -1004,7 +1004,7 @@ TEST(ReductionOpTest, ReduceMaxAxesInitializerOpset18) { test.AddOutput("reduced", {3, 1, 1}, {4.0f, 8.0f, 12.0f}); // TODO: DNNL, TensorRT, and OpenVINO dont support "axes" input in opset 18 test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDmlExecutionProvider}); } #if defined(USE_DNNL) @@ -1483,7 +1483,7 @@ TEST(ReductionOpTest, ReduceMeanAxesInitializerOpset18) { // TODO: DNNL, TensorRT, and OpenVINO dont support "axes" input in opset 18, re-enable after test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDmlExecutionProvider}); } #ifdef USE_DNNL @@ -1732,7 +1732,7 @@ TEST(ReductionOpTest, ReduceMinAxesInitializerOpset18) { test.AddOutput("reduced", {1, 2, 1}, {1.0f, 3.0f}); // TODO: DNNL, TensorRT, and OpenVINO dont support "axes" input in opset 18, re-enable after test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDmlExecutionProvider}); } #if defined(USE_DNNL) @@ -1908,7 +1908,7 @@ TEST(ReductionOpTest, ReduceSumAxesInitializerOpset13) { test.AddInput("axes", {2}, {0, 2}, true); test.AddOutput("reduced", {1, 2, 1}, {33.0f, 45.0f}); // TODO: TensorRT and OpenVINO dont support "axes" input in opset 13, re-enable after - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDmlExecutionProvider}); } TEST(ReductionOpTest, ReduceSum_double) { @@ -2697,7 +2697,7 @@ TEST(ReductionOpTest, ReduceProdAxesInitializerOpset18) { test.AddOutput("reduced", {1, 2, 1}, {5400.f, 88704.f}); // TODO: DNNL, TensorRT, and OpenVINO dont support "axes" input in opset 18, re-enable after test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + {kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDmlExecutionProvider}); } #if defined(USE_DNNL)