mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Implement dft(20) (#17821)
### Description dft is updated in opset20. implement it in ort ### Motivation and Context this is for ort 1.17.0 release Fixes #17723 --------- Signed-off-by: Liqun Fu <liqfu@microsoft.com>
This commit is contained in:
parent
5f00bc9931
commit
32fcf73740
6 changed files with 101 additions and 37 deletions
|
|
@ -80,7 +80,8 @@ Do not modify directly.*
|
|||
|Crop|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|
||||
|CumSum|*in* x:**T**<br> *in* axis:**T2**<br> *out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|||[17, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|
||||
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|
||||
|||[11, 12]|**T** = tensor(double), tensor(float)|
|
||||
|||[1, 10]|**T** = tensor(double), tensor(float)|
|
||||
|
|
|
|||
|
|
@ -823,7 +823,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int64_t, LessOrEqual);
|
||||
|
||||
// Opset 17
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, DFT);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, 19, DFT);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, BlackmanWindow);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HammingWindow);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HannWindow);
|
||||
|
|
@ -960,6 +960,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
|
|||
|
||||
// Opset 20
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid);
|
||||
|
|
@ -2217,7 +2218,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
// Opset 17
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, BlackmanWindow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, DFT)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, 19, DFT)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HammingWindow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HannWindow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MelWeightMatrix)>,
|
||||
|
|
@ -2403,6 +2404,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
// Opset 20
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid)>,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,15 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(DFT, 17,
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
DFT,
|
||||
17, 19,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
|
||||
.TypeConstraint("T2", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
DFT);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(DFT, 20,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
|
||||
.TypeConstraint("T2", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
|
|
@ -442,7 +450,13 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo
|
|||
}
|
||||
|
||||
Status DFT::Compute(OpKernelContext* ctx) const {
|
||||
ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis_, is_onesided_, is_inverse_));
|
||||
int64_t axis = axis_;
|
||||
if (opset_ >= 20 && ctx->InputCount() >= 3) {
|
||||
const Tensor* axes_tensor = ctx->Input<Tensor>(2);
|
||||
axis = axes_tensor->Data<int64_t>()[0];
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis, is_onesided_, is_inverse_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
class DFT final : public OpKernel {
|
||||
int opset_;
|
||||
bool is_onesided_ = true;
|
||||
int64_t axis_ = 0;
|
||||
bool is_inverse_ = false;
|
||||
|
|
@ -14,7 +15,11 @@ class DFT final : public OpKernel {
|
|||
public:
|
||||
explicit DFT(const OpKernelInfo& info) : OpKernel(info) {
|
||||
is_onesided_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("onesided", 0));
|
||||
axis_ = info.GetAttrOrDefault<int64_t>("axis", 1);
|
||||
opset_ = info.node().SinceVersion();
|
||||
if (opset_ < 20)
|
||||
axis_ = info.GetAttrOrDefault<int64_t>("axis", 1);
|
||||
else
|
||||
axis_ = -2; // default axis of DFT(20)
|
||||
is_inverse_ = info.GetAttrOrDefault<int64_t>("inverse", 0);
|
||||
}
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
|
|
|
|||
|
|
@ -16,9 +16,10 @@ namespace onnxruntime {
|
|||
namespace test {
|
||||
|
||||
static constexpr int kMinOpsetVersion = 17;
|
||||
static constexpr int kOpsetVersion20 = 20;
|
||||
|
||||
static void TestNaiveDFTFloat(bool onesided) {
|
||||
OpTester test("DFT", kMinOpsetVersion);
|
||||
static void TestNaiveDFTFloat(bool onesided, int since_version) {
|
||||
OpTester test("DFT", since_version);
|
||||
|
||||
vector<int64_t> shape = {1, 5, 1};
|
||||
vector<int64_t> output_shape = {1, 5, 2};
|
||||
|
|
@ -37,8 +38,8 @@ static void TestNaiveDFTFloat(bool onesided) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
static void TestRadix2DFTFloat(bool onesided) {
|
||||
OpTester test("DFT", kMinOpsetVersion);
|
||||
static void TestRadix2DFTFloat(bool onesided, int since_version) {
|
||||
OpTester test("DFT", since_version);
|
||||
|
||||
vector<int64_t> shape = {1, 8, 1};
|
||||
vector<int64_t> output_shape = {1, 8, 2};
|
||||
|
|
@ -57,20 +58,8 @@ static void TestRadix2DFTFloat(bool onesided) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFTFloat_naive) {
|
||||
TestNaiveDFTFloat(false);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFTFloat_naive_onesided) {
|
||||
TestNaiveDFTFloat(true);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFTFloat_radix2) { TestRadix2DFTFloat(false); }
|
||||
|
||||
TEST(SignalOpsTest, DFTFloat_radix2_onesided) { TestRadix2DFTFloat(true); }
|
||||
|
||||
TEST(SignalOpsTest, DFTFloat_inverse) {
|
||||
OpTester test("DFT", kMinOpsetVersion);
|
||||
static void TestInverseFloat(int since_version) {
|
||||
OpTester test("DFT", since_version);
|
||||
|
||||
vector<int64_t> shape = {1, 5, 2};
|
||||
vector<float> input = {15.000000f, 0.0000000f, -2.499999f, 3.4409550f, -2.500000f,
|
||||
|
|
@ -83,12 +72,44 @@ TEST(SignalOpsTest, DFTFloat_inverse) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT17_Float_naive) {
|
||||
TestNaiveDFTFloat(false, kMinOpsetVersion);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT20_Float_naive) {
|
||||
TestNaiveDFTFloat(false, kOpsetVersion20);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT17_Float_naive_onesided) {
|
||||
TestNaiveDFTFloat(true, kMinOpsetVersion);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT20_Float_naive_onesided) {
|
||||
TestNaiveDFTFloat(true, kOpsetVersion20);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT17_Float_radix2) { TestRadix2DFTFloat(false, kMinOpsetVersion); }
|
||||
|
||||
TEST(SignalOpsTest, DFT20_Float_radix2) { TestRadix2DFTFloat(false, kOpsetVersion20); }
|
||||
|
||||
TEST(SignalOpsTest, DFT17_Float_radix2_onesided) { TestRadix2DFTFloat(true, kMinOpsetVersion); }
|
||||
|
||||
TEST(SignalOpsTest, DFT20_Float_radix2_onesided) { TestRadix2DFTFloat(true, kOpsetVersion20); }
|
||||
|
||||
TEST(SignalOpsTest, DFT17_Float_inverse) {
|
||||
TestInverseFloat(kMinOpsetVersion);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT20_Float_inverse) {
|
||||
TestInverseFloat(kOpsetVersion20);
|
||||
}
|
||||
|
||||
// Tests that FFT(FFT(x), inverse=true) == x
|
||||
static void TestDFTInvertible(bool complex) {
|
||||
static void TestDFTInvertible(bool complex, int since_version) {
|
||||
// TODO: test dft_length
|
||||
class DFTInvertibleTester : public OpTester {
|
||||
public:
|
||||
DFTInvertibleTester(int64_t axis) : OpTester("DFT", kMinOpsetVersion), axis_(axis) {}
|
||||
DFTInvertibleTester(int64_t axis, int since_version) : OpTester("DFT", since_version), axis_(axis) {}
|
||||
|
||||
protected:
|
||||
void AddNodes(Graph& graph, vector<NodeArg*>& graph_inputs, vector<NodeArg*>& graph_outputs,
|
||||
|
|
@ -98,11 +119,20 @@ static void TestDFTInvertible(bool complex) {
|
|||
|
||||
// call base implementation to add the DFT node.
|
||||
OpTester::AddNodes(graph, graph_inputs, intermediate_outputs, add_attribute_funcs);
|
||||
OpTester::AddAttribute("axis", axis_);
|
||||
if (this->Opset() < kOpsetVersion20) {
|
||||
OpTester::AddAttribute("axis", axis_);
|
||||
} else {
|
||||
assert(intermediate_outputs.size() == 1);
|
||||
assert(graph_inputs.size() == 3);
|
||||
intermediate_outputs.push_back(graph_inputs[1]);
|
||||
intermediate_outputs.push_back(graph_inputs[2]);
|
||||
}
|
||||
|
||||
Node& inverse = graph.AddNode("inverse", "DFT", "inverse", intermediate_outputs, graph_outputs);
|
||||
inverse.AddAttribute("inverse", static_cast<int64_t>(true));
|
||||
inverse.AddAttribute("axis", axis_);
|
||||
if (this->Opset() < kOpsetVersion20) {
|
||||
inverse.AddAttribute("axis", axis_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -112,14 +142,21 @@ static void TestDFTInvertible(bool complex) {
|
|||
RandomValueGenerator random(GetTestRandomSeed());
|
||||
// TODO(smk2007): Add tests for different dft_length values.
|
||||
constexpr int64_t num_batches = 2;
|
||||
for (int64_t axis = 1; axis < 2; axis += 1) {
|
||||
for (int64_t axis = 0; axis < 2; axis += 1) {
|
||||
for (int64_t signal_dim1 = 2; signal_dim1 <= 5; signal_dim1 += 1) {
|
||||
for (int64_t signal_dim2 = 2; signal_dim2 <= 5; signal_dim2 += 1) {
|
||||
DFTInvertibleTester test(axis);
|
||||
if (axis == 0 && since_version < kOpsetVersion20)
|
||||
continue;
|
||||
DFTInvertibleTester test(axis, since_version);
|
||||
vector<int64_t> input_shape{num_batches, signal_dim1, signal_dim2, 1 + (complex ? 1 : 0)};
|
||||
vector<float> input_data = random.Uniform<float>(input_shape, -100.f, 100.f);
|
||||
test.AddInput("input", input_shape, input_data);
|
||||
|
||||
if (since_version >= kOpsetVersion20) {
|
||||
test.AddInput<int64_t>("", {0}, {});
|
||||
test.AddInput<int64_t>("axis", {1}, {axis});
|
||||
}
|
||||
|
||||
vector<int64_t> output_shape(input_shape);
|
||||
vector<float>* output_data_p;
|
||||
vector<float> output_data;
|
||||
|
|
@ -141,12 +178,20 @@ static void TestDFTInvertible(bool complex) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT_invertible_real) {
|
||||
TestDFTInvertible(false);
|
||||
TEST(SignalOpsTest, DFT17_invertible_real) {
|
||||
TestDFTInvertible(false, kMinOpsetVersion);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT_invertible_complex) {
|
||||
TestDFTInvertible(true);
|
||||
TEST(SignalOpsTest, DFT20_invertible_real) {
|
||||
TestDFTInvertible(false, kOpsetVersion20);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT17_invertible_complex) {
|
||||
TestDFTInvertible(true, kMinOpsetVersion);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, DFT20_invertible_complex) {
|
||||
TestDFTInvertible(true, kOpsetVersion20);
|
||||
}
|
||||
|
||||
TEST(SignalOpsTest, STFTFloat) {
|
||||
|
|
|
|||
|
|
@ -262,9 +262,6 @@
|
|||
"^test_string_split_empty_tensor",
|
||||
"^test_string_split_maxsplit",
|
||||
"^test_string_split_no_delimiter",
|
||||
"^test_dft_axis",
|
||||
"^test_dft",
|
||||
"^test_dft_inverse",
|
||||
"^test_reduce_max_bool_inputs",
|
||||
"^test_reduce_min_bool_inputs",
|
||||
"^test_reduce_min_empty_set",
|
||||
|
|
|
|||
Loading…
Reference in a new issue