From 95a4607dcf330c128f6b6a2ffbcedb3193f6f81c Mon Sep 17 00:00:00 2001 From: Linnea May Date: Tue, 9 May 2023 21:56:41 -0700 Subject: [PATCH] User/linneamay/roi align 16 (#15812) ### Description Add registration for DML RoiAlign-16 and tests for new coordinate_transform_mode attribute. PR [7354](https://github.com/microsoft/onnxruntime/pull/7354) is still open to fix the CPU EP version, which is why there are skipped tests right now. That will be completed separately so that, for now, we can officially support opset16 with the next release. ### Motivation and Context --------- Co-authored-by: Linnea May Co-authored-by: Dwayne Robinson --- docs/OperatorKernels.md | 3 +- .../src/Operators/DmlOperatorRoiAlign.cpp | 18 +- .../src/Operators/OperatorRegistration.cpp | 6 +- .../dml/OperatorAuthorHelper/OperatorHelper.h | 1 + .../OperatorAuthorHelper/OperatorVersions.h | 1 + onnxruntime/test/onnx/main.cc | 1 + .../cpu/object_detection/roialign_test.cc | 430 +++++++++++++++++- .../onnx_backend_test_series_filters.jsonc | 1 + 8 files changed, 450 insertions(+), 11 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 4dc9e90377..97051e99ef 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1102,7 +1102,8 @@ Do not modify directly.* |||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| +|||10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| |STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp index c3a25ca8d4..892efca305 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp @@ -29,14 +29,25 @@ public: {"max", DML_REDUCE_FUNCTION_MAX}, {"avg", DML_REDUCE_FUNCTION_AVERAGE}, }; + + constexpr NameAndIndex coordinateTransformationModes[] = + { + {"half_pixel", 0}, + {"output_half_pixel", 1}, + }; + + std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute(AttrName::CoordinateTransformationMode, "half_pixel"); + auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes); const std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "avg"); const auto optionalReductionFunction = TryMapStringToIndex(mode, mapping); const float spatialScale = kernelCreationContext.GetOptionalAttribute(AttrName::SpatialScale, 1.0f); const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute(AttrName::SamplingRatio, 0u); ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive."); ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode."); + ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode."); - DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {}; + + DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = &inputDescs[0]; operatorDesc.ROITensor = &inputDescs[1]; operatorDesc.BatchIndicesTensor = &inputDescs[2]; @@ -48,12 +59,15 @@ public: operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput; operatorDesc.ReductionFunction = *optionalReductionFunction; operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc }; + operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f; + operatorDesc.OutputPixelOffset = -0.5f; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); } }; DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 6b8cdf65db..fd0ad8385f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -206,6 +206,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool); DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10); +DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16); DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization); DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization); DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15); @@ -551,6 +552,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)}, {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. @@ -807,10 +809,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16 {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, - {REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16 {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index f2be3cf05b..8629668bd4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1450,6 +1450,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper; using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper; using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper; +using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper; using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 5e2ca4cb11..054faad5ba 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -381,6 +381,7 @@ namespace OperatorHelper static const int sc_sinceVer_LessOrEqual = 16; static const int sc_sinceVer_ScatterND = 16; static const int sc_sinceVer_ScatterElements = 16; + static const int sc_sinceVer_RoiAlign = 16; } // namespace OnnxOperatorSet16 namespace OnnxOperatorSet17 diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 3ee2f4ddcf..a00e3bc4f2 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -800,6 +800,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); {"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."}, {"test_roialign_aligned_true", "Opset 16 not supported yet."}, {"test_roialign_aligned_false", "Opset 16 not supported yet."}, + {"test_roialign_mode_max", "Onnx roialign mode expected output is incorrect."}, {"test_scatternd_add", "Opset 16 not supported yet."}, {"test_scatternd_multiply", "Opset 16 not supported yet."}, {"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."}, diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index a7cc7c536a..ad9c561ffb 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -9,11 +9,10 @@ namespace onnxruntime { namespace test { TEST(RoiAlignTest, AvgModePositive) { - // TODO: Unskip when fixed #41968513 + // TODO: Unskip when fixed ort issue #3428 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.9583299160003662, which exceeds threshold"; } - OpTester test("RoiAlign", 10); test.AddAttribute("output_height", 3); test.AddAttribute("output_width", 4); @@ -30,7 +29,241 @@ TEST(RoiAlignTest, AvgModePositive) { test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f}); + // As per ORT issue https://github.com/microsoft/onnxruntime/issues/6921, the above output values are INCORRECT. + // DML has the correct outputs, which are defined below. + /*test.AddOutput("Y", {5, 3, 3, 4}, { + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + });*/ + test.Run(); +} + +TEST(RoiAlignTest, AvgModePositive_half_pixel) { + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f / 16.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 5; + constexpr int W = 5; + + std::vector rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.}; + test.AddInput("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.}); + test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); + test.AddOutput("Y", {5, 3, 3, 4}, {0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000}); + test.Run(); +} + +TEST(RoiAlignTest, AvgModePositive_output_half_pixel) { + // TODO: Unskip when fixed ort issue #3428 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.95832991600036621, which exceeds threshold"; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f / 16.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 5; + constexpr int W = 5; + + std::vector rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.}; + test.AddInput("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.}); + test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); + test.AddOutput("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f}); test.Run(); } @@ -230,12 +463,11 @@ static void BasicTest() { 0.3661f, 0.2349f, }); - test.Run(); } TEST(RoiAlignTest, OnnxTest) { - // TODO: Unskip when fixed #41968513 + // TODO: Unskip when fixed ort issue #3428 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.051382988691329956, which exceeds threshold"; } @@ -245,7 +477,7 @@ TEST(RoiAlignTest, OnnxTest) { } TEST(RoiAlignTest, MaxModePositive) { - // TODO: Unskip when fixed #41968513 + // TODO: Unskip when fixed ort issue #3428 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.1093800067901611, which exceeds threshold"; } @@ -267,10 +499,196 @@ TEST(RoiAlignTest, MaxModePositive) { test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5, 3, 3, 4}, {2.10938f, 2.95313f, 3.375f, 2.53125f, 3.35938f, 4.70313f, 5.375f, 4.03125f, 3.51563f, 4.92188f, 5.625f, 4.21875f, 10.8984f, 15.2578f, 17.4375f, 13.0781f, 17.3568f, 24.2995f, 27.7708f, 20.8281f, 18.1641f, 25.4297f, 29.0625f, 21.7969f, 19.6875f, 27.5625f, 31.5f, 23.625f, 31.3542f, 43.8958f, 50.1667f, 37.625f, 32.8125f, 45.9375f, 52.5f, 39.375f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 5.625f, 5.625f, 5.625f, 4.57031f, 8.95833f, 8.95833f, 8.95833f, 7.27865f, 9.375f, 9.375f, 9.375f, 7.61719f, 19.6875f, 19.6875f, 19.6875f, 15.9961f, 31.3542f, 31.3542f, 31.3542f, 25.4753f, 32.8125f, 32.8125f, 32.8125f, 26.6602f, 33.75f, 33.75f, 33.75f, 27.4219f, 53.75f, 53.75f, 53.75f, 43.6719f, 56.25f, 56.25f, 56.25f, 45.7031f, 4.5f, 3.9375f, 2.8125f, 3.9375f, 5.5f, 4.8125f, 3.4375f, 4.8125f, 4.58333f, 4.01042f, 2.86458f, 3.9375f, 23.25f, 20.3438f, 14.5313f, 18.f, 28.4167f, 24.86458f, 17.76042f, 22.f, 23.25f, 20.3437f, 14.5312f, 18.f, 42.f, 36.75f, 26.25f, 32.0625f, 51.3333f, 44.9167f, 32.08333f, 39.1875f, 42.f, 36.75f, 26.25f, 32.0625f, 4.375f, 4.375f, 4.375f, 4.375f, 7.70833f, 7.70833f, 7.70833f, 7.70833f, 9.375f, 9.375f, 9.375f, 9.375f, 21.875f, 21.875f, 21.875f, 21.875f, 26.9792f, 26.9792f, 26.9792f, 26.9792f, 32.8125f, 32.8125f, 32.8125f, 32.8125f, 40.1042f, 40.1042f, 40.1042f, 40.1042f, 46.25f, 46.25f, 46.25f, 46.25f, 56.25f, 56.25f, 56.25f, 56.25f}); + // As per ort issue #3428, the above output values are INCORRECT. + // DML has the correct outputs, which are defined below. + /*test.AddOutput("Y",{5, 3, 3, 4}, { + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + });*/ test.Run(); } - TEST(RoiAlignTest, AvgModeNegativeInvalidMode) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 34e2087f1d..cd9a90ee8e 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -105,6 +105,7 @@ "^test_col2im_pads*", // remove this when using ONNX with this: https://github.com/onnx/onnx/pull/4769 // Following tests are for opset 16 ops and are not yet implemented in ORT "^test_roialign_aligned_*", + "^test_roialign_mode_max", // TODO: Remove once onnx test is fixed //GPU failures "^test_batchnorm_epsilon_training_mode_cuda", "^test_batchnorm_example_training_mode_cuda",