mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Fix IRFFT contrib op output dimension calculation (#15662)
### Description Fixes the issue with IRFFT output dimension calculation as described in #13236 ### Motivation and Context Please refer to #13236 for detailed description. Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as: ``` out_dim = in_dim * 2 - 1 ``` while it should be this instead: ``` out_dim = 2 * (in_dim - 1) ``` (assuming the original signal has even number of samples, of course). For example, if the original signal has 4 samples, then the round trip should look something like: ``` 4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4 ``` with the current code the output will be a signal with 5 points. --------- Co-authored-by: Alexey Kamenev <akamenev@nvidia.com> Co-authored-by: Nick Geneva <nicholasgeneva@gmail.com>
This commit is contained in:
parent
1743e9a615
commit
7c05f7bab1
4 changed files with 34 additions and 29 deletions
|
|
@ -2242,6 +2242,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
### <a name="com.microsoft.Irfft"></a><a name="com.microsoft.irfft">**com.microsoft.Irfft**</a>
|
||||
|
||||
This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
|
@ -2250,25 +2252,25 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
<dl>
|
||||
<dt><tt>normalized</tt> : int</dt>
|
||||
<dd></dd>
|
||||
<dd>must be 0, normalization currently not supported</dd>
|
||||
<dt><tt>onesided</tt> : int</dt>
|
||||
<dd></dd>
|
||||
<dd>must be 1, only one sided FFTs supported</dd>
|
||||
<dt><tt>signal_ndim</tt> : int (required)</dt>
|
||||
<dd></dd>
|
||||
<dd>number of dimensions comprising the signal</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>X</tt> : T</dt>
|
||||
<dd>input tensor</dd>
|
||||
<dd>input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>Y</tt> : T</dt>
|
||||
<dd>output tensor</dd>
|
||||
<dd>output tensor with size n in the signal dim</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
|
@ -4394,6 +4396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
### <a name="com.microsoft.Rfft"></a><a name="com.microsoft.rfft">**com.microsoft.Rfft**</a>
|
||||
|
||||
This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
|
@ -4402,25 +4406,25 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
<dl>
|
||||
<dt><tt>normalized</tt> : int</dt>
|
||||
<dd></dd>
|
||||
<dd>must be 0, normalization currently not supported</dd>
|
||||
<dt><tt>onesided</tt> : int</dt>
|
||||
<dd></dd>
|
||||
<dd>must be 1, only one sided FFTs supported</dd>
|
||||
<dt><tt>signal_ndim</tt> : int</dt>
|
||||
<dd></dd>
|
||||
<dd>number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>X</tt> : T</dt>
|
||||
<dd>input tensor</dd>
|
||||
<dd>input tensor of size n in the signal dim</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>Y</tt> : T</dt>
|
||||
<dd>output tensor</dd>
|
||||
<dd>output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ Status FFTBase<T>::DoFFT(OpKernelContext* context, const Tensor* X, bool complex
|
|||
// process the last dim(s)
|
||||
if (onesided_) {
|
||||
if (complex_input && !complex_output) { // IRFFT
|
||||
int64_t inferred_size = input_shape[i] * 2 - 1;
|
||||
int64_t inferred_size = 2 * (input_shape[i] - 1);
|
||||
output_dims.push_back(inferred_size);
|
||||
signal_dims.push_back(inferred_size);
|
||||
} else if (!complex_input && complex_output) { // RFFT
|
||||
|
|
|
|||
|
|
@ -1242,22 +1242,22 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MaxpoolWithMask, 1,
|
|||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(Rfft, 1,
|
||||
OpSchema()
|
||||
.SetDoc(R"DOC()DOC")
|
||||
.Input(0, "X", "input tensor", "T")
|
||||
.Attr("signal_ndim", "", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Attr("normalized", "", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("onesided", "", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Output(0, "Y", "output tensor", "T")
|
||||
.SetDoc(R"DOC(This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number.)DOC")
|
||||
.Input(0, "X", "input tensor of size n in the signal dim", "T")
|
||||
.Attr("signal_ndim", "number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Output(0, "Y", "output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(Irfft, 1,
|
||||
OpSchema()
|
||||
.SetDoc(R"DOC()DOC")
|
||||
.Input(0, "X", "input tensor", "T")
|
||||
.Attr("signal_ndim", "", AttributeProto::INT)
|
||||
.Attr("normalized", "", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("onesided", "", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Output(0, "Y", "output tensor", "T")
|
||||
.SetDoc(R"DOC(This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'.)DOC")
|
||||
.Input(0, "X", "input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T")
|
||||
.Attr("signal_ndim", "number of dimensions comprising the signal", AttributeProto::INT)
|
||||
.Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Output(0, "Y", "output tensor with size n in the signal dim", "T")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(ComplexMul, 1,
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ TEST(ContribOpTest, Rfft) {
|
|||
if (DefaultCudaExecutionProvider() == nullptr) return;
|
||||
|
||||
OpTester test("Rfft", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("signal_ndim", static_cast<int64_t>(2));
|
||||
test.AddAttribute("signal_ndim", static_cast<int64_t>(1));
|
||||
test.AddAttribute("onesided", static_cast<int64_t>(1));
|
||||
test.AddAttribute("normalized", static_cast<int64_t>(0));
|
||||
test.AddInput<float>("X", {4, 5}, std::vector<float>{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f});
|
||||
test.AddOutput<float>("Y", {4, 3, 2}, std::vector<float>{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f});
|
||||
// Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward")
|
||||
test.AddInput<float>("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
|
||||
test.AddOutput<float>("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
|
|
@ -25,11 +26,11 @@ TEST(ContribOpTest, Irfft) {
|
|||
if (DefaultCudaExecutionProvider() == nullptr) return;
|
||||
|
||||
OpTester test("Irfft", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("signal_ndim", static_cast<int64_t>(2));
|
||||
test.AddAttribute("signal_ndim", static_cast<int64_t>(1));
|
||||
test.AddAttribute("onesided", static_cast<int64_t>(1));
|
||||
test.AddAttribute("normalized", static_cast<int64_t>(0));
|
||||
test.AddInput<float>("X", {4, 3, 2}, std::vector<float>{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f});
|
||||
test.AddOutput<float>("Y", {4, 5}, std::vector<float>{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f});
|
||||
test.AddInput<float>("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
|
||||
test.AddOutput<float>("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
|
|
|
|||
Loading…
Reference in a new issue