pytorch/test/cpp/api/fft.cpp

159 lines
6 KiB
C++
Raw Normal View History

Adds fft namespace (#41911) Summary: This PR creates a new namespace, torch.fft (torch::fft) and puts a single function, fft, in it. This function is analogous to is a simplified version of NumPy's [numpy.fft.fft](https://numpy.org/doc/1.18/reference/generated/numpy.fft.fft.html?highlight=fft#numpy.fft.fft) that accepts no optional arguments. It is intended to demonstrate how to add and document functions in the namespace, and is not intended to deprecate the existing torch.fft function. Adding this namespace was complicated by the existence of the torch.fft function in Python. Creating a torch.fft Python module makes this name ambiguous: does it refer to a function or module? If the JIT didn't exist, a solution to this problem would have been to make torch.fft refer to a callable class that mimicked both the function and module. The JIT, however, cannot understand this pattern. As a workaround it's required to explicitly `import torch.fft` to access the torch.fft.fft function in Python: ``` import torch.fft t = torch.randn(128, dtype=torch.cdouble) torch.fft.fft(t) ``` See https://github.com/pytorch/pytorch/issues/42175 for future work. Another possible future PR is to get the JIT to understand torch.fft as a callable class so it need not be imported explicitly to be used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41911 Reviewed By: glaringlee Differential Revision: D22941894 Pulled By: mruberry fbshipit-source-id: c8e0b44cbe90d21e998ca3832cf3a533f28dbe8d
2020-08-06 07:18:51 +00:00
#include <gtest/gtest.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
// Naive DFT of a 1 dimensional tensor
torch::Tensor naive_dft(torch::Tensor x, bool forward=true) {
TORCH_INTERNAL_ASSERT(x.dim() == 1);
x = x.contiguous();
auto out_tensor = torch::zeros_like(x);
const int64_t len = x.size(0);
// Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse transform
std::vector<c10::complex<double>> roots(len);
const auto angle_base = (forward ? -2.0 : 2.0) * M_PI / len;
for (int64_t i = 0; i < len; ++i) {
auto angle = i * angle_base;
roots[i] = c10::complex<double>(std::cos(angle), std::sin(angle));
}
const auto in = x.data_ptr<c10::complex<double>>();
const auto out = out_tensor.data_ptr<c10::complex<double>>();
for (int64_t i = 0; i < len; ++i) {
for (int64_t j = 0; j < len; ++j) {
out[i] += roots[(j * i) % len] * in[j];
}
}
return out_tensor;
}
Adds fft namespace (#41911) Summary: This PR creates a new namespace, torch.fft (torch::fft) and puts a single function, fft, in it. This function is analogous to is a simplified version of NumPy's [numpy.fft.fft](https://numpy.org/doc/1.18/reference/generated/numpy.fft.fft.html?highlight=fft#numpy.fft.fft) that accepts no optional arguments. It is intended to demonstrate how to add and document functions in the namespace, and is not intended to deprecate the existing torch.fft function. Adding this namespace was complicated by the existence of the torch.fft function in Python. Creating a torch.fft Python module makes this name ambiguous: does it refer to a function or module? If the JIT didn't exist, a solution to this problem would have been to make torch.fft refer to a callable class that mimicked both the function and module. The JIT, however, cannot understand this pattern. As a workaround it's required to explicitly `import torch.fft` to access the torch.fft.fft function in Python: ``` import torch.fft t = torch.randn(128, dtype=torch.cdouble) torch.fft.fft(t) ``` See https://github.com/pytorch/pytorch/issues/42175 for future work. Another possible future PR is to get the JIT to understand torch.fft as a callable class so it need not be imported explicitly to be used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41911 Reviewed By: glaringlee Differential Revision: D22941894 Pulled By: mruberry fbshipit-source-id: c8e0b44cbe90d21e998ca3832cf3a533f28dbe8d
2020-08-06 07:18:51 +00:00
// NOTE: Visual Studio and ROCm builds don't understand complex literals
// as of August 2020
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
Adds fft namespace (#41911) Summary: This PR creates a new namespace, torch.fft (torch::fft) and puts a single function, fft, in it. This function is analogous to is a simplified version of NumPy's [numpy.fft.fft](https://numpy.org/doc/1.18/reference/generated/numpy.fft.fft.html?highlight=fft#numpy.fft.fft) that accepts no optional arguments. It is intended to demonstrate how to add and document functions in the namespace, and is not intended to deprecate the existing torch.fft function. Adding this namespace was complicated by the existence of the torch.fft function in Python. Creating a torch.fft Python module makes this name ambiguous: does it refer to a function or module? If the JIT didn't exist, a solution to this problem would have been to make torch.fft refer to a callable class that mimicked both the function and module. The JIT, however, cannot understand this pattern. As a workaround it's required to explicitly `import torch.fft` to access the torch.fft.fft function in Python: ``` import torch.fft t = torch.randn(128, dtype=torch.cdouble) torch.fft.fft(t) ``` See https://github.com/pytorch/pytorch/issues/42175 for future work. Another possible future PR is to get the JIT to understand torch.fft as a callable class so it need not be imported explicitly to be used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41911 Reviewed By: glaringlee Differential Revision: D22941894 Pulled By: mruberry fbshipit-source-id: c8e0b44cbe90d21e998ca3832cf3a533f28dbe8d
2020-08-06 07:18:51 +00:00
TEST(FFTTest, fft) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(128, torch::kComplexDouble);
auto actual = torch::fft::fft(t);
auto expect = naive_dft(t);
ASSERT_TRUE(torch::allclose(actual, expect));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, fft_real) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(128, torch::kDouble);
auto actual = torch::fft::fft(t);
auto expect = torch::fft::fft(t.to(torch::kComplexDouble));
ASSERT_TRUE(torch::allclose(actual, expect));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, fft_pad) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(128, torch::kComplexDouble);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto actual = torch::fft::fft(t, 200);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto expect = torch::fft::fft(torch::constant_pad_nd(t, {0, 72}));
ASSERT_TRUE(torch::allclose(actual, expect));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
actual = torch::fft::fft(t, 64);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
expect = torch::fft::fft(torch::constant_pad_nd(t, {0, -64}));
ASSERT_TRUE(torch::allclose(actual, expect));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, fft_norm) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(128, torch::kComplexDouble);
// NOLINTNEXTLINE(bugprone-argument-comment)
auto unnorm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/{});
// NOLINTNEXTLINE(bugprone-argument-comment)
auto norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"forward");
ASSERT_TRUE(torch::allclose(unnorm / 128, norm));
// NOLINTNEXTLINE(bugprone-argument-comment)
auto ortho_norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"ortho");
ASSERT_TRUE(torch::allclose(unnorm / std::sqrt(128), ortho_norm));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, ifft) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto T = torch::randn(128, torch::kComplexDouble);
auto actual = torch::fft::ifft(T);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto expect = naive_dft(T, /*forward=*/false) / 128;
ASSERT_TRUE(torch::allclose(actual, expect));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, fft_ifft) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(77, torch::kComplexDouble);
auto T = torch::fft::fft(t);
ASSERT_EQ(T.size(0), 77);
ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
auto t_round_trip = torch::fft::ifft(T);
ASSERT_EQ(t_round_trip.size(0), 77);
ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
ASSERT_TRUE(torch::allclose(t, t_round_trip));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, rfft) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(129, torch::kDouble);
auto actual = torch::fft::rfft(t);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto expect = torch::fft::fft(t.to(torch::kComplexDouble)).slice(0, 0, 65);
ASSERT_TRUE(torch::allclose(actual, expect));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, rfft_irfft) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(128, torch::kDouble);
auto T = torch::fft::rfft(t);
ASSERT_EQ(T.size(0), 65);
ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
auto t_round_trip = torch::fft::irfft(T);
ASSERT_EQ(t_round_trip.size(0), 128);
ASSERT_EQ(t_round_trip.scalar_type(), torch::kDouble);
ASSERT_TRUE(torch::allclose(t, t_round_trip));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, ihfft) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto T = torch::randn(129, torch::kDouble);
auto actual = torch::fft::ihfft(T);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto expect = torch::fft::ifft(T.to(torch::kComplexDouble)).slice(0, 0, 65);
ASSERT_TRUE(torch::allclose(actual, expect));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(FFTTest, hfft_ihfft) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto t = torch::randn(64, torch::kComplexDouble);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
t[0] = .5; // Must be purely real to satisfy hermitian symmetry
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto T = torch::fft::hfft(t, 127);
ASSERT_EQ(T.size(0), 127);
ASSERT_EQ(T.scalar_type(), torch::kDouble);
auto t_round_trip = torch::fft::ihfft(T);
ASSERT_EQ(t_round_trip.size(0), 64);
ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
ASSERT_TRUE(torch::allclose(t, t_round_trip));
Adds fft namespace (#41911) Summary: This PR creates a new namespace, torch.fft (torch::fft) and puts a single function, fft, in it. This function is analogous to is a simplified version of NumPy's [numpy.fft.fft](https://numpy.org/doc/1.18/reference/generated/numpy.fft.fft.html?highlight=fft#numpy.fft.fft) that accepts no optional arguments. It is intended to demonstrate how to add and document functions in the namespace, and is not intended to deprecate the existing torch.fft function. Adding this namespace was complicated by the existence of the torch.fft function in Python. Creating a torch.fft Python module makes this name ambiguous: does it refer to a function or module? If the JIT didn't exist, a solution to this problem would have been to make torch.fft refer to a callable class that mimicked both the function and module. The JIT, however, cannot understand this pattern. As a workaround it's required to explicitly `import torch.fft` to access the torch.fft.fft function in Python: ``` import torch.fft t = torch.randn(128, dtype=torch.cdouble) torch.fft.fft(t) ``` See https://github.com/pytorch/pytorch/issues/42175 for future work. Another possible future PR is to get the JIT to understand torch.fft as a callable class so it need not be imported explicitly to be used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41911 Reviewed By: glaringlee Differential Revision: D22941894 Pulled By: mruberry fbshipit-source-id: c8e0b44cbe90d21e998ca3832cf3a533f28dbe8d
2020-08-06 07:18:51 +00:00
}