From 4275055868359e41f436c2b170f03d2a52fda057 Mon Sep 17 00:00:00 2001
From: DeyuHuang <10047193+hwangdeyu@users.noreply.github.com>
Date: Thu, 22 Jul 2021 15:39:28 +0800
Subject: [PATCH] Add Gridsampler contrib op (#8372)
* add Gridsampler contrib op
* fix gridsampler_paddingmode_border test
* disable the tests until the kernel added
* fix CI failure
* change GridSampler to GridSample
---
docs/ContribOperators.md | 53 +++++++
.../core/graph/contrib_ops/contrib_defs.cc | 89 +++++++++++-
.../test/contrib_ops/gridsample_test.cc | 137 ++++++++++++++++++
3 files changed, 278 insertions(+), 1 deletion(-)
create mode 100644 onnxruntime/test/contrib_ops/gridsample_test.cc
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 60fbad8254..577637a336 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -24,6 +24,7 @@ Do not modify directly.*
* com.microsoft.FusedMatMul
* com.microsoft.GatherND
* com.microsoft.Gelu
+ * com.microsoft.GridSample
* com.microsoft.Inverse
* com.microsoft.Irfft
* com.microsoft.LongformerAttention
@@ -1185,6 +1186,58 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GridSample**
+
+ Given an `input` and a flow-field `grid`, computes the `output` using `input` values and pixel locations from `grid`.
+ Currently, only spatial (4-D) inputs are supported. For `input` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2),
+ the `output` will have shape (N, C, H_out, W_out).
+ For each output location `output[n, :, h, w]`, the size-2 vector `grid[n, h, w]` specifies `input` pixel locations `x` and `y`,
+ which are used to interpolate the output value `output[n, :, h, w]`.
+ The GridSample operator is often used in doing grid generator and sampler in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
+ See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample).
+
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- align_corners : int
+- If align_corners=1, the extrema (-1 and 1) are considered as referring to the center points of the input's corner pixels. If align_corners=0, they are instead considered as referring to the corner points of the input's corner pixels, making the sampling more resolution agnostic.
+- mode : string
+- Three interpolation modes: bilinear (default), nearest and bicubic.
+- padding_mode : string
+- Support padding modes for outside grid values: `zeros`(default), `border`, `reflection`. zeros: use 0 for out-of-bound grid locations, border: use border values for out-of-bound grid locations, reflection: use values at locations reflected by the border for out-of-bound grid locations.
+
+
+#### Inputs
+
+
+- X : T1
+- 4-D tensor of shape (N, C, H, W), where N is the batch size, C is the numbers of channels, H and W are the height and width of the input data.
+- Grid : T1
+- Input offset, 4-D tensor of shape (N, H_out, W_out, 2), where H_out and W_out are the height and width of grid and output, Grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. If grid has values outside the range of [-1, 1], the corresponding outputs will be handled as defined by padding_mode.
+
+
+#### Outputs
+
+
+- Y : T2
+- 4-D tensor of shape (N, C, H_out, W_out).
+
+
+#### Type Constraints
+
+
+- T1 : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)
+- Constrain input types to all tensor types.
+- T2 : tensor(float16), tensor(float), tensor(double)
+- Constrain output types to float tensors.
+
+
+
### **com.microsoft.Inverse**
#### Version
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index d0050ac069..8019d8758f 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -2985,7 +2985,94 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
ctx.getOutputType(0)
->CopyFrom(input_type->optional_type().elem_type());
});
-
+
+ static const char* GridSample_ver1_doc = R"DOC(
+ Given an `input` and a flow-field `grid`, computes the `output` using `input` values and pixel locations from `grid`.
+ Currently, only spatial (4-D) inputs are supported. For `input` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2),
+ the `output` will have shape (N, C, H_out, W_out).
+ For each output location `output[n, :, h, w]`, the size-2 vector `grid[n, h, w]` specifies `input` pixel locations `x` and `y`,
+ which are used to interpolate the output value `output[n, :, h, w]`.
+ The GridSample operator is often used in doing grid generator and sampler in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
+ See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample).
+ )DOC";
+
+ ONNX_CONTRIB_OPERATOR_SCHEMA(GridSample)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc(GridSample_ver1_doc)
+ .Attr(
+ "mode",
+ "Three interpolation modes: bilinear (default), nearest and bicubic.",
+ AttributeProto::STRING,
+ std::string("bilinear"))
+ .Attr(
+ "padding_mode",
+ "Support padding modes for outside grid values: `zeros`(default), `border`, `reflection`. "
+ "zeros: use 0 for out-of-bound grid locations, "
+ "border: use border values for out-of-bound grid locations, "
+ "reflection: use values at locations reflected by the border for out-of-bound grid locations.",
+ AttributeProto::STRING,
+ std::string("zeros"))
+ .Attr(
+ "align_corners",
+ "If align_corners=1, the extrema (-1 and 1) are considered as referring to the center points of the input's corner pixels. "
+ "If align_corners=0, they are instead considered as referring to the corner points of the input's corner pixels, making the sampling more resolution agnostic.",
+ AttributeProto::INT,
+ static_cast(0))
+ .Input(
+ 0,
+ "X",
+ "4-D tensor of shape (N, C, H, W), "
+ "where N is the batch size, C is the numbers of channels, "
+ "H and W are the height and width of the input data.",
+ "T1")
+ .Input(
+ 1,
+ "Grid",
+ "Input offset, 4-D tensor of shape (N, H_out, W_out, 2), "
+ "where H_out and W_out are the height and width of grid and output, "
+ "Grid specifies the sampling pixel locations normalized by the input spatial dimensions. "
+ "Therefore, it should have most values in the range of [-1, 1]. "
+ "If grid has values outside the range of [-1, 1], the corresponding outputs will be handled as defined by padding_mode.",
+ "T1")
+ .Output(
+ 0,
+ "Y",
+ "4-D tensor of shape (N, C, H_out, W_out).",
+ "T2")
+ .TypeConstraint(
+ "T1",
+ OpSchema::all_tensor_types(),
+ "Constrain input types to all tensor types.")
+ .TypeConstraint(
+ "T2",
+ {"tensor(float16)", "tensor(float)", "tensor(double)"},
+ "Constrain output types to float tensors.")
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+
+ size_t input_param = 0, grid_param = 1;
+
+ checkInputRank(ctx, input_param, 4);
+ checkInputRank(ctx, grid_param, 4);
+
+ // Output dimensions, initialized to an unknown-dimension-value
+ Dim N, C, H_out, W_out;
+
+ // Get value of N from dim 0 of input_param, if available
+ unifyInputDim(ctx, input_param, 0, N);
+ // Get value of C from dim 1 of input_param, if available
+ unifyInputDim(ctx, input_param, 1, C);
+
+ // Get value of H_out from dim 1 of grid_param, if available
+ unifyInputDim(ctx, grid_param, 1, H_out);
+ // Get value of W_out from dim 2 of grid_param, if available
+ unifyInputDim(ctx, grid_param, 2, W_out);
+
+ // set output shape:
+ updateOutputShape(ctx, 0, {N, C, H_out, W_out});
+ });
+
#ifndef _OPSCHEMA_LIB_
// Register the NCHWc schemas if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
diff --git a/onnxruntime/test/contrib_ops/gridsample_test.cc b/onnxruntime/test/contrib_ops/gridsample_test.cc
new file mode 100644
index 0000000000..a729a9c4fa
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/gridsample_test.cc
@@ -0,0 +1,137 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifdef GridSampleKernal // disable the unit tests until the kernel is added and remove it.
+
+#include "gtest/gtest.h"
+#include "test/providers/provider_test_utils.h"
+#include "core/util/math.h"
+
+namespace onnxruntime {
+namespace test {
+
+TEST(GridsampleContribOpTest, gridsample_default) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 4, 4}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f});
+ test.AddInput("Grid", {1, 6, 6, 2},
+ {-1.0000f, -1.0000f, -0.6000f, -1.0000f, -0.2000f, -1.0000f, 0.2000f, -1.0000f,
+ 0.6000f, -1.0000f, 1.0000f, -1.0000f, -1.0000f, -0.6000f, -0.6000f, -0.6000f,
+ -0.2000f, -0.6000f, 0.2000f, -0.6000f, 0.6000f, -0.6000f, 1.0000f, -0.6000f,
+ -1.0000f, -0.2000f, -0.6000f, -0.2000f, -0.2000f, -0.2000f, 0.2000f, -0.2000f,
+ 0.6000f, -0.2000f, 1.0000f, -0.2000f, -1.0000f, 0.2000f, -0.6000f, 0.2000f,
+ -0.2000f, 0.2000f, 0.2000f, 0.2000f, 0.6000f, 0.2000f, 1.0000f, 0.2000f,
+ -1.0000f, 0.6000f, -0.6000f, 0.6000f, -0.2000f, 0.6000f, 0.2000f, 0.6000f,
+ 0.6000f, 0.6000f, 1.0000f, 0.6000f, -1.0000f, 1.0000f, -0.6000f, 1.0000f,
+ -0.2000f, 1.0000f, 0.2000f, 1.0000f, 0.6000f, 1.0000f, 1.0000f, 1.0000f});
+ int64_t align_corners = 0;
+ test.AddAttribute("mode", "bilinear");
+ test.AddAttribute("padding_mode", "zeros");
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", {1, 1, 6, 6},
+ {0.0000f, 0.1500f, 0.5500f, 0.9500f, 1.3500f, 0.7500f,
+ 0.6000f, 1.5000f, 2.3000f, 3.1000f, 3.9000f, 2.1000f,
+ 2.2000f, 4.7000f, 5.5000f, 6.3000f, 7.1000f, 3.7000f,
+ 3.8000f, 7.9000f, 8.7000f, 9.5000f, 10.3000f, 5.3000f,
+ 5.4000f, 11.1000f, 11.9000f, 12.7000f, 13.5000f, 6.9000f,
+ 3.0000f, 6.1500f, 6.5500f, 6.9500f, 7.3500f, 3.7500f});
+ test.Run();
+}
+
+TEST(GridsampleContribOpTest, gridsample_paddingmode_zeros) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 3, 2}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ test.AddInput("Grid", {1, 2, 4, 2},
+ {-10.0000f, -10.0000f, -5.0000f, -5.0000f,
+ -0.2000f, -0.2000f, 10.0000f, 10.0000f,
+ 10.0000f, 10.0000f, -0.2000f, -0.2000f,
+ 5.0000f, 5.0000f, 10.0000f, 10.0000f});
+ test.AddAttribute("padding_mode", "zeros");
+ test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f});
+ test.Run();
+}
+
+TEST(GridsampleContribOpTest, gridsample_paddingmode_border) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 3, 2}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ test.AddInput("Grid", {1, 2, 4, 2},
+ {-10.0000f, -10.0000f, -5.0000f, -5.0000f,
+ -0.2000f, -0.2000f, 10.0000f, 10.0000f,
+ 10.0000f, 10.0000f, -0.2000f, -0.2000f,
+ 5.0000f, 5.0000f, 10.0000f, 10.0000f});
+ test.AddAttribute("padding_mode", "border");
+ test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 5.0000f, 5.0000f, 1.7000f, 5.0000f, 5.0000f});
+ test.Run();
+}
+
+TEST(GridsampleContribOpTest, gridsample_paddingmode_reflection) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 3, 2}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ test.AddInput("Grid", {1, 2, 4, 2},
+ {-10.0000f, -10.0000f, -5.0000f, -5.0000f,
+ -0.2000f, -0.2000f, 10.0000f, 10.0000f,
+ 10.0000f, 10.0000f, -0.2000f, -0.2000f,
+ 5.0000f, 5.0000f, 10.0000f, 10.0000f});
+ test.AddAttribute("padding_mode", "reflection");
+ test.AddOutput("Y", {1, 1, 2, 4}, {2.5000f, 0.0000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 5.0000f, 2.5000f});
+ test.Run();
+}
+
+TEST(GridsampleContribOpTest, gridsample_aligncorners_true) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 3, 2}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ test.AddInput("Grid", {1, 2, 4, 2},
+ {-1.0000f, -1.0000f, -0.5000f, -0.5000f,
+ -0.2000f, -0.2000f, 0.0000f, 0.0000f,
+ 0.0000f, 0.0000f, -0.2000f, -0.2000f,
+ 0.5000f, 0.5000f, 1.0000f, 1.0000f});
+ int64_t align_corners = 1;
+ test.AddAttribute("mode", "bilinear");
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 1.2500f, 2.0000f, 2.5000f, 2.5000f, 2.0000f, 3.7500f, 5.0000f});
+ test.Run();
+}
+
+TEST(GridsampleContribOpTest, gridsample_mode_bilinear) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 3, 2}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ test.AddInput("Grid", {1, 2, 4, 2},
+ {-1.0000f, -1.0000f, -0.5000f, -0.5000f,
+ -0.2000f, -0.2000f, 0.0000f, 0.0000f,
+ 0.0000f, 0.0000f, -0.2000f, -0.2000f,
+ 0.5000f, 0.5000f, 1.0000f, 1.0000f});
+ test.AddAttribute("mode", "bilinear");
+ test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.5000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 4.5000f, 1.2500f});
+ test.Run();
+}
+
+TEST(GridsampleContribOpTest, gridsample_mode_nearest) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 3, 2}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ test.AddInput("Grid", {1, 2, 4, 2},
+ {-1.0000f, -1.0000f, -0.5000f, -0.5000f,
+ -0.2000f, -0.2000f, 0.0000f, 0.0000f,
+ 0.0000f, 0.0000f, -0.2000f, -0.2000f,
+ 0.5000f, 0.5000f, 1.0000f, 1.0000f});
+ test.AddAttribute("mode", "nearest");
+ test.AddOutput("Y", {1, 1, 2, 4}, {0.f, 0.f, 2.f, 2.f, 2.f, 2.f, 5.f, 0.f});
+ test.Run();
+}
+
+TEST(GridsampleContribOpTest, gridsample_mode_bicubic) {
+ OpTester test("GridSample", 1, kMSDomain);
+ test.AddInput("X", {1, 1, 3, 2}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ test.AddInput("Grid", {1, 2, 4, 2},
+ {-1.0000f, -1.0000f, -0.5000f, -0.5000f,
+ -0.2000f, -0.2000f, 0.0000f, 0.0000f,
+ 0.0000f, 0.0000f, -0.2000f, -0.2000f,
+ 0.5000f, 0.5000f, 1.0000f, 1.0000f});
+ test.AddAttribute("mode", "bicubic");
+ test.AddOutput("Y", {1, 1, 2, 4}, {-0.1406f, 0.3828f, 1.7556f, 2.9688f, 2.9688f, 1.7556f, 5.1445f, 1.3906f});
+ test.Run();
+}
+
+} // namespace test
+} // namespace onnxruntime
+
+#endif
+