mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-15 01:23:42 +00:00
User/linneamay/roi align 16 (#15812)
### Description <!-- Describe your changes. --> Add registration for DML RoiAlign-16 and tests for new coordinate_transform_mode attribute. PR [7354](https://github.com/microsoft/onnxruntime/pull/7354) is still open to fix the CPU EP version, which is why there are skipped tests right now. That will be completed separately so that, for now, we can officially support opset16 with the next release. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Linnea May <linneamay@microsoft.com> Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
parent
c7b27f4486
commit
95a4607dcf
8 changed files with 450 additions and 11 deletions
|
|
@ -1102,7 +1102,8 @@ Do not modify directly.*
|
|||
|||11+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)|
|
||||
|||10+|**T** = tensor(float), tensor(float16)|
|
||||
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|||10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|
||||
|STFT|*in* signal:**T1**<br> *in* frame_step:**T2**<br> *in* window:**T1**<br> *in* frame_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -29,14 +29,25 @@ public:
|
|||
{"max", DML_REDUCE_FUNCTION_MAX},
|
||||
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
|
||||
};
|
||||
|
||||
constexpr NameAndIndex coordinateTransformationModes[] =
|
||||
{
|
||||
{"half_pixel", 0},
|
||||
{"output_half_pixel", 1},
|
||||
};
|
||||
|
||||
std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::CoordinateTransformationMode, "half_pixel");
|
||||
auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes);
|
||||
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
|
||||
const auto optionalReductionFunction = TryMapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
|
||||
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
|
||||
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
|
||||
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
|
||||
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");
|
||||
ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode.");
|
||||
|
||||
DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
|
||||
|
||||
DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.ROITensor = &inputDescs[1];
|
||||
operatorDesc.BatchIndicesTensor = &inputDescs[2];
|
||||
|
|
@ -48,12 +59,15 @@ public:
|
|||
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
|
||||
operatorDesc.ReductionFunction = *optionalReductionFunction;
|
||||
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
|
||||
operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f;
|
||||
operatorDesc.OutputPixelOffset = -0.5f;
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc };
|
||||
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorRegionOfInterestAlign, 10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel<DmlOperatorRegionOfInterestAlign, 16>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -206,6 +206,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
|
||||
|
|
@ -551,6 +552,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
|
||||
|
|
@ -807,10 +809,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
|
||||
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
|
||||
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -1450,6 +1450,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper;
|
|||
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
|
||||
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
|
||||
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
|
||||
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper<RoiAlignHelper, 16>;
|
||||
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
|
||||
|
|
|
|||
|
|
@ -381,6 +381,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_LessOrEqual = 16;
|
||||
static const int sc_sinceVer_ScatterND = 16;
|
||||
static const int sc_sinceVer_ScatterElements = 16;
|
||||
static const int sc_sinceVer_RoiAlign = 16;
|
||||
} // namespace OnnxOperatorSet16
|
||||
|
||||
namespace OnnxOperatorSet17
|
||||
|
|
|
|||
|
|
@ -800,6 +800,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
{"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
|
||||
{"test_roialign_aligned_true", "Opset 16 not supported yet."},
|
||||
{"test_roialign_aligned_false", "Opset 16 not supported yet."},
|
||||
{"test_roialign_mode_max", "Onnx roialign mode expected output is incorrect."},
|
||||
{"test_scatternd_add", "Opset 16 not supported yet."},
|
||||
{"test_scatternd_multiply", "Opset 16 not supported yet."},
|
||||
{"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."},
|
||||
|
|
|
|||
|
|
@ -9,11 +9,10 @@ namespace onnxruntime {
|
|||
namespace test {
|
||||
|
||||
TEST(RoiAlignTest, AvgModePositive) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
// TODO: Unskip when fixed ort issue #3428
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.9583299160003662, which exceeds threshold";
|
||||
}
|
||||
|
||||
OpTester test("RoiAlign", 10);
|
||||
test.AddAttribute<int64_t>("output_height", 3);
|
||||
test.AddAttribute<int64_t>("output_width", 4);
|
||||
|
|
@ -30,7 +29,241 @@ TEST(RoiAlignTest, AvgModePositive) {
|
|||
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
|
||||
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
|
||||
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
|
||||
// As per ORT issue https://github.com/microsoft/onnxruntime/issues/6921, the above output values are INCORRECT.
|
||||
// DML has the correct outputs, which are defined below.
|
||||
/*test.AddOutput<float>("Y", {5, 3, 3, 4}, {
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
});*/
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(RoiAlignTest, AvgModePositive_half_pixel) {
|
||||
OpTester test("RoiAlign", 16);
|
||||
test.AddAttribute<int64_t>("output_height", 3);
|
||||
test.AddAttribute<int64_t>("output_width", 4);
|
||||
test.AddAttribute<int64_t>("sampling_ratio", 2);
|
||||
test.AddAttribute<float>("spatial_scale", 1.0f / 16.0f);
|
||||
test.AddAttribute<std::string>("coordinate_transformation_mode", "half_pixel");
|
||||
|
||||
constexpr int N = 1;
|
||||
constexpr int C = 3;
|
||||
constexpr int H = 5;
|
||||
constexpr int W = 5;
|
||||
|
||||
std::vector<float> rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
|
||||
test.AddInput<float>("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
|
||||
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
|
||||
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
|
||||
test.AddOutput<float>("Y", {5, 3, 3, 4}, {0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(RoiAlignTest, AvgModePositive_output_half_pixel) {
|
||||
// TODO: Unskip when fixed ort issue #3428
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.95832991600036621, which exceeds threshold";
|
||||
}
|
||||
|
||||
OpTester test("RoiAlign", 16);
|
||||
test.AddAttribute<int64_t>("output_height", 3);
|
||||
test.AddAttribute<int64_t>("output_width", 4);
|
||||
test.AddAttribute<int64_t>("sampling_ratio", 2);
|
||||
test.AddAttribute<float>("spatial_scale", 1.0f / 16.0f);
|
||||
test.AddAttribute<std::string>("coordinate_transformation_mode", "output_half_pixel");
|
||||
|
||||
constexpr int N = 1;
|
||||
constexpr int C = 3;
|
||||
constexpr int H = 5;
|
||||
constexpr int W = 5;
|
||||
|
||||
std::vector<float> rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
|
||||
test.AddInput<float>("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
|
||||
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
|
||||
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
|
||||
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -230,12 +463,11 @@ static void BasicTest() {
|
|||
0.3661f,
|
||||
0.2349f,
|
||||
});
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(RoiAlignTest, OnnxTest) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
// TODO: Unskip when fixed ort issue #3428
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.051382988691329956, which exceeds threshold";
|
||||
}
|
||||
|
|
@ -245,7 +477,7 @@ TEST(RoiAlignTest, OnnxTest) {
|
|||
}
|
||||
|
||||
TEST(RoiAlignTest, MaxModePositive) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
// TODO: Unskip when fixed ort issue #3428
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.1093800067901611, which exceeds threshold";
|
||||
}
|
||||
|
|
@ -267,10 +499,196 @@ TEST(RoiAlignTest, MaxModePositive) {
|
|||
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
|
||||
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
|
||||
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.10938f, 2.95313f, 3.375f, 2.53125f, 3.35938f, 4.70313f, 5.375f, 4.03125f, 3.51563f, 4.92188f, 5.625f, 4.21875f, 10.8984f, 15.2578f, 17.4375f, 13.0781f, 17.3568f, 24.2995f, 27.7708f, 20.8281f, 18.1641f, 25.4297f, 29.0625f, 21.7969f, 19.6875f, 27.5625f, 31.5f, 23.625f, 31.3542f, 43.8958f, 50.1667f, 37.625f, 32.8125f, 45.9375f, 52.5f, 39.375f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 5.625f, 5.625f, 5.625f, 4.57031f, 8.95833f, 8.95833f, 8.95833f, 7.27865f, 9.375f, 9.375f, 9.375f, 7.61719f, 19.6875f, 19.6875f, 19.6875f, 15.9961f, 31.3542f, 31.3542f, 31.3542f, 25.4753f, 32.8125f, 32.8125f, 32.8125f, 26.6602f, 33.75f, 33.75f, 33.75f, 27.4219f, 53.75f, 53.75f, 53.75f, 43.6719f, 56.25f, 56.25f, 56.25f, 45.7031f, 4.5f, 3.9375f, 2.8125f, 3.9375f, 5.5f, 4.8125f, 3.4375f, 4.8125f, 4.58333f, 4.01042f, 2.86458f, 3.9375f, 23.25f, 20.3438f, 14.5313f, 18.f, 28.4167f, 24.86458f, 17.76042f, 22.f, 23.25f, 20.3437f, 14.5312f, 18.f, 42.f, 36.75f, 26.25f, 32.0625f, 51.3333f, 44.9167f, 32.08333f, 39.1875f, 42.f, 36.75f, 26.25f, 32.0625f, 4.375f, 4.375f, 4.375f, 4.375f, 7.70833f, 7.70833f, 7.70833f, 7.70833f, 9.375f, 9.375f, 9.375f, 9.375f, 21.875f, 21.875f, 21.875f, 21.875f, 26.9792f, 26.9792f, 26.9792f, 26.9792f, 32.8125f, 32.8125f, 32.8125f, 32.8125f, 40.1042f, 40.1042f, 40.1042f, 40.1042f, 46.25f, 46.25f, 46.25f, 46.25f, 56.25f, 56.25f, 56.25f, 56.25f});
|
||||
// As per ort issue #3428, the above output values are INCORRECT.
|
||||
// DML has the correct outputs, which are defined below.
|
||||
/*test.AddOutput<float>("Y",{5, 3, 3, 4}, {
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
27.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
52.0000,
|
||||
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
25.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
50.0000,
|
||||
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
6.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
31.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
56.5625,
|
||||
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
3.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
28.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
53.3125,
|
||||
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
5.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
30.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
55.9375,
|
||||
});*/
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(RoiAlignTest, AvgModeNegativeInvalidMode) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@
|
|||
"^test_col2im_pads*", // remove this when using ONNX with this: https://github.com/onnx/onnx/pull/4769
|
||||
// Following tests are for opset 16 ops and are not yet implemented in ORT
|
||||
"^test_roialign_aligned_*",
|
||||
"^test_roialign_mode_max", // TODO: Remove once onnx test is fixed
|
||||
//GPU failures
|
||||
"^test_batchnorm_epsilon_training_mode_cuda",
|
||||
"^test_batchnorm_example_training_mode_cuda",
|
||||
|
|
|
|||
Loading…
Reference in a new issue