Fix IRFFT contrib op output dimension calculation (#15662)

### Description
Fixes the issue with IRFFT output dimension calculation as described in
#13236

### Motivation and Context
Please refer to #13236 for detailed description.

Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as:
```
out_dim = in_dim * 2 - 1
```
while it should be this instead:
```
out_dim = 2 * (in_dim - 1)
```
(assuming the original signal has even number of samples, of course).

For example, if the original signal has 4 samples, then the round trip should look something like:
```
4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4
```
with the current code the output will be a signal with 5 points.

---------

Co-authored-by: Alexey Kamenev <akamenev@nvidia.com>
Co-authored-by: Nick Geneva <nicholasgeneva@gmail.com>
This commit is contained in:
Alexey Kamenev 2023-07-28 15:52:37 -07:00 committed by GitHub
parent 1743e9a615
commit 7c05f7bab1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 29 deletions

View file

@ -2242,6 +2242,8 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.Irfft"></a><a name="com.microsoft.irfft">**com.microsoft.Irfft**</a>
This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
@ -2250,25 +2252,25 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>normalized</tt> : int</dt>
<dd></dd>
<dd>must be 0, normalization currently not supported</dd>
<dt><tt>onesided</tt> : int</dt>
<dd></dd>
<dd>must be 1, only one sided FFTs supported</dd>
<dt><tt>signal_ndim</tt> : int (required)</dt>
<dd></dd>
<dd>number of dimensions comprising the signal</dd>
</dl>
#### Inputs
<dl>
<dt><tt>X</tt> : T</dt>
<dd>input tensor</dd>
<dd>input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts</dd>
</dl>
#### Outputs
<dl>
<dt><tt>Y</tt> : T</dt>
<dd>output tensor</dd>
<dd>output tensor with size n in the signal dim</dd>
</dl>
#### Type Constraints
@ -4394,6 +4396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.Rfft"></a><a name="com.microsoft.rfft">**com.microsoft.Rfft**</a>
This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
@ -4402,25 +4406,25 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>normalized</tt> : int</dt>
<dd></dd>
<dd>must be 0, normalization currently not supported</dd>
<dt><tt>onesided</tt> : int</dt>
<dd></dd>
<dd>must be 1, only one sided FFTs supported</dd>
<dt><tt>signal_ndim</tt> : int</dt>
<dd></dd>
<dd>number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)</dd>
</dl>
#### Inputs
<dl>
<dt><tt>X</tt> : T</dt>
<dd>input tensor</dd>
<dd>input tensor of size n in the signal dim</dd>
</dl>
#### Outputs
<dl>
<dt><tt>Y</tt> : T</dt>
<dd>output tensor</dd>
<dd>output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts</dd>
</dl>
#### Type Constraints

View file

@ -100,7 +100,7 @@ Status FFTBase<T>::DoFFT(OpKernelContext* context, const Tensor* X, bool complex
// process the last dim(s)
if (onesided_) {
if (complex_input && !complex_output) { // IRFFT
int64_t inferred_size = input_shape[i] * 2 - 1;
int64_t inferred_size = 2 * (input_shape[i] - 1);
output_dims.push_back(inferred_size);
signal_dims.push_back(inferred_size);
} else if (!complex_input && complex_output) { // RFFT

View file

@ -1242,22 +1242,22 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MaxpoolWithMask, 1,
ONNX_MS_OPERATOR_SET_SCHEMA(Rfft, 1,
OpSchema()
.SetDoc(R"DOC()DOC")
.Input(0, "X", "input tensor", "T")
.Attr("signal_ndim", "", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("normalized", "", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("onesided", "", AttributeProto::INT, static_cast<int64_t>(1))
.Output(0, "Y", "output tensor", "T")
.SetDoc(R"DOC(This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number.)DOC")
.Input(0, "X", "input tensor of size n in the signal dim", "T")
.Attr("signal_ndim", "number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast<int64_t>(1))
.Output(0, "Y", "output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."));
ONNX_MS_OPERATOR_SET_SCHEMA(Irfft, 1,
OpSchema()
.SetDoc(R"DOC()DOC")
.Input(0, "X", "input tensor", "T")
.Attr("signal_ndim", "", AttributeProto::INT)
.Attr("normalized", "", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("onesided", "", AttributeProto::INT, static_cast<int64_t>(1))
.Output(0, "Y", "output tensor", "T")
.SetDoc(R"DOC(This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'.)DOC")
.Input(0, "X", "input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T")
.Attr("signal_ndim", "number of dimensions comprising the signal", AttributeProto::INT)
.Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast<int64_t>(1))
.Output(0, "Y", "output tensor with size n in the signal dim", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."));
ONNX_MS_OPERATOR_SET_SCHEMA(ComplexMul, 1,

View file

@ -11,11 +11,12 @@ TEST(ContribOpTest, Rfft) {
if (DefaultCudaExecutionProvider() == nullptr) return;
OpTester test("Rfft", 1, onnxruntime::kMSDomain);
test.AddAttribute("signal_ndim", static_cast<int64_t>(2));
test.AddAttribute("signal_ndim", static_cast<int64_t>(1));
test.AddAttribute("onesided", static_cast<int64_t>(1));
test.AddAttribute("normalized", static_cast<int64_t>(0));
test.AddInput<float>("X", {4, 5}, std::vector<float>{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f});
test.AddOutput<float>("Y", {4, 3, 2}, std::vector<float>{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f});
// Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward")
test.AddInput<float>("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
test.AddOutput<float>("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
@ -25,11 +26,11 @@ TEST(ContribOpTest, Irfft) {
if (DefaultCudaExecutionProvider() == nullptr) return;
OpTester test("Irfft", 1, onnxruntime::kMSDomain);
test.AddAttribute("signal_ndim", static_cast<int64_t>(2));
test.AddAttribute("signal_ndim", static_cast<int64_t>(1));
test.AddAttribute("onesided", static_cast<int64_t>(1));
test.AddAttribute("normalized", static_cast<int64_t>(0));
test.AddInput<float>("X", {4, 3, 2}, std::vector<float>{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f});
test.AddOutput<float>("Y", {4, 5}, std::vector<float>{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f});
test.AddInput<float>("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
test.AddOutput<float>("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);