diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index 09c88b6027..429aabd666 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -18,8 +18,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ROIAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ROIAlign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence); // This section includes all opkernel declarations for former experimental ops which have now been removed from onnx. @@ -63,8 +61,6 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index dd5835d949..3f42ab4d7e 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -96,9 +96,9 @@ If scale is not provided, crop the borders as provided.)DOC"; .Output(0, "output", "Result, has same type as input, with H and W dimensions reduced.", "T") .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors."); - static const char* ThresholdedRelu_ver1_doc = R"DOC( -ThresholdedRelu takes one input data (Tensor) and produces one output data -(Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise, + static const char* ThresholdedRelu_ver1_doc = R"DOC( +ThresholdedRelu takes one input data (Tensor) and produces one output data +(Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise, is applied to the tensor elementwise. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(ThresholdedRelu) @@ -1015,61 +1015,6 @@ Example 4: "Constrain to tensor(float).") .SetDoc(R"DOC(The WordConvEmbedding takes in a batch of sequence words and embed each word to a vector.)DOC"); - ONNX_CONTRIB_OPERATOR_SCHEMA(ROIAlign) - .SetDomain(kMSDomain) - .SinceVersion(1) - .Attr( - "spatial_scale", - "Multiplicative spatial scale factor to translate ROI coordinates " - "from their input spatial scale to the scale used when pooling, " - "i.e., spatial scale of the input feature map X relative to the " - "input image. E.g.; default is 1.0f. ", - AttributeProto::FLOAT, - 1.f) - .Attr( - "pooled_h", - "default 1; Pooled output Y's height.", - AttributeProto::INT, - static_cast(1)) - .Attr( - "pooled_w", - "default 1; Pooled output Y's width.", - AttributeProto::INT, - static_cast(1)) - .Attr( - "sampling_ratio", - "Number of sampling points in the interpolation grid used to compute " - "the output value of each pooled output bin. If > 0, then exactly " - "sampling_ratio x sampling_ratio grid points are used. If == 0, then " - "an adaptive number of grid points are used (computed as " - "ceil(roi_width / pooled_w), and likewise for height). Default is 0.", - AttributeProto::INT, - static_cast(0)) - .Attr( - "mode", - "The pooling method. Two modes are supported: 'avg' and 'max'. " - "Default is 'avg'.", - AttributeProto::STRING, - std::string("avg")) - .Input(0, "X", "Input data tensor from the previous operator; 4-D feature map of shape (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and the width of the data.", "T") - .Input(1, "rois", "RoIs (Regions of Interest2) to pool over; rois is 2-D input of shape (num_rois, 5) given as [[batch_id, x1, y1, x2, y2], ...]. The RoIs' coordinates are in the coordinate system of the input image.", "T") - .Output(0, "Y", "RoI pooled output, 4-D tesnor of shape (num_rois, C, pooled_h, pooled_w). The r-th batch element Y[r-1] is a pooled feature map corresponding to the r-th RoI X[r-1].", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain to float, float16 and double tensors.") - .SetDoc(R"DOC(Region of Interest (RoI) align operation described in the - [Mask R-CNN paper](https://arxiv.org/abs/1703.06870). - RoIAlign consumes an input tensor X and region of interests (rois) - to apply pooling across each RoI; it produces a 4-D tensor of shape - (num_rois, C, pooled_h, pooled_w). - - RoIAlign is proposed to avoid the misalignment by removing - quantizations while converting from original image into feature - map and from feature map into RoI feature; in each ROI bin, - the value of the sampled locations are computed directly - through bilinear interpolation.)DOC"); - #ifdef MICROSOFT_INTERNAL // register internal ops RegisterInternalSchemas(); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 67125aa8ca..a98796a91a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -266,6 +266,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign); void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -521,6 +523,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc similarity index 80% rename from onnxruntime/contrib_ops/cpu/roialign.cc rename to onnxruntime/core/providers/cpu/object_detection/roialign.cc index 1da4ce5103..13d04ffe11 100644 --- a/onnxruntime/contrib_ops/cpu/roialign.cc +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc @@ -25,18 +25,18 @@ using namespace onnxruntime::concurrency; namespace onnxruntime { -namespace contrib { const int64_t EXPECTED_NUM_ROI_DIMS = 2; -const int64_t EXPECTED_SECOND_ROI_DIM = 5; +const int64_t EXPECTED_SECOND_ROI_DIM = 4; -#define ADD_TYPED_ROIALIGN_OP(data_type) \ - ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( \ - ROIAlign, \ - 1, \ - data_type, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ROIAlign); +#define ADD_TYPED_ROIALIGN_OP(data_type) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + RoiAlign, \ + 10, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); ADD_TYPED_ROIALIGN_OP(float); ADD_TYPED_ROIALIGN_OP(double); @@ -151,7 +151,7 @@ void pre_calc_for_bilinear_interpolate( } template -void ROIAlignForward( +void RoiAlignForward( int64_t nthreads, const T* bottom_data, float spatial_scale, @@ -162,18 +162,18 @@ void ROIAlignForward( int64_t pooled_width, int64_t sampling_ratio, const T* bottom_rois, - int64_t roi_cols, + int64_t num_roi_cols, T* top_data, const std::string& mode, + const int64_t* batch_indices_ptr, const ThreadPool* ttp) { int64_t n_rois = nthreads / channels / pooled_width / pooled_height; std::function work_object = [&](int32_t n) { int64_t index_n = n * channels * pooled_width * pooled_height; - const T* offset_bottom_rois = bottom_rois + n * roi_cols; - const T roi_batch_ind = offset_bottom_rois[0]; - offset_bottom_rois++; + const T* offset_bottom_rois = bottom_rois + n * num_roi_cols; + const auto roi_batch_ind = batch_indices_ptr[n]; // Do not using rounding; this implementation detail is critical T roi_start_w = offset_bottom_rois[0] * spatial_scale; @@ -264,53 +264,80 @@ void ROIAlignForward( } // for ph } // for c }; // for n - const_cast(ttp)->ParallelFor((int32_t)n_rois, work_object); + const_cast(ttp)->ParallelFor(static_cast(n_rois), work_object); } } // namespace template -Status ROIAlign::Compute(OpKernelContext* context) const { +Status RoiAlign::Compute(OpKernelContext* context) const { + using namespace onnxruntime::common; + + // X const Tensor* X_ptr = context->Input(0); if (!X_ptr) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null input X ptr"); } + // rois const Tensor* rois_ptr = context->Input(1); if (!rois_ptr) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null rois_ptr"); } - auto& x_dims = X_ptr->Shape(); + // batch indices + const Tensor* batch_indices_ptr = context->Input(2); + if (!batch_indices_ptr) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null rois_ptr"); + } - auto& rois_dims = rois_ptr->Shape(); + const auto& x_dims = X_ptr->Shape(); + const auto& rois_dims = rois_ptr->Shape(); + const auto& batch_indices_dims = batch_indices_ptr->Shape(); + + if (batch_indices_dims.NumDimensions() != 1) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions for batch indices should be exactly 1"); + } // validate rois_dims if (rois_dims.NumDimensions() != EXPECTED_NUM_ROI_DIMS) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Number of dimensions for rois should be exactly 2"); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions for rois should be exactly " + std::to_string(EXPECTED_NUM_ROI_DIMS)); } if (rois_dims[1] != EXPECTED_SECOND_ROI_DIM) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Second dimension for rois should be exactly 5"); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + "Second dimension for rois should be exactly " + std::to_string(EXPECTED_SECOND_ROI_DIM)); } - auto& Y = *context->Output(0, {rois_dims[0], x_dims[1], pooled_h_, pooled_w_}); + auto num_rois = batch_indices_dims[0]; + auto num_rois_from_rois = rois_dims[0]; + auto num_roi_cols = rois_dims[1]; + + // first dimension of batch_indices and rois should match + if (num_rois != num_rois_from_rois) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + "First dimension (num_rois) of batch_indices and rois don't match"); + } + + auto& Y = *context->Output(0, {num_rois, x_dims[1], output_height_, output_width_}); int64_t output_size = Y.Shape().Size(); - ROIAlignForward( - output_size, + RoiAlignForward( + output_size, // num threads X_ptr->Data(), spatial_scale_, - x_dims[1], - x_dims[2], - x_dims[3], - pooled_h_, - pooled_w_, + x_dims[1], // num channels + x_dims[2], // height + x_dims[3], // width + output_height_, + output_width_, sampling_ratio_, rois_ptr->Data(), - rois_dims[1], + num_roi_cols, Y.template MutableData(), mode_, + batch_indices_ptr->Data(), static_cast(context)->GetOperatorThreadPool()); return Status::OK(); } -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/roialign.h b/onnxruntime/core/providers/cpu/object_detection/roialign.h similarity index 70% rename from onnxruntime/contrib_ops/cpu/roialign.h rename to onnxruntime/core/providers/cpu/object_detection/roialign.h index 5e52b7ac98..80555d5925 100644 --- a/onnxruntime/contrib_ops/cpu/roialign.h +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.h @@ -7,11 +7,10 @@ #include namespace onnxruntime { -namespace contrib { template -class ROIAlign final : public OpKernel { +class RoiAlign final : public OpKernel { public: - explicit ROIAlign(const OpKernelInfo& info) : OpKernel(info) { + explicit RoiAlign(const OpKernelInfo& info) : OpKernel(info) { // mode std::string mode_tmp; if (info.GetAttr("mode", &mode_tmp).IsOK()) { @@ -22,16 +21,16 @@ class ROIAlign final : public OpKernel { } } - // pooled_h - int64_t pooled_h_tmp; - if (info.GetAttr("pooled_h", &pooled_h_tmp).IsOK()) { - pooled_h_ = pooled_h_tmp; + // output_height + int64_t output_height_tmp; + if (info.GetAttr("output_height", &output_height_tmp).IsOK()) { + output_height_ = output_height_tmp; } - // pooled_w - int64_t pooled_w_tmp; - if (info.GetAttr("pooled_w", &pooled_w_tmp).IsOK()) { - pooled_w_ = pooled_w_tmp; + // output_width + int64_t output_width_tmp; + if (info.GetAttr("output_width", &output_width_tmp).IsOK()) { + output_width_ = output_width_tmp; } // sampling_ratio @@ -52,12 +51,11 @@ class ROIAlign final : public OpKernel { private: std::string mode_{"avg"}; - int64_t pooled_h_{1}; - int64_t pooled_w_{1}; + int64_t output_height_{1}; + int64_t output_width_{1}; int64_t sampling_ratio_{0}; float spatial_scale_{1.0f}; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ROIAlign); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RoiAlign); }; -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc similarity index 67% rename from onnxruntime/test/contrib_ops/roialign_test.cc rename to onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 3dd52b8b0c..250cc629c8 100644 --- a/onnxruntime/test/contrib_ops/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -7,10 +7,10 @@ namespace onnxruntime { namespace test { -TEST(ROIAlignTest, AvgModePositive) { - OpTester test("ROIAlign", 1, onnxruntime::kMSDomain); - test.AddAttribute("pooled_h", 3); - test.AddAttribute("pooled_w", 4); +TEST(RoiAlignTest, AvgModePositive) { + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); test.AddAttribute("sampling_ratio", 2); test.AddAttribute("spatial_scale", 1.0f / 16.0f); @@ -24,7 +24,8 @@ TEST(ROIAlignTest, AvgModePositive) { 25.,26.,27.,28.,29.,30.,31.,32.,33.,34.,35.,36.,37.,38.,39.,40.,41.,42.,43.,44.,45.,46., 47.,48.,49.,50.,51.,52.,53.,54.,55.,56.,57.,58.,59.,60.,61.,62.,63.,64.,65.,66.,67.,68., 69.,70.,71.,72.,73.,74.}); - test.AddInput("rois", {5, 5}, {0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.,0.,-14.,19.,-14.,19.}); + test.AddInput("rois", {5, 4}, {7.,5.,7.,5.,-15.,-15.,-15.,-15.,-10.,21.,-10.,21.,13.,8.,13.,8.,-14.,19.,-14.,19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5,3,3,4}, {2.95833f,3.20833f,3.45833f,3.70833f,4.625f,4.875f,5.125f,5.375f, 6.29167f,6.54167f,6.79167f,7.04167f,27.9583f,28.2083f,28.4583f, 28.7083f,29.625f,29.875f,30.125f,30.375f,31.2917f,31.5417f,31.7917f, @@ -48,11 +49,59 @@ TEST(ROIAlignTest, AvgModePositive) { test.Run(); } -TEST(ROIAlignTest, MaxModePositive) { - OpTester test("ROIAlign", 1, onnxruntime::kMSDomain); +TEST(RoiAlignTest, OnnxTest) { + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 5); + test.AddAttribute("output_width", 5); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + + const int N = 1; + const int C = 1; + const int H = 10; + const int W = 10; + + test.AddInput("X", {N, C, H, W}, { + 0.2764f, 0.7150f, 0.1958f, 0.3416f, 0.4638f, 0.0259f, 0.2963f, 0.6518f, 0.4856f, 0.7250f, + 0.9637f, 0.0895f, 0.2919f, 0.6753f, 0.0234f, 0.6132f, 0.8085f, 0.5324f, 0.8992f, 0.4467f, + 0.3265f, 0.8479f, 0.9698f, 0.2471f, 0.9336f, 0.1878f, 0.4766f, 0.4308f, 0.3400f, 0.2162f, + 0.0206f, 0.1720f, 0.2155f, 0.4394f, 0.0653f, 0.3406f, 0.7724f, 0.3921f, 0.2541f, 0.5799f, + 0.4062f, 0.2194f, 0.4473f, 0.4687f, 0.7109f, 0.9327f, 0.9815f, 0.6320f, 0.1728f, 0.6119f, + 0.3097f, 0.1283f, 0.4984f, 0.5068f, 0.4279f, 0.0173f, 0.4388f, 0.0430f, 0.4671f, 0.7119f, + 0.1011f, 0.8477f, 0.4726f, 0.1777f, 0.9923f, 0.4042f, 0.1869f, 0.7795f, 0.9946f, 0.9689f, + 0.1366f, 0.3671f, 0.7011f, 0.6234f, 0.9867f, 0.5585f, 0.6985f, 0.5609f, 0.8788f, 0.9928f, + 0.5697f, 0.8511f, 0.6711f, 0.9406f, 0.8751f, 0.7496f, 0.1650f, 0.1049f, 0.1559f, 0.2514f, + 0.7012f, 0.4056f, 0.7879f, 0.3461f, 0.0415f, 0.2998f, 0.5094f, 0.3727f, 0.5482f, 0.0502f,}); + test.AddInput("rois", {3, 4}, {0., 0., 9., 9., 0., 5., 4., 9., 5., 5., 9., 9.}); + test.AddInput("batch_indices", {3}, {0, 0, 0}); + test.AddOutput("Y", {3,1,5,5}, { + 0.4664f, 0.4466f, 0.3405f, 0.5688f, 0.6068f, + 0.3714f, 0.4296f, 0.3835f, 0.5562f, 0.3510f, + 0.2768f, 0.4883f, 0.5222f, 0.5528f, 0.4171f, + 0.4713f, 0.4844f, 0.6904f, 0.4920f, 0.8774f, + 0.6239f, 0.7125f, 0.6289f, 0.3355f, 0.3495f, + + 0.3022f, 0.4305f, 0.4696f, 0.3978f, 0.5423f, + 0.3656f, 0.7050f, 0.5165f, 0.3172f, 0.7015f, + 0.2912f, 0.5059f, 0.6476f, 0.6235f, 0.8299f, + 0.5916f, 0.7389f, 0.7048f, 0.8372f, 0.8893f, + 0.6227f, 0.6153f, 0.7097f, 0.6154f, 0.4585f, + + 0.2384f, 0.3379f, 0.3717f, 0.6100f, 0.7601f, + 0.3767f, 0.3785f, 0.7147f, 0.9243f, 0.9727f, + 0.5749f, 0.5826f, 0.5709f, 0.7619f, 0.8770f, + 0.5355f, 0.2566f, 0.2141f, 0.2796f, 0.3600f, + 0.4365f, 0.3504f, 0.2887f, 0.3661f, 0.2349f, + }); + + test.Run(); +} + +TEST(RoiAlignTest, MaxModePositive) { + OpTester test("RoiAlign", 10); test.AddAttribute("mode", "max"); - test.AddAttribute("pooled_h", 3); - test.AddAttribute("pooled_w", 4); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); test.AddAttribute("sampling_ratio", 2); test.AddAttribute("spatial_scale", 1.0f / 16.0f); @@ -66,7 +115,8 @@ TEST(ROIAlignTest, MaxModePositive) { 25.,26.,27.,28.,29.,30.,31.,32.,33.,34.,35.,36.,37.,38.,39.,40.,41.,42.,43.,44.,45.,46., 47.,48.,49.,50.,51.,52.,53.,54.,55.,56.,57.,58.,59.,60.,61.,62.,63.,64.,65.,66.,67.,68., 69.,70.,71.,72.,73.,74.}); - test.AddInput("rois", {5, 5}, {0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.,0.,-14.,19.,-14.,19.}); + test.AddInput("rois", {5, 4}, {7.,5.,7.,5.,-15.,-15.,-15.,-15.,-10.,21.,-10.,21.,13.,8.,13.,8.,-14.,19.,-14.,19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5,3,3,4}, {2.10938f,2.95313f,3.375f,2.53125f,3.35938f,4.70313f,5.375f,4.03125f,3.51563f,4.92188f,5.625f, 4.21875f,10.8984f,15.2578f,17.4375f,13.0781f,17.3568f,24.2995f,27.7708f,20.8281f,18.1641f, 25.4297f,29.0625f,21.7969f,19.6875f,27.5625f,31.5f,23.625f,31.3542f,43.8958f,50.1667f,37.625f, @@ -85,11 +135,11 @@ TEST(ROIAlignTest, MaxModePositive) { test.Run(); } -TEST(ROIAlignTest, AvgModeNegativeInvalidMode) { - OpTester test("ROIAlign", 1, onnxruntime::kMSDomain); - test.AddAttribute("mode", "foobar"); // <-- - test.AddAttribute("pooled_h", 3); - test.AddAttribute("pooled_w", 4); +TEST(RoiAlignTest, AvgModeNegativeInvalidMode) { + OpTester test("RoiAlign", 10); + test.AddAttribute("mode", "foobar"); // <-- failure condition + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); test.AddAttribute("sampling_ratio", -2); test.AddAttribute("spatial_scale", 1.0f / 16.0f); @@ -103,7 +153,8 @@ TEST(ROIAlignTest, AvgModeNegativeInvalidMode) { 25.,26.,27.,28.,29.,30.,31.,32.,33.,34.,35.,36.,37.,38.,39.,40.,41.,42.,43.,44.,45.,46., 47.,48.,49.,50.,51.,52.,53.,54.,55.,56.,57.,58.,59.,60.,61.,62.,63.,64.,65.,66.,67.,68., 69.,70.,71.,72.,73.,74.}); - test.AddInput("rois", {5, 5}, {0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.,0.,-14.,19.,-14.,19.}); + test.AddInput("rois", {5, 4}, {7.,5.,7.,5.,-15.,-15.,-15.,-15.,-10.,21.,-10.,21.,13.,8.,13.,8.,-14.,19.,-14.,19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5,3,3,4}, {2.95833f,3.20833f,3.45833f,3.70833f,4.625f,4.875f,5.125f,5.375f, 6.29167f,6.54167f,6.79167f,7.04167f,27.9583f,28.2083f,28.4583f, 28.7083f,29.625f,29.875f,30.125f,30.375f,31.2917f,31.5417f,31.7917f, @@ -127,11 +178,11 @@ TEST(ROIAlignTest, AvgModeNegativeInvalidMode) { test.Run(OpTester::ExpectResult::kExpectFailure, "Invalid mode"); } -TEST(ROIAlignTest, AvgModeNegativeSamplingRatio) { - OpTester test("ROIAlign", 1, onnxruntime::kMSDomain); - test.AddAttribute("pooled_h", 3); - test.AddAttribute("pooled_w", 4); - test.AddAttribute("sampling_ratio", -2); // <-- +TEST(RoiAlignTest, AvgModeNegativeSamplingRatio) { + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); + test.AddAttribute("sampling_ratio", -2); // <-- failure condition test.AddAttribute("spatial_scale", 1.0f / 16.0f); const int N = 1; @@ -139,12 +190,12 @@ TEST(ROIAlignTest, AvgModeNegativeSamplingRatio) { const int H = 5; const int W = 5; - std::vector rois{0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.,0.,-14.,19.,-14.,19.}; test.AddInput("X", {N, C, H, W}, {0.,1.,2.,3.,4.,5.,6.,7.,8.,9.,10.,11.,12.,13.,14.,15.,16.,17.,18.,19.,20.,21.,22.,23.,24., 25.,26.,27.,28.,29.,30.,31.,32.,33.,34.,35.,36.,37.,38.,39.,40.,41.,42.,43.,44.,45.,46., 47.,48.,49.,50.,51.,52.,53.,54.,55.,56.,57.,58.,59.,60.,61.,62.,63.,64.,65.,66.,67.,68., 69.,70.,71.,72.,73.,74.}); - test.AddInput("rois", {5, 5}, {0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.,0.,-14.,19.,-14.,19.}); + test.AddInput("rois", {5, 4}, {7.,5.,7.,5.,-15.,-15.,-15.,-15.,-10.,21.,-10.,21.,13.,8.,13.,8.,-14.,19.,-14.,19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5,3,3,4}, {2.95833f,3.20833f,3.45833f,3.70833f,4.625f,4.875f,5.125f,5.375f, 6.29167f,6.54167f,6.79167f,7.04167f,27.9583f,28.2083f,28.4583f, 28.7083f,29.625f,29.875f,30.125f,30.375f,31.2917f,31.5417f,31.7917f, @@ -168,10 +219,10 @@ TEST(ROIAlignTest, AvgModeNegativeSamplingRatio) { test.Run(OpTester::ExpectResult::kExpectFailure, "Sampling ratio should be >=0"); } -TEST(ROIAlignTest, AvgModeNegativeInvalidNumRoiDims) { - OpTester test("ROIAlign", 1, onnxruntime::kMSDomain); - test.AddAttribute("pooled_h", 3); - test.AddAttribute("pooled_w", 4); +TEST(RoiAlignTest, AvgModeNegativeInvalidNumRoiDims) { + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); test.AddAttribute("sampling_ratio", 2); test.AddAttribute("spatial_scale", 1.0f / 16.0f); @@ -180,12 +231,13 @@ TEST(ROIAlignTest, AvgModeNegativeInvalidNumRoiDims) { const int H = 5; const int W = 5; - std::vector rois{0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.,0.,-14.,19.,-14.,19.}; + std::vector rois{0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.,0.,-14.,19.,-14.,19.}; test.AddInput("X", {N, C, H, W}, {0.,1.,2.,3.,4.,5.,6.,7.,8.,9.,10.,11.,12.,13.,14.,15.,16.,17.,18.,19.,20.,21.,22.,23.,24., 25.,26.,27.,28.,29.,30.,31.,32.,33.,34.,35.,36.,37.,38.,39.,40.,41.,42.,43.,44.,45.,46., 47.,48.,49.,50.,51.,52.,53.,54.,55.,56.,57.,58.,59.,60.,61.,62.,63.,64.,65.,66.,67.,68., 69.,70.,71.,72.,73.,74.}); - test.AddInput("rois", {5, 4, 1}, {0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.}); // <-- + test.AddInput("rois", {5, 4, 1}, {7.,5.,7.,5.,-15.,-15.,-15.,-15.,-10.,21.,-10.,21.,13.,8.,13.,8.,-14.,19.,-14.,19.}); // <-- failure condition + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5,3,3,4}, {2.95833f,3.20833f,3.45833f,3.70833f,4.625f,4.875f,5.125f,5.375f, 6.29167f,6.54167f,6.79167f,7.04167f,27.9583f,28.2083f,28.4583f, 28.7083f,29.625f,29.875f,30.125f,30.375f,31.2917f,31.5417f,31.7917f, @@ -206,13 +258,54 @@ TEST(ROIAlignTest, AvgModeNegativeInvalidNumRoiDims) { 35.1354f,56.7708f,56.7708f,56.7708f,56.8021f,58.4375f,58.4375f,58.4375f,58.4688f,60.1042f, 60.1042f,60.1042f,60.1354f}); - test.Run(OpTester::ExpectResult::kExpectFailure, "Number of dimensions for rois should be exactly 2"); + test.Run(OpTester::ExpectResult::kExpectFailure, "[ShapeInferenceError] rois input tensor has wrong dimension"); } -TEST(ROIAlignTest, AvgModeNegativeInvalidSecondRoiDims) { - OpTester test("ROIAlign", 1, onnxruntime::kMSDomain); - test.AddAttribute("pooled_h", 3); - test.AddAttribute("pooled_w", 4); +TEST(RoiAlignTest, AvgModeNegativeInvalidSecondRoiDims) { + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f / 16.0f); + + const int N = 1; + const int C = 3; + const int H = 5; + const int W = 5; + + test.AddInput("X", {N, C, H, W}, {0.,1.,2.,3.,4.,5.,6.,7.,8.,9.,10.,11.,12.,13.,14.,15.,16.,17.,18.,19.,20.,21.,22.,23.,24., + 25.,26.,27.,28.,29.,30.,31.,32.,33.,34.,35.,36.,37.,38.,39.,40.,41.,42.,43.,44.,45.,46., + 47.,48.,49.,50.,51.,52.,53.,54.,55.,56.,57.,58.,59.,60.,61.,62.,63.,64.,65.,66.,67.,68., + 69.,70.,71.,72.,73.,74.}); + test.AddInput("rois", {5, 3}, {7.,5.,7.,5.,-15.,-15.,-15.,-15.,-10.,21.,-10.,21.,13.,8.,13.}); // <-- failure condition + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); + test.AddOutput("Y", {5,3,3,4}, {2.95833f,3.20833f,3.45833f,3.70833f,4.625f,4.875f,5.125f,5.375f, + 6.29167f,6.54167f,6.79167f,7.04167f,27.9583f,28.2083f,28.4583f, + 28.7083f,29.625f,29.875f,30.125f,30.375f,31.2917f,31.5417f,31.7917f, + 32.0417f,52.9583f,53.2083f,53.4583f,53.7083f,54.625f,54.875f,55.125f, + 55.375f,56.2917f,56.5417f,56.7917f,57.0417f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f, + 25.f,25.f,25.f,25.f,25.f,25.f,25.f,25.f,25.f,25.f,25.f,25.f,50.f,50.f,50.f,50.f,50.f,50.f,50.f,50.f,50.f, + 50.f,50.f,50.f,7.39583f,7.39583f,7.42708f,7.64583f,9.0625f,9.0625f,9.09375f, + 9.3125f,10.7292f,10.7292f,10.7604f,10.9792f,32.3958f,32.3958f,32.4271f, + 32.6458f,34.0625f,34.0625f,34.0938f,34.3125f,35.7292f,35.7292f,35.7604f, + 35.9792f,57.3958f,57.3958f,57.4271f,57.6458f,59.0625f,59.0625f,59.0938f, + 59.3125f,60.7292f,60.7292f,60.7604f,60.9792f,4.27083f,4.52083f,4.77083f, + 5.02083f,5.9375f,6.1875f,6.4375f,6.6875f,7.60417f,7.85417f,8.10417f,8.35417f, + 29.2708f,29.5208f,29.7708f,30.0208f,30.9375f,31.1875f,31.4375f,31.6875f,32.6042f, + 32.8542f,33.1042f,33.3542f,54.2708f,54.5208f,54.7708f,55.0208f,55.9375f,56.1875f, + 56.4375f,56.6875f,57.6042f,57.8542f,58.1042f,58.3542f,6.77083f,6.77083f,6.77083f, + 6.80208f,8.4375f,8.4375f,8.4375f,8.46875f,10.1042f,10.1042f,10.1042f,10.1354f,31.7708f, + 31.7708f,31.7708f,31.8021f,33.4375f,33.4375f,33.4375f,33.4688f,35.1042f,35.1042f,35.1042f, + 35.1354f,56.7708f,56.7708f,56.7708f,56.8021f,58.4375f,58.4375f,58.4375f,58.4688f,60.1042f, + 60.1042f,60.1042f,60.1354f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "Second dimension for rois should be exactly 4"); +} + +TEST(RoiAlignTest, MismatchNumRois) { + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); test.AddAttribute("sampling_ratio", 2); test.AddAttribute("spatial_scale", 1.0f / 16.0f); @@ -226,7 +319,8 @@ TEST(ROIAlignTest, AvgModeNegativeInvalidSecondRoiDims) { 25.,26.,27.,28.,29.,30.,31.,32.,33.,34.,35.,36.,37.,38.,39.,40.,41.,42.,43.,44.,45.,46., 47.,48.,49.,50.,51.,52.,53.,54.,55.,56.,57.,58.,59.,60.,61.,62.,63.,64.,65.,66.,67.,68., 69.,70.,71.,72.,73.,74.}); - test.AddInput("rois", {5, 4}, {0.,7.,5.,7.,5.,0.,-15.,-15.,-15.,-15.,0.,-10.,21.,-10.,21.,0.,13.,8.,13.,8.}); // <-- + test.AddInput("rois", {5, 4}, {7.,5.,7.,5.,-15.,-15.,-15.,-15.,-10.,21.,-10.,21.,13.,8.,13.,8.,-14.,19.,-14.,19.}); + test.AddInput("batch_indices", {4}, {0, 0, 0, 0}); // <-- failure condition test.AddOutput("Y", {5,3,3,4}, {2.95833f,3.20833f,3.45833f,3.70833f,4.625f,4.875f,5.125f,5.375f, 6.29167f,6.54167f,6.79167f,7.04167f,27.9583f,28.2083f,28.4583f, 28.7083f,29.625f,29.875f,30.125f,30.375f,31.2917f,31.5417f,31.7917f, @@ -247,7 +341,7 @@ TEST(ROIAlignTest, AvgModeNegativeInvalidSecondRoiDims) { 35.1354f,56.7708f,56.7708f,56.7708f,56.8021f,58.4375f,58.4375f,58.4375f,58.4688f,60.1042f, 60.1042f,60.1042f,60.1354f}); - test.Run(OpTester::ExpectResult::kExpectFailure, "Second dimension for rois should be exactly 5"); + test.Run(OpTester::ExpectResult::kExpectFailure, "First dimension (num_rois) of batch_indices and rois don't match"); } } // namespace test } // namespace onnxruntime