From 7c05f7bab109cfa2ea519902920133305efccab9 Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Fri, 28 Jul 2023 15:52:37 -0700 Subject: [PATCH] Fix IRFFT contrib op output dimension calculation (#15662) ### Description Fixes the issue with IRFFT output dimension calculation as described in #13236 ### Motivation and Context Please refer to #13236 for detailed description. Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as: ``` out_dim = in_dim * 2 - 1 ``` while it should be this instead: ``` out_dim = 2 * (in_dim - 1) ``` (assuming the original signal has even number of samples, of course). For example, if the original signal has 4 samples, then the round trip should look something like: ``` 4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4 ``` with the current code the output will be a signal with 5 points. --------- Co-authored-by: Alexey Kamenev Co-authored-by: Nick Geneva --- docs/ContribOperators.md | 24 +++++++++++-------- onnxruntime/contrib_ops/cuda/math/fft_ops.cc | 2 +- .../core/graph/contrib_ops/contrib_defs.cc | 24 +++++++++---------- onnxruntime/test/contrib_ops/fft_op_test.cc | 13 +++++----- 4 files changed, 34 insertions(+), 29 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f90f8cff2f..589be30918 100755 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2242,6 +2242,8 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.Irfft** + This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'. + #### Version This version of the operator has been available since version 1 of the 'com.microsoft' operator set. @@ -2250,25 +2252,25 @@ This version of the operator has been available since version 1 of the 'com.micr
normalized : int
-
+
must be 0, normalization currently not supported
onesided : int
-
+
must be 1, only one sided FFTs supported
signal_ndim : int (required)
-
+
number of dimensions comprising the signal
#### Inputs
X : T
-
input tensor
+
input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts
#### Outputs
Y : T
-
output tensor
+
output tensor with size n in the signal dim
#### Type Constraints @@ -4394,6 +4396,8 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.Rfft** + This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number. + #### Version This version of the operator has been available since version 1 of the 'com.microsoft' operator set. @@ -4402,25 +4406,25 @@ This version of the operator has been available since version 1 of the 'com.micr
normalized : int
-
+
must be 0, normalization currently not supported
onesided : int
-
+
must be 1, only one sided FFTs supported
signal_ndim : int
-
+
number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)
#### Inputs
X : T
-
input tensor
+
input tensor of size n in the signal dim
#### Outputs
Y : T
-
output tensor
+
output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts
#### Type Constraints diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc index 5fddcd7e93..4b524dcf79 100644 --- a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc +++ b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc @@ -100,7 +100,7 @@ Status FFTBase::DoFFT(OpKernelContext* context, const Tensor* X, bool complex // process the last dim(s) if (onesided_) { if (complex_input && !complex_output) { // IRFFT - int64_t inferred_size = input_shape[i] * 2 - 1; + int64_t inferred_size = 2 * (input_shape[i] - 1); output_dims.push_back(inferred_size); signal_dims.push_back(inferred_size); } else if (!complex_input && complex_output) { // RFFT diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 534303b9b1..f129f8ffdb 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1242,22 +1242,22 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MaxpoolWithMask, 1, ONNX_MS_OPERATOR_SET_SCHEMA(Rfft, 1, OpSchema() - .SetDoc(R"DOC()DOC") - .Input(0, "X", "input tensor", "T") - .Attr("signal_ndim", "", AttributeProto::INT, static_cast(1)) - .Attr("normalized", "", AttributeProto::INT, static_cast(0)) - .Attr("onesided", "", AttributeProto::INT, static_cast(1)) - .Output(0, "Y", "output tensor", "T") + .SetDoc(R"DOC(This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number.)DOC") + .Input(0, "X", "input tensor of size n in the signal dim", "T") + .Attr("signal_ndim", "number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)", AttributeProto::INT, static_cast(1)) + .Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast(0)) + .Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast(1)) + .Output(0, "Y", "output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T") .TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")); ONNX_MS_OPERATOR_SET_SCHEMA(Irfft, 1, OpSchema() - .SetDoc(R"DOC()DOC") - .Input(0, "X", "input tensor", "T") - .Attr("signal_ndim", "", AttributeProto::INT) - .Attr("normalized", "", AttributeProto::INT, static_cast(0)) - .Attr("onesided", "", AttributeProto::INT, static_cast(1)) - .Output(0, "Y", "output tensor", "T") + .SetDoc(R"DOC(This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'.)DOC") + .Input(0, "X", "input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T") + .Attr("signal_ndim", "number of dimensions comprising the signal", AttributeProto::INT) + .Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast(0)) + .Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast(1)) + .Output(0, "Y", "output tensor with size n in the signal dim", "T") .TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")); ONNX_MS_OPERATOR_SET_SCHEMA(ComplexMul, 1, diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc index c259cbaf25..eaadb95c8a 100644 --- a/onnxruntime/test/contrib_ops/fft_op_test.cc +++ b/onnxruntime/test/contrib_ops/fft_op_test.cc @@ -11,11 +11,12 @@ TEST(ContribOpTest, Rfft) { if (DefaultCudaExecutionProvider() == nullptr) return; OpTester test("Rfft", 1, onnxruntime::kMSDomain); - test.AddAttribute("signal_ndim", static_cast(2)); + test.AddAttribute("signal_ndim", static_cast(1)); test.AddAttribute("onesided", static_cast(1)); test.AddAttribute("normalized", static_cast(0)); - test.AddInput("X", {4, 5}, std::vector{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f}); - test.AddOutput("Y", {4, 3, 2}, std::vector{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f}); + // Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward") + test.AddInput("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); + test.AddOutput("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -25,11 +26,11 @@ TEST(ContribOpTest, Irfft) { if (DefaultCudaExecutionProvider() == nullptr) return; OpTester test("Irfft", 1, onnxruntime::kMSDomain); - test.AddAttribute("signal_ndim", static_cast(2)); + test.AddAttribute("signal_ndim", static_cast(1)); test.AddAttribute("onesided", static_cast(1)); test.AddAttribute("normalized", static_cast(0)); - test.AddInput("X", {4, 3, 2}, std::vector{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f}); - test.AddOutput("Y", {4, 5}, std::vector{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f}); + test.AddInput("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); + test.AddOutput("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);