User/linneamay/roi align 16 (#15812)

### Description
<!-- Describe your changes. -->
Add registration for DML RoiAlign-16 and tests for new
coordinate_transform_mode attribute. PR
[7354](https://github.com/microsoft/onnxruntime/pull/7354) is still open
to fix the CPU EP version, which is why there are skipped tests right
now. That will be completed separately so that, for now, we can
officially support opset16 with the next release.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Linnea May <linneamay@microsoft.com>
Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
Linnea May 2023-05-09 21:56:41 -07:00 committed by GitHub
parent c7b27f4486
commit 95a4607dcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 450 additions and 11 deletions

View file

@ -1102,7 +1102,8 @@ Do not modify directly.*
|||11+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|||10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|STFT|*in* signal:**T1**<br> *in* frame_step:**T2**<br> *in* window:**T1**<br> *in* frame_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|

View file

@ -29,14 +29,25 @@ public:
{"max", DML_REDUCE_FUNCTION_MAX},
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
};
constexpr NameAndIndex coordinateTransformationModes[] =
{
{"half_pixel", 0},
{"output_half_pixel", 1},
};
std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::CoordinateTransformationMode, "half_pixel");
auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes);
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
const auto optionalReductionFunction = TryMapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");
ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode.");
DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ROITensor = &inputDescs[1];
operatorDesc.BatchIndicesTensor = &inputDescs[2];
@ -48,12 +59,15 @@ public:
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
operatorDesc.ReductionFunction = *optionalReductionFunction;
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f;
operatorDesc.OutputPixelOffset = -0.5f;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorRegionOfInterestAlign, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel<DmlOperatorRegionOfInterestAlign, 16>);
} // namespace Dml

View file

@ -206,6 +206,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
@ -551,6 +552,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
@ -807,10 +809,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},

View file

@ -1450,6 +1450,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper<RoiAlignHelper, 16>;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;

View file

@ -381,6 +381,7 @@ namespace OperatorHelper
static const int sc_sinceVer_LessOrEqual = 16;
static const int sc_sinceVer_ScatterND = 16;
static const int sc_sinceVer_ScatterElements = 16;
static const int sc_sinceVer_RoiAlign = 16;
} // namespace OnnxOperatorSet16
namespace OnnxOperatorSet17

View file

