diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f90f8cff2f..589be30918 100755
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2242,6 +2242,8 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.Irfft**
+ This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'.
+
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
@@ -2250,25 +2252,25 @@ This version of the operator has been available since version 1 of the 'com.micr
- normalized : int
-
+- must be 0, normalization currently not supported
- onesided : int
-
+- must be 1, only one sided FFTs supported
- signal_ndim : int (required)
-
+- number of dimensions comprising the signal
#### Inputs
- X : T
-- input tensor
+- input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts
#### Outputs
- Y : T
-- output tensor
+- output tensor with size n in the signal dim
#### Type Constraints
@@ -4394,6 +4396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.Rfft**
+ This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number.
+
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
@@ -4402,25 +4406,25 @@ This version of the operator has been available since version 1 of the 'com.micr
- normalized : int
-
+- must be 0, normalization currently not supported
- onesided : int
-
+- must be 1, only one sided FFTs supported
- signal_ndim : int
-
+- number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)
#### Inputs
- X : T
-- input tensor
+- input tensor of size n in the signal dim
#### Outputs
- Y : T
-- output tensor
+- output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts
#### Type Constraints
diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc
index 5fddcd7e93..4b524dcf79 100644
--- a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc
+++ b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc
@@ -100,7 +100,7 @@ Status FFTBase::DoFFT(OpKernelContext* context, const Tensor* X, bool complex
// process the last dim(s)
if (onesided_) {
if (complex_input && !complex_output) { // IRFFT
- int64_t inferred_size = input_shape[i] * 2 - 1;
+ int64_t inferred_size = 2 * (input_shape[i] - 1);
output_dims.push_back(inferred_size);
signal_dims.push_back(inferred_size);
} else if (!complex_input && complex_output) { // RFFT
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 534303b9b1..f129f8ffdb 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1242,22 +1242,22 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MaxpoolWithMask, 1,
ONNX_MS_OPERATOR_SET_SCHEMA(Rfft, 1,
OpSchema()
- .SetDoc(R"DOC()DOC")
- .Input(0, "X", "input tensor", "T")
- .Attr("signal_ndim", "", AttributeProto::INT, static_cast(1))
- .Attr("normalized", "", AttributeProto::INT, static_cast(0))
- .Attr("onesided", "", AttributeProto::INT, static_cast(1))
- .Output(0, "Y", "output tensor", "T")
+ .SetDoc(R"DOC(This function computes the n-point one dimensional Fourier transform for a real-valued input where n is an even number.)DOC")
+ .Input(0, "X", "input tensor of size n in the signal dim", "T")
+ .Attr("signal_ndim", "number of dimensions comprising the signal, collected in reverse order (e.g. 1 = last dimension is the signal)", AttributeProto::INT, static_cast(1))
+ .Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast(0))
+ .Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast(1))
+ .Output(0, "Y", "output tensor of size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."));
ONNX_MS_OPERATOR_SET_SCHEMA(Irfft, 1,
OpSchema()
- .SetDoc(R"DOC()DOC")
- .Input(0, "X", "input tensor", "T")
- .Attr("signal_ndim", "", AttributeProto::INT)
- .Attr("normalized", "", AttributeProto::INT, static_cast(0))
- .Attr("onesided", "", AttributeProto::INT, static_cast(1))
- .Output(0, "Y", "output tensor", "T")
+ .SetDoc(R"DOC(This function computes the inverse of the one-dimensional n-point RFFT computed in 'com.microsoft.rfft'.)DOC")
+ .Input(0, "X", "input tensor with size (n//2 + 1) in the signal dim and 2 in the last dimension for the real and complex parts", "T")
+ .Attr("signal_ndim", "number of dimensions comprising the signal", AttributeProto::INT)
+ .Attr("normalized", "must be 0, normalization currently not supported", AttributeProto::INT, static_cast(0))
+ .Attr("onesided", "must be 1, only one sided FFTs supported", AttributeProto::INT, static_cast(1))
+ .Output(0, "Y", "output tensor with size n in the signal dim", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."));
ONNX_MS_OPERATOR_SET_SCHEMA(ComplexMul, 1,
diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc
index c259cbaf25..eaadb95c8a 100644
--- a/onnxruntime/test/contrib_ops/fft_op_test.cc
+++ b/onnxruntime/test/contrib_ops/fft_op_test.cc
@@ -11,11 +11,12 @@ TEST(ContribOpTest, Rfft) {
if (DefaultCudaExecutionProvider() == nullptr) return;
OpTester test("Rfft", 1, onnxruntime::kMSDomain);
- test.AddAttribute("signal_ndim", static_cast(2));
+ test.AddAttribute("signal_ndim", static_cast(1));
test.AddAttribute("onesided", static_cast(1));
test.AddAttribute("normalized", static_cast(0));
- test.AddInput("X", {4, 5}, std::vector{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f});
- test.AddOutput("Y", {4, 3, 2}, std::vector{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f});
+ // Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward")
+ test.AddInput("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
+ test.AddOutput("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
std::vector> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
@@ -25,11 +26,11 @@ TEST(ContribOpTest, Irfft) {
if (DefaultCudaExecutionProvider() == nullptr) return;
OpTester test("Irfft", 1, onnxruntime::kMSDomain);
- test.AddAttribute("signal_ndim", static_cast(2));
+ test.AddAttribute("signal_ndim", static_cast(1));
test.AddAttribute("onesided", static_cast(1));
test.AddAttribute("normalized", static_cast(0));
- test.AddInput("X", {4, 3, 2}, std::vector{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f});
- test.AddOutput("Y", {4, 5}, std::vector{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f});
+ test.AddInput("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
+ test.AddOutput("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
std::vector> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);