From 2c5bf12584a8ec359cbce34fac73fb2bc3cd0af0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 9 May 2022 19:59:43 +0000 Subject: [PATCH] Revert "stft: remove non-center overload and python functional wrapper" This reverts commit d23ecbfc9ac157560611b242f015743f189dbf48. Reverted https://github.com/pytorch/pytorch/pull/73434 on behalf of https://github.com/albanD --- aten/src/ATen/autocast_mode.cpp | 3 +- aten/src/ATen/native/SpectralOps.cpp | 20 ++ aten/src/ATen/native/native_functions.yaml | 7 +- caffe2/serialize/versions.h | 6 +- .../check_forward_backward_compatibility.py | 1 - test/jit/fixtures/test_versioned_stft_v10.ptl | Bin 2630 -> 0 bytes test/jit/fixtures_srcs/fixtures_src.py | 8 - test/jit/fixtures_srcs/generate_models.py | 1 - test/jit/test_save_load_for_op_version.py | 17 -- tools/pyi/gen_pyi.py | 1 + torch/_tensor.py | 36 +++- torch/_tensor_docs.py | 11 +- torch/csrc/jit/mobile/upgrader_mobile.cpp | 33 --- .../operator_upgraders/upgraders_entry.cpp | 11 - .../jit/operator_upgraders/version_map.cpp | 6 +- torch/functional.py | 202 ++++++++++-------- 16 files changed, 177 insertions(+), 186 deletions(-) delete mode 100644 test/jit/fixtures/test_versioned_stft_v10.ptl diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index c6d0260229b..2f0eefa88be 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -541,7 +541,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional, bool, c10::string_view), fp32) KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) - KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::string_view, bool, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::string_view, bool, c10::optional, c10::optional), fp32) KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional), fp32) KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional), fp32) KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional), fp32) diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 5b9b273e923..af000cc70d9 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -907,6 +907,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop } } +Tensor stft( + const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, + const optional win_lengthOpt, const c10::optional& window_opt, + const bool normalized, + const optional onesidedOpt, const optional return_complexOpt) { + return at::stft( + self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt, + /*center=*/false, /*mode=*/"constant", normalized, onesidedOpt, + return_complexOpt); +} + // Create complex tensor from the old style of real tensor with size=(..., 2) // This is to support istft in the transition to requiring complex input. // NOTE: This may return a view of the input tensor, or might clone if necessary @@ -1100,6 +1111,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho #undef REPR } +Tensor istft(const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, + const optional win_lengthOpt, const Tensor& window, + const bool center, const bool normalized, const optional onesidedOpt, + const optional lengthOpt) { + return at::native::istft( + self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized, + onesidedOpt, lengthOpt, /*return_complex=*/false); +} + void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { const auto input_sizes = input.sizes(); const auto input_strides = input.strides(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5b93adfdc7a..2eb18bed11f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4320,7 +4320,12 @@ - func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) -- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor +# Overload without center & pad mode, needed for forward-compatibility +- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor + variants: function, method + cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized'] + +- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor variants: function, method - func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 88b0c19b095..78a91c64fe8 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -12,7 +12,7 @@ namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; #if ENABLE_UPGRADERS -constexpr uint64_t kMaxSupportedFileFormatVersion = 11; +constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL; #else constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; #endif @@ -83,9 +83,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; // Bump the version number to 10 to update aten::gelu and // and aten::gelu.out to support the new approximate kwarg. // (see: https://github.com/pytorch/pytorch/pull/61439) -// 4) [02/25/2022] -// Bump version number to 11 to update aten::stft to do padding in ATen -constexpr uint64_t kProducedFileFormatVersion = 11L; +constexpr uint64_t kProducedFileFormatVersion = 0xAL; #else constexpr uint64_t kProducedFileFormatVersion = 0x3L; #endif diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 861f2260457..28e501a5c77 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -118,7 +118,6 @@ ALLOW_LIST = [ ("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)), ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)), ("aten::scatter_reduce.two", datetime.date(2022, 4, 15)), - ("aten::stft", datetime.date(2022, 6, 1)), ("aten::_s_where", datetime.date(2022, 9, 30)), ("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)), ("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)), diff --git a/test/jit/fixtures/test_versioned_stft_v10.ptl b/test/jit/fixtures/test_versioned_stft_v10.ptl deleted file mode 100644 index 7dcb8cc8f715f26fe90d445210f3910a349817b3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2630 zcmWIWW@cev;NW1u016CF3?-?>CGlmcMa7x#8Tt9yg-m{>NrlXM0p9E!x;wMO zGk_{VH~{Ex5um>@J#NNeg45s0`6;RTaQEsXyH~HE(#^?$qfum0 zh2GOef4{HJnB%+r=#Dd1@(LnRHoU(@de}S`e(br?ZsDzZ{)EZ5@7vyrFc?VNsunL? zdq$#3q9bx!RoBg_FP~*TeE9Qw_BFG2Uwce9rY|)s-Y#fWae2w@-Huij8$16WVGg~& zPTJUA&DlN|WKlp8wXpkIrD zKpUf};cpRNAD(QE3N^>g?2i{deAXzfXb_uqL+qB~2J^?Q>K}S-B<7!LS`)4AtA6?Y z|MXk?D(y;*j=Ma+!dJTN+|ptr?Y(82^O-w3xg-*Ucw1N1@bYq>kPw(?+-CTYiP@+i z@MNNN{s}4bB|5X(x@|QDTx7CSJhav=aCv-z!_`$qStB&EeT}$kk!F|LoYHz6wmAjX`|B}4u-IRN~zWzM@?#?eCrrCFHt86>D z*tg8Aa?{DrF*TRgcUA>^X1c%m+^xFj*J zq!?0E2xE(%1}05#-ptkK39TT*37Dky(KEK|c~8J3c&fytXA z!Wd{NJXfQXLVDOrA-%NxB2WQS$effdUNe5g>2+`iHp|vUoEA`57$U%phTqIv;;W7w1A&u(}K= z#p2BZQpM`c0HHu;0C^y-5llcm{sDz-V7UlpATP)@q>vrVEoqD31PbIAq!uLt3)$jA zjtDLwKd~e=&&mo^niq0L0Qn4sT;3e*tq{e@Kwp4kt&qFZucN&U?5L8;g4E(d9?iB0 zUZC=#)bz~alGLL3;>@blLSAo{wi2*CiACw9xv9ViFXYnzdoeFDH?@!-%q&jLNh=g+ z1G7P<6$(aZ1C1gf*b9Xqwx^_~C6?v@E26}l($qp>zqa-a{s>?qWvD0=Y0D4<2Y*Ov zUU7a=p=etMR2~$vdGTpLAB%z1fPyJAucS~Mt^gEi<(YXY`Q?QYKouEKt-ugu2g($s z0xKq9>`G?vXK*1&x3;ylw}BII&+JJHT!C=`!nhN#JcAz21PrZb-JGMj$G?0fFw`=p&A}y8lt8|F3`y@_rIa>?U> zVn;U6XXM#|s3_3G2f5@?MX^l{m`icnhE%qp8;P72R8WkZg2zbY0u0?mt5IEg1FkAwqcpw1k1GAoh$`Bv`<*@*7R Tensor" r""" +stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor See :func:`torch.stft` - -.. warning:: - This function changed signature at version 0.4.1. Calling with - the previous signature may cause error or return incorrect result. """) add_docstr_all('istft', - "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " - "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor" r""" +istft(n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None) -> Tensor See :func:`torch.istft` """) diff --git a/torch/csrc/jit/mobile/upgrader_mobile.cpp b/torch/csrc/jit/mobile/upgrader_mobile.cpp index eed4a676a9c..0e52829255d 100644 --- a/torch/csrc/jit/mobile/upgrader_mobile.cpp +++ b/torch/csrc/jit/mobile/upgrader_mobile.cpp @@ -67,10 +67,6 @@ getOperatorVersionMapForMobile() { std::vector({ Upgrader({0, 8, "logspace_out_0_8", 10}) })}, - {std::string("aten::stft"), - std::vector({ - Upgrader({0, 10, "stft_0_10", 11}) - })}, }); return operatorVersionMapForMobile; } @@ -531,35 +527,6 @@ const std::vector& getUpgraderBytecodeList() { OperatorString({"prim::unchecked_cast", "", 1}), }), // operators list }), - ByteCodeFunctionWithOperator({ - mobile::Function::registerFunc( - "stft_0_10", - std::vector({ - Instruction{OpCode::STOREN, 1, 8}, - Instruction{OpCode::MOVE, 1, 0}, - Instruction{OpCode::MOVE, 2, 0}, - Instruction{OpCode::MOVE, 3, 0}, - Instruction{OpCode::MOVE, 4, 0}, - Instruction{OpCode::MOVE, 5, 0}, - Instruction{OpCode::LOADC, 1, 0}, - Instruction{OpCode::LOADC, 0, 0}, - Instruction{OpCode::MOVE, 6, 0}, - Instruction{OpCode::MOVE, 7, 0}, - Instruction{OpCode::MOVE, 8, 0}, - Instruction{OpCode::OP, 0, 0}, - Instruction{OpCode::RET, 0, 0}, - }), // instructions list, - std::vector({ - c10::IValue("reflect"), - c10::IValue(false), - }), // constants list, - std::vector(), // types list, - 8 - ), - std::vector({ - OperatorString({"aten::stft", "", 10}), - }), // operators list - }), }); for (const auto& upgrader_function : upgrader_function_list) { for (const auto& op : upgrader_function.operators) { diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp index e50227d18ae..7b09cc409a4 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp @@ -15,17 +15,6 @@ namespace torch { namespace jit { static std::unordered_map kUpgradersEntryMap({ - {"stft_0_10", R"SCRIPT( -def stft_0_10( - self: Tensor, n_fft: int, hop_length: Optional[int] = None, - win_length: Optional[int] = None, window: Optional[Tensor] = None, - normalized: bool = False, onesided: Optional[bool] = None, - return_complex: Optional[bool] = None) -> Tensor: - return torch.stft( - self, n_fft=n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=False, normalized=normalized, onesided=onesided, - return_complex=return_complex) -)SCRIPT"}, {"logspace_0_8", R"SCRIPT( def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): diff --git a/torch/csrc/jit/operator_upgraders/version_map.cpp b/torch/csrc/jit/operator_upgraders/version_map.cpp index d96527b66fc..1e19f4cc39d 100644 --- a/torch/csrc/jit/operator_upgraders/version_map.cpp +++ b/torch/csrc/jit/operator_upgraders/version_map.cpp @@ -16,11 +16,7 @@ static bool isVersionMapSorted = false; // Note for developers: The list of upgraders should be SORTED // by the version number where the upgrader is registered. static std::unordered_map> operatorVersionMap( - {{"aten::stft", - {{11, - "stft_0_10", - "aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}}, - {"aten::logspace", + {{"aten::logspace", {{9, "logspace_0_8", "aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}}, diff --git a/torch/functional.py b/torch/functional.py index dad5fc63dc0..29a66f7f160 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -4,7 +4,7 @@ from typing import ( import torch from torch._C import _add_docstr -import torch.nn.functional +import torch.nn.functional as F from ._lowrank import svd_lowrank, pca_lowrank from .overrides import ( has_torch_function, has_torch_function_unary, has_torch_function_variadic, @@ -478,121 +478,133 @@ def _meshgrid(*tensors, indexing: Optional[str]): return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] -stft = _add_docstr( - torch.stft, - "stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " - "pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor" - r""" +def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: Optional[Tensor] = None, + center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None) -> Tensor: + r"""Short-time Fourier transform (STFT). -Short-time Fourier transform (STFT). + .. warning:: + From version 1.8.0, :attr:`return_complex` must always be given + explicitly for real inputs and `return_complex=False` has been + deprecated. Strongly prefer `return_complex=True` as in a future + pytorch release, this function will only return complex tensors. -.. warning:: - From version 1.8.0, :attr:`return_complex` must always be given - explicitly for real inputs and `return_complex=False` has been - deprecated. Strongly prefer `return_complex=True` as in a future - pytorch release, this function will only return complex tensors. + Note that :func:`torch.view_as_real` can be used to recover a real + tensor with an extra last dimension for real and imaginary components. - Note that :func:`torch.view_as_real` can be used to recover a real - tensor with an extra last dimension for real and imaginary components. + The STFT computes the Fourier transform of short overlapping windows of the + input. This giving frequency components of the signal as they change over + time. The interface of this function is modeled after (but *not* a drop-in + replacement for) librosa_ stft function. -The STFT computes the Fourier transform of short overlapping windows of the -input. This giving frequency components of the signal as they change over -time. The interface of this function is modeled after (but *not* a drop-in -replacement for) librosa_ stft function. + .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html -.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html + Ignoring the optional batch dimension, this method computes the following + expression: -Ignoring the optional batch dimension, this method computes the following -expression: + .. math:: + X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% + \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % + \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right), -.. math:: - X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% - \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % - \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right), + where :math:`m` is the index of the sliding window, and :math:`\omega` is + the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, + or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. -where :math:`m` is the index of the sliding window, and :math:`\omega` is -the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, -or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. + * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time + sequences. -* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time - sequences. + * If :attr:`hop_length` is ``None`` (default), it is treated as equal to + ``floor(n_fft / 4)``. -* If :attr:`hop_length` is ``None`` (default), it is treated as equal to - ``floor(n_fft / 4)``. + * If :attr:`win_length` is ``None`` (default), it is treated as equal to + :attr:`n_fft`. -* If :attr:`win_length` is ``None`` (default), it is treated as equal to - :attr:`n_fft`. + * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from + :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is + treated as if having :math:`1` everywhere in the window. If + :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on + both sides to length :attr:`n_fft` before being applied. -* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from - :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is - treated as if having :math:`1` everywhere in the window. If - :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on - both sides to length :attr:`n_fft` before being applied. + * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on + both sides so that the :math:`t`-th frame is centered at time + :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame + begins at time :math:`t \times \text{hop\_length}`. -* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on - both sides so that the :math:`t`-th frame is centered at time - :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame - begins at time :math:`t \times \text{hop\_length}`. + * :attr:`pad_mode` determines the padding method used on :attr:`input` when + :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for + all available options. Default is ``"reflect"``. -* :attr:`pad_mode` determines the padding method used on :attr:`input` when - :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for - all available options. Default is ``"reflect"``. + * If :attr:`onesided` is ``True`` (default for real input), only values for + :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor + \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because + the real-to-complex Fourier transform satisfies the conjugate symmetry, + i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. + Note if the input or window tensors are complex, then :attr:`onesided` + output is not possible. -* If :attr:`onesided` is ``True`` (default for real input), only values for - :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor - \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because - the real-to-complex Fourier transform satisfies the conjugate symmetry, - i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. - Note if the input or window tensors are complex, then :attr:`onesided` - output is not possible. + * If :attr:`normalized` is ``True`` (default is ``False``), the function + returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. -* If :attr:`normalized` is ``True`` (default is ``False``), the function - returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. + * If :attr:`return_complex` is ``True`` (default if input is complex), the + return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, + the output is a ``input.dim() + 2`` dimensional real tensor where the last + dimension represents the real and imaginary components. -* If :attr:`return_complex` is ``True`` (default if input is complex), the - return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, - the output is a ``input.dim() + 2`` dimensional real tensor where the last - dimension represents the real and imaginary components. + Returns either a complex tensor of size :math:`(* \times N \times T)` if + :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N + \times T \times 2)`. Where :math:`*` is the optional batch size of + :attr:`input`, :math:`N` is the number of frequencies where STFT is applied + and :math:`T` is the total number of frames used. -Returns either a complex tensor of size :math:`(* \times N \times T)` if -:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N -\times T \times 2)`. Where :math:`*` is the optional batch size of -:attr:`input`, :math:`N` is the number of frequencies where STFT is applied -and :math:`T` is the total number of frames used. + .. warning:: + This function changed signature at version 0.4.1. Calling with the + previous signature may cause error or return incorrect result. -.. warning:: - This function changed signature at version 0.4.1. Calling with the - previous signature may cause error or return incorrect result. + Args: + input (Tensor): the input tensor + n_fft (int): size of Fourier transform + hop_length (int, optional): the distance between neighboring sliding window + frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) + win_length (int, optional): the size of window frame and STFT filter. + Default: ``None`` (treated as equal to :attr:`n_fft`) + window (Tensor, optional): the optional window function. + Default: ``None`` (treated as window of all :math:`1` s) + center (bool, optional): whether to pad :attr:`input` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + Default: ``True`` + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. Default: ``"reflect"`` + normalized (bool, optional): controls whether to return the normalized STFT results + Default: ``False`` + onesided (bool, optional): controls whether to return half of results to + avoid redundancy for real inputs. + Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. + return_complex (bool, optional): whether to return a complex tensor, or + a real tensor with an extra last dimension for the real and + imaginary components. -Args: - input (Tensor): the input tensor - n_fft (int): size of Fourier transform - hop_length (int, optional): the distance between neighboring sliding window - frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) - win_length (int, optional): the size of window frame and STFT filter. - Default: ``None`` (treated as equal to :attr:`n_fft`) - window (Tensor, optional): the optional window function. - Default: ``None`` (treated as window of all :math:`1` s) - center (bool, optional): whether to pad :attr:`input` on both sides so - that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. - Default: ``True`` - pad_mode (string, optional): controls the padding method used when - :attr:`center` is ``True``. Default: ``"reflect"`` - normalized (bool, optional): controls whether to return the normalized STFT results - Default: ``False`` - onesided (bool, optional): controls whether to return half of results to - avoid redundancy for real inputs. - Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. - return_complex (bool, optional): whether to return a complex tensor, or - a real tensor with an extra last dimension for the real and - imaginary components. + Returns: + Tensor: A tensor containing the STFT result with shape described above -Returns: - Tensor: A tensor containing the STFT result with shape described above - -""") -# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 -stft.__module__ = "torch.functional" + """ + if has_torch_function_unary(input): + return handle_torch_function( + stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex) + # NOTE: Do not edit. This code will be removed once the forward-compatibility + # period is over for PR #73432 + if center: + signal_dim = input.dim() + extended_shape = [1] * (3 - signal_dim) + list(input.size()) + pad = int(n_fft // 2) + input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) + input = input.view(input.shape[-signal_dim:]) + return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined] + normalized, onesided, return_complex) istft = _add_docstr(