@ -800,6 +800,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
{"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
{"test_roialign_aligned_true", "Opset 16 not supported yet."},
{"test_roialign_aligned_false", "Opset 16 not supported yet."},
{"test_roialign_mode_max", "Onnx roialign mode expected output is incorrect."},
{"test_scatternd_add", "Opset 16 not supported yet."},
{"test_scatternd_multiply", "Opset 16 not supported yet."},
{"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."},

View file

@ -9,11 +9,10 @@ namespace onnxruntime {
namespace test {
TEST(RoiAlignTest, AvgModePositive) {
// TODO: Unskip when fixed #41968513
// TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.9583299160003662, which exceeds threshold";
}
OpTester test("RoiAlign", 10);
test.AddAttribute<int64_t>("output_height", 3);
test.AddAttribute<int64_t>("output_width", 4);
@ -30,7 +29,241 @@ TEST(RoiAlignTest, AvgModePositive) {
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
// As per ORT issue https://github.com/microsoft/onnxruntime/issues/6921, the above output values are INCORRECT.
// DML has the correct outputs, which are defined below.
/*test.AddOutput<float>("Y", {5, 3, 3, 4}, {
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
});*/
test.Run();
}
TEST(RoiAlignTest, AvgModePositive_half_pixel) {
OpTester test("RoiAlign", 16);
test.AddAttribute<int64_t>("output_height", 3);
test.AddAttribute<int64_t>("output_width", 4);
test.AddAttribute<int64_t>("sampling_ratio", 2);
test.AddAttribute<float>("spatial_scale", 1.0f / 16.0f);
test.AddAttribute<std::string>("coordinate_transformation_mode", "half_pixel");
constexpr int N = 1;
constexpr int C = 3;
constexpr int H = 5;
constexpr int W = 5;
std::vector<float> rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
test.AddInput<float>("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput<float>("Y", {5, 3, 3, 4}, {0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000});
test.Run();
}
TEST(RoiAlignTest, AvgModePositive_output_half_pixel) {
// TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.95832991600036621, which exceeds threshold";
}
OpTester test("RoiAlign", 16);
test.AddAttribute<int64_t>("output_height", 3);
test.AddAttribute<int64_t>("output_width", 4);
test.AddAttribute<int64_t>("sampling_ratio", 2);
test.AddAttribute<float>("spatial_scale", 1.0f / 16.0f);
test.AddAttribute<std::string>("coordinate_transformation_mode", "output_half_pixel");
constexpr int N = 1;
constexpr int C = 3;
constexpr int H = 5;
constexpr int W = 5;
std::vector<float> rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
test.AddInput<float>("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
test.Run();
}
@ -230,12 +463,11 @@ static void BasicTest() {
0.3661f,
0.2349f,
});
test.Run();
}
TEST(RoiAlignTest, OnnxTest) {
// TODO: Unskip when fixed #41968513
// TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.051382988691329956, which exceeds threshold";
}
@ -245,7 +477,7 @@ TEST(RoiAlignTest, OnnxTest) {
}
TEST(RoiAlignTest, MaxModePositive) {
// TODO: Unskip when fixed #41968513
// TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.1093800067901611, which exceeds threshold";
}
@ -267,10 +499,196 @@ TEST(RoiAlignTest, MaxModePositive) {
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.10938f, 2.95313f, 3.375f, 2.53125f, 3.35938f, 4.70313f, 5.375f, 4.03125f, 3.51563f, 4.92188f, 5.625f, 4.21875f, 10.8984f, 15.2578f, 17.4375f, 13.0781f, 17.3568f, 24.2995f, 27.7708f, 20.8281f, 18.1641f, 25.4297f, 29.0625f, 21.7969f, 19.6875f, 27.5625f, 31.5f, 23.625f, 31.3542f, 43.8958f, 50.1667f, 37.625f, 32.8125f, 45.9375f, 52.5f, 39.375f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 5.625f, 5.625f, 5.625f, 4.57031f, 8.95833f, 8.95833f, 8.95833f, 7.27865f, 9.375f, 9.375f, 9.375f, 7.61719f, 19.6875f, 19.6875f, 19.6875f, 15.9961f, 31.3542f, 31.3542f, 31.3542f, 25.4753f, 32.8125f, 32.8125f, 32.8125f, 26.6602f, 33.75f, 33.75f, 33.75f, 27.4219f, 53.75f, 53.75f, 53.75f, 43.6719f, 56.25f, 56.25f, 56.25f, 45.7031f, 4.5f, 3.9375f, 2.8125f, 3.9375f, 5.5f, 4.8125f, 3.4375f, 4.8125f, 4.58333f, 4.01042f, 2.86458f, 3.9375f, 23.25f, 20.3438f, 14.5313f, 18.f, 28.4167f, 24.86458f, 17.76042f, 22.f, 23.25f, 20.3437f, 14.5312f, 18.f, 42.f, 36.75f, 26.25f, 32.0625f, 51.3333f, 44.9167f, 32.08333f, 39.1875f, 42.f, 36.75f, 26.25f, 32.0625f, 4.375f, 4.375f, 4.375f, 4.375f, 7.70833f, 7.70833f, 7.70833f, 7.70833f, 9.375f, 9.375f, 9.375f, 9.375f, 21.875f, 21.875f, 21.875f, 21.875f, 26.9792f, 26.9792f, 26.9792f, 26.9792f, 32.8125f, 32.8125f, 32.8125f, 32.8125f, 40.1042f, 40.1042f, 40.1042f, 40.1042f, 46.25f, 46.25f, 46.25f, 46.25f, 56.25f, 56.25f, 56.25f, 56.25f});
// As per ort issue #3428, the above output values are INCORRECT.
// DML has the correct outputs, which are defined below.
/*test.AddOutput<float>("Y",{5, 3, 3, 4}, {
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
27.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
52.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
25.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
50.0000,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
6.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
31.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
56.5625,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
3.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
28.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
53.3125,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
5.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
30.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
55.9375,
});*/
test.Run();
}
TEST(RoiAlignTest, AvgModeNegativeInvalidMode) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {

View file

@ -105,6 +105,7 @@
"^test_col2im_pads*", // remove this when using ONNX with this: https://github.com/onnx/onnx/pull/4769
// Following tests are for opset 16 ops and are not yet implemented in ORT
"^test_roialign_aligned_*",
"^test_roialign_mode_max", // TODO: Remove once onnx test is fixed
//GPU failures
"^test_batchnorm_epsilon_training_mode_cuda",
"^test_batchnorm_example_training_mode_cuda",