diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index c6d0260229b..2f0eefa88be 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -541,7 +541,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional, bool, c10::string_view), fp32) KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) - KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::string_view, bool, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::string_view, bool, c10::optional, c10::optional), fp32) KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional), fp32) KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional), fp32) KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional), fp32) diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 5b9b273e923..af000cc70d9 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -907,6 +907,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop } } +Tensor stft( + const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, + const optional win_lengthOpt, const c10::optional& window_opt, + const bool normalized, + const optional onesidedOpt, const optional return_complexOpt) { + return at::stft( + self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt, + /*center=*/false, /*mode=*/"constant", normalized, onesidedOpt, + return_complexOpt); +} + // Create complex tensor from the old style of real tensor with size=(..., 2) // This is to support istft in the transition to requiring complex input. // NOTE: This may return a view of the input tensor, or might clone if necessary @@ -1100,6 +1111,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho #undef REPR } +Tensor istft(const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, + const optional win_lengthOpt, const Tensor& window, + const bool center, const bool normalized, const optional onesidedOpt, + const optional lengthOpt) { + return at::native::istft( + self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized, + onesidedOpt, lengthOpt, /*return_complex=*/false); +} + void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { const auto input_sizes = input.sizes(); const auto input_strides = input.strides(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5b93adfdc7a..2eb18bed11f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4320,7 +4320,12 @@ - func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) -- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor +# Overload without center & pad mode, needed for forward-compatibility +- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor + variants: function, method + cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized'] + +- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor variants: function, method - func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 88b0c19b095..78a91c64fe8 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -12,7 +12,7 @@ namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; #if ENABLE_UPGRADERS -constexpr uint64_t kMaxSupportedFileFormatVersion = 11; +constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL; #else constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; #endif @@ -83,9 +83,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; // Bump the version number to 10 to update aten::gelu and // and aten::gelu.out to support the new approximate kwarg. // (see: https://github.com/pytorch/pytorch/pull/61439) -// 4) [02/25/2022] -// Bump version number to 11 to update aten::stft to do padding in ATen -constexpr uint64_t kProducedFileFormatVersion = 11L; +constexpr uint64_t kProducedFileFormatVersion = 0xAL; #else constexpr uint64_t kProducedFileFormatVersion = 0x3L; #endif diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 861f2260457..28e501a5c77 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -118,7 +118,6 @@ ALLOW_LIST = [ ("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)), ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)), ("aten::scatter_reduce.two", datetime.date(2022, 4, 15)), - ("aten::stft", datetime.date(2022, 6, 1)), ("aten::_s_where", datetime.date(2022, 9, 30)), ("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)), ("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)), diff --git a/test/jit/fixtures/test_versioned_stft_v10.ptl b/test/jit/fixtures/test_versioned_stft_v10.ptl deleted file mode 100644 index 7dcb8cc8f71..00000000000 Binary files a/test/jit/fixtures/test_versioned_stft_v10.ptl and /dev/null differ diff --git a/test/jit/fixtures_srcs/fixtures_src.py b/test/jit/fixtures_srcs/fixtures_src.py index ba1322fff2e..dff23702311 100644 --- a/test/jit/fixtures_srcs/fixtures_src.py +++ b/test/jit/fixtures_srcs/fixtures_src.py @@ -57,11 +57,3 @@ class TestVersionedGeluOutV9(torch.nn.Module): def forward(self, x): out = torch.zeros_like(x) return torch._C._nn.gelu(x, out=out) - -class TestVersionedStftV10(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, n_fft: int, window): - # calling aten::stft direct instead of torch.functional.stft - return torch.ops.aten.stft(x, n_fft=n_fft, window=window, return_complex=True) diff --git a/test/jit/fixtures_srcs/generate_models.py b/test/jit/fixtures_srcs/generate_models.py index 92c25cef188..e0015374513 100644 --- a/test/jit/fixtures_srcs/generate_models.py +++ b/test/jit/fixtures_srcs/generate_models.py @@ -96,7 +96,6 @@ ALL_MODULES = { TestVersionedLogspaceOutV8(): "aten::logspace.out", TestVersionedGeluV9(): "aten::gelu", TestVersionedGeluOutV9(): "aten::gelu.out", - TestVersionedStftV10(): "aten::stft", } """ diff --git a/test/jit/test_save_load_for_op_version.py b/test/jit/test_save_load_for_op_version.py index ff793404e3b..b5e38b37d3e 100644 --- a/test/jit/test_save_load_for_op_version.py +++ b/test/jit/test_save_load_for_op_version.py @@ -540,20 +540,3 @@ class TestSaveLoadForOpVersion(JitTestCase): self.assertTrue(output.size(dim=0) == 100) # "Upgraded" model should match the new version output self.assertEqual(output, output_current) - - def test_versioned_stft_v10(self): - model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl" - loaded_model = torch.jit.load(model_path) - buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter()) - buffer.seek(0) - v10_mobile_module = _load_for_lite_interpreter(buffer) - - for in_dtype, window_dtype in product( - [torch.float32, torch.complex64], repeat=2): - input = torch.rand((100,), dtype=in_dtype) - window = torch.rand((10,), dtype=window_dtype) - n_fft = 10 - output = v10_mobile_module(input, n_fft, window) - output_expected = torch.stft(input, n_fft=n_fft, window=window, - center=False, return_complex=True) - self.assertEqual(output, output_expected) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index e4acd74a62e..94c89a90671 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -109,6 +109,7 @@ blocklist = [ "block_diag", "norm", "chain_matmul", + "stft", "tensordot", "split", "unique_consecutive", diff --git a/torch/_tensor.py b/torch/_tensor.py index 993c95b980b..35eba292233 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -2,7 +2,7 @@ from collections import OrderedDict import enum import functools from numbers import Number -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import warnings import copyreg from copy import deepcopy @@ -545,6 +545,40 @@ class Tensor(torch._C._TensorBase): else: return LU, pivots + def stft(self, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: 'Optional[Tensor]' = None, + center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, + onesided: Optional[bool] = None, return_complex: Optional[bool] = None): + r"""See :func:`torch.stft` + + .. warning:: + This function changed signature at version 0.4.1. Calling with + the previous signature may cause error or return incorrect result. + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.stft, (self,), self, n_fft, hop_length=hop_length, + win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex + ) + return torch.stft(self, n_fft, hop_length, win_length, window, center, + pad_mode, normalized, onesided, return_complex=return_complex) + + def istft(self, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: 'Optional[Tensor]' = None, + center: bool = True, normalized: bool = False, + onesided: Optional[bool] = None, length: Optional[int] = None, + return_complex: bool = False): + r"""See :func:`torch.istft`""" + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, length=length, + return_complex=return_complex + ) + return torch.istft(self, n_fft, hop_length, win_length, window, center, + normalized, onesided, length, return_complex=return_complex) + def resize(self, *sizes): if has_torch_function_unary(self): return handle_torch_function(Tensor.resize, (self,), self, *sizes) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 0b368e8e553..6cc15a8aeec 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -4752,21 +4752,16 @@ See :func:`torch.dsplit` """) add_docstr_all('stft', - "stft(n_fft, hop_length=None, win_length=None, window=None, center=True, " - "pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor" r""" +stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor See :func:`torch.stft` - -.. warning:: - This function changed signature at version 0.4.1. Calling with - the previous signature may cause error or return incorrect result. """) add_docstr_all('istft', - "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " - "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor" r""" +istft(n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None) -> Tensor See :func:`torch.istft` """) diff --git a/torch/csrc/jit/mobile/upgrader_mobile.cpp b/torch/csrc/jit/mobile/upgrader_mobile.cpp index eed4a676a9c..0e52829255d 100644 --- a/torch/csrc/jit/mobile/upgrader_mobile.cpp +++ b/torch/csrc/jit/mobile/upgrader_mobile.cpp @@ -67,10 +67,6 @@ getOperatorVersionMapForMobile() { std::vector({ Upgrader({0, 8, "logspace_out_0_8", 10}) })}, - {std::string("aten::stft"), - std::vector({ - Upgrader({0, 10, "stft_0_10", 11}) - })}, }); return operatorVersionMapForMobile; } @@ -531,35 +527,6 @@ const std::vector& getUpgraderBytecodeList() { OperatorString({"prim::unchecked_cast", "", 1}), }), // operators list }), - ByteCodeFunctionWithOperator({ - mobile::Function::registerFunc( - "stft_0_10", - std::vector({ - Instruction{OpCode::STOREN, 1, 8}, - Instruction{OpCode::MOVE, 1, 0}, - Instruction{OpCode::MOVE, 2, 0}, - Instruction{OpCode::MOVE, 3, 0}, - Instruction{OpCode::MOVE, 4, 0}, - Instruction{OpCode::MOVE, 5, 0}, - Instruction{OpCode::LOADC, 1, 0}, - Instruction{OpCode::LOADC, 0, 0}, - Instruction{OpCode::MOVE, 6, 0}, - Instruction{OpCode::MOVE, 7, 0}, - Instruction{OpCode::MOVE, 8, 0}, - Instruction{OpCode::OP, 0, 0}, - Instruction{OpCode::RET, 0, 0}, - }), // instructions list, - std::vector({ - c10::IValue("reflect"), - c10::IValue(false), - }), // constants list, - std::vector(), // types list, - 8 - ), - std::vector({ - OperatorString({"aten::stft", "", 10}), - }), // operators list - }), }); for (const auto& upgrader_function : upgrader_function_list) { for (const auto& op : upgrader_function.operators) { diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp index e50227d18ae..7b09cc409a4 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp @@ -15,17 +15,6 @@ namespace torch { namespace jit { static std::unordered_map kUpgradersEntryMap({ - {"stft_0_10", R"SCRIPT( -def stft_0_10( - self: Tensor, n_fft: int, hop_length: Optional[int] = None, - win_length: Optional[int] = None, window: Optional[Tensor] = None, - normalized: bool = False, onesided: Optional[bool] = None, - return_complex: Optional[bool] = None) -> Tensor: - return torch.stft( - self, n_fft=n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=False, normalized=normalized, onesided=onesided, - return_complex=return_complex) -)SCRIPT"}, {"logspace_0_8", R"SCRIPT( def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): diff --git a/torch/csrc/jit/operator_upgraders/version_map.cpp b/torch/csrc/jit/operator_upgraders/version_map.cpp index d96527b66fc..1e19f4cc39d 100644 --- a/torch/csrc/jit/operator_upgraders/version_map.cpp +++ b/torch/csrc/jit/operator_upgraders/version_map.cpp @@ -16,11 +16,7 @@ static bool isVersionMapSorted = false; // Note for developers: The list of upgraders should be SORTED // by the version number where the upgrader is registered. static std::unordered_map> operatorVersionMap( - {{"aten::stft", - {{11, - "stft_0_10", - "aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}}, - {"aten::logspace", + {{"aten::logspace", {{9, "logspace_0_8", "aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}}, diff --git a/torch/functional.py b/torch/functional.py index dad5fc63dc0..29a66f7f160 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -4,7 +4,7 @@ from typing import ( import torch from torch._C import _add_docstr -import torch.nn.functional +import torch.nn.functional as F from ._lowrank import svd_lowrank, pca_lowrank from .overrides import ( has_torch_function, has_torch_function_unary, has_torch_function_variadic, @@ -478,121 +478,133 @@ def _meshgrid(*tensors, indexing: Optional[str]): return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] -stft = _add_docstr( - torch.stft, - "stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " - "pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor" - r""" +def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: Optional[Tensor] = None, + center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None) -> Tensor: + r"""Short-time Fourier transform (STFT). -Short-time Fourier transform (STFT). + .. warning:: + From version 1.8.0, :attr:`return_complex` must always be given + explicitly for real inputs and `return_complex=False` has been + deprecated. Strongly prefer `return_complex=True` as in a future + pytorch release, this function will only return complex tensors. -.. warning:: - From version 1.8.0, :attr:`return_complex` must always be given - explicitly for real inputs and `return_complex=False` has been - deprecated. Strongly prefer `return_complex=True` as in a future - pytorch release, this function will only return complex tensors. + Note that :func:`torch.view_as_real` can be used to recover a real + tensor with an extra last dimension for real and imaginary components. - Note that :func:`torch.view_as_real` can be used to recover a real - tensor with an extra last dimension for real and imaginary components. + The STFT computes the Fourier transform of short overlapping windows of the + input. This giving frequency components of the signal as they change over + time. The interface of this function is modeled after (but *not* a drop-in + replacement for) librosa_ stft function. -The STFT computes the Fourier transform of short overlapping windows of the -input. This giving frequency components of the signal as they change over -time. The interface of this function is modeled after (but *not* a drop-in -replacement for) librosa_ stft function. + .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html -.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html + Ignoring the optional batch dimension, this method computes the following + expression: -Ignoring the optional batch dimension, this method computes the following -expression: + .. math:: + X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% + \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % + \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right), -.. math:: - X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% - \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % - \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right), + where :math:`m` is the index of the sliding window, and :math:`\omega` is + the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, + or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. -where :math:`m` is the index of the sliding window, and :math:`\omega` is -the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, -or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. + * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time + sequences. -* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time - sequences. + * If :attr:`hop_length` is ``None`` (default), it is treated as equal to + ``floor(n_fft / 4)``. -* If :attr:`hop_length` is ``None`` (default), it is treated as equal to - ``floor(n_fft / 4)``. + * If :attr:`win_length` is ``None`` (default), it is treated as equal to + :attr:`n_fft`. -* If :attr:`win_length` is ``None`` (default), it is treated as equal to - :attr:`n_fft`. + * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from + :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is + treated as if having :math:`1` everywhere in the window. If + :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on + both sides to length :attr:`n_fft` before being applied. -* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from - :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is - treated as if having :math:`1` everywhere in the window. If - :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on - both sides to length :attr:`n_fft` before being applied. + * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on + both sides so that the :math:`t`-th frame is centered at time + :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame + begins at time :math:`t \times \text{hop\_length}`. -* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on - both sides so that the :math:`t`-th frame is centered at time - :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame - begins at time :math:`t \times \text{hop\_length}`. + * :attr:`pad_mode` determines the padding method used on :attr:`input` when + :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for + all available options. Default is ``"reflect"``. -* :attr:`pad_mode` determines the padding method used on :attr:`input` when - :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for - all available options. Default is ``"reflect"``. + * If :attr:`onesided` is ``True`` (default for real input), only values for + :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor + \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because + the real-to-complex Fourier transform satisfies the conjugate symmetry, + i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. + Note if the input or window tensors are complex, then :attr:`onesided` + output is not possible. -* If :attr:`onesided` is ``True`` (default for real input), only values for - :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor - \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because - the real-to-complex Fourier transform satisfies the conjugate symmetry, - i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. - Note if the input or window tensors are complex, then :attr:`onesided` - output is not possible. + * If :attr:`normalized` is ``True`` (default is ``False``), the function + returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. -* If :attr:`normalized` is ``True`` (default is ``False``), the function - returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. + * If :attr:`return_complex` is ``True`` (default if input is complex), the + return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, + the output is a ``input.dim() + 2`` dimensional real tensor where the last + dimension represents the real and imaginary components. -* If :attr:`return_complex` is ``True`` (default if input is complex), the - return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, - the output is a ``input.dim() + 2`` dimensional real tensor where the last - dimension represents the real and imaginary components. + Returns either a complex tensor of size :math:`(* \times N \times T)` if + :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N + \times T \times 2)`. Where :math:`*` is the optional batch size of + :attr:`input`, :math:`N` is the number of frequencies where STFT is applied + and :math:`T` is the total number of frames used. -Returns either a complex tensor of size :math:`(* \times N \times T)` if -:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N -\times T \times 2)`. Where :math:`*` is the optional batch size of -:attr:`input`, :math:`N` is the number of frequencies where STFT is applied -and :math:`T` is the total number of frames used. + .. warning:: + This function changed signature at version 0.4.1. Calling with the + previous signature may cause error or return incorrect result. -.. warning:: - This function changed signature at version 0.4.1. Calling with the - previous signature may cause error or return incorrect result. + Args: + input (Tensor): the input tensor + n_fft (int): size of Fourier transform + hop_length (int, optional): the distance between neighboring sliding window + frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) + win_length (int, optional): the size of window frame and STFT filter. + Default: ``None`` (treated as equal to :attr:`n_fft`) + window (Tensor, optional): the optional window function. + Default: ``None`` (treated as window of all :math:`1` s) + center (bool, optional): whether to pad :attr:`input` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + Default: ``True`` + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. Default: ``"reflect"`` + normalized (bool, optional): controls whether to return the normalized STFT results + Default: ``False`` + onesided (bool, optional): controls whether to return half of results to + avoid redundancy for real inputs. + Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. + return_complex (bool, optional): whether to return a complex tensor, or + a real tensor with an extra last dimension for the real and + imaginary components. -Args: - input (Tensor): the input tensor - n_fft (int): size of Fourier transform - hop_length (int, optional): the distance between neighboring sliding window - frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) - win_length (int, optional): the size of window frame and STFT filter. - Default: ``None`` (treated as equal to :attr:`n_fft`) - window (Tensor, optional): the optional window function. - Default: ``None`` (treated as window of all :math:`1` s) - center (bool, optional): whether to pad :attr:`input` on both sides so - that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. - Default: ``True`` - pad_mode (string, optional): controls the padding method used when - :attr:`center` is ``True``. Default: ``"reflect"`` - normalized (bool, optional): controls whether to return the normalized STFT results - Default: ``False`` - onesided (bool, optional): controls whether to return half of results to - avoid redundancy for real inputs. - Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. - return_complex (bool, optional): whether to return a complex tensor, or - a real tensor with an extra last dimension for the real and - imaginary components. + Returns: + Tensor: A tensor containing the STFT result with shape described above -Returns: - Tensor: A tensor containing the STFT result with shape described above - -""") -# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 -stft.__module__ = "torch.functional" + """ + if has_torch_function_unary(input): + return handle_torch_function( + stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex) + # NOTE: Do not edit. This code will be removed once the forward-compatibility + # period is over for PR #73432 + if center: + signal_dim = input.dim() + extended_shape = [1] * (3 - signal_dim) + list(input.size()) + pad = int(n_fft // 2) + input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) + input = input.view(input.shape[-signal_dim:]) + return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined] + normalized, onesided, return_complex) istft = _add_docstr(