diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index b612d37e123..c1784ef8b45 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -826,7 +826,7 @@ static Stream& write_opt(Stream& SS, const std::optional& value) { Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional hop_lengthOpt, const std::optional win_lengthOpt, const std::optional& window_opt, const bool center, std::string_view mode, const bool normalized, - const std::optional onesidedOpt, const std::optional return_complexOpt) { + const std::optional onesidedOpt, const std::optional return_complexOpt, const std::optional align_to_windowOpt) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned window_maybe_owned = at::borrow_from_optional_tensor(window_opt); const Tensor& window = *window_maybe_owned; @@ -853,11 +853,14 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional> 2); @@ -869,7 +872,6 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional hop_lengthOpt, const std::optional win_lengthOpt, const std::optional& window_opt, const bool normalized, - const std::optional onesidedOpt, const std::optional return_complexOpt) { + const std::optional onesidedOpt, const std::optional return_complexOpt, + const std::optional align_to_windowOpt) { return at::stft( self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt, /*center=*/false, /*mode=*/"constant", normalized, onesidedOpt, - return_complexOpt); + return_complexOpt, align_to_windowOpt); } // Create complex tensor from the old style of real tensor with size=(..., 2) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e412752a1dc..4eb313cf5bd 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5758,11 +5758,11 @@ - func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) # Overload without center & pad mode, needed for forward-compatibility -- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor +- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor variants: function, method cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized'] -- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor +- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor variants: function, method - func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 76f128f4369..154c36832f7 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -1226,6 +1226,14 @@ class TestFFT(TestCase): with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): y = x.stft(10, pad_mode='constant') + @onlyNativeDeviceTypes + @skipCPUIfNoFFT + def test_stft_align_to_window_only_requires_non_center(self, device): + x = torch.rand(100) + for align_to_window in [True, False]: + with self.assertRaisesRegex(RuntimeError, 'stft align_to_window should only be set when center = false'): + y = x.stft(10, center=True, return_complex=True, align_to_window=align_to_window) + # stft and istft are currently warning if a window is not provided @onlyNativeDeviceTypes @skipCPUIfNoFFT diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 200039ce81a..3db1dad6c59 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3362,6 +3362,7 @@ def stft( normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, ) -> Tensor: torch._check( window is None or window.device == input.device, @@ -3370,6 +3371,10 @@ def stft( + f" and window on {window.device}" # type: ignore[union-attr] ), ) + torch._check( + not center or align_to_window is None, + "stft only supports align_to_window for center = False.", + ) hop_length_ = hop_length if hop_length is not None else n_fft // 4 win_length_ = win_length if win_length is not None else n_fft @@ -3433,6 +3438,9 @@ def stft( window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left]) input = input.unfold(dimension=-1, size=n_fft, step=hop_length_) + if not center and align_to_window: + input_pad_amount = (n_fft - win_length_) // 2 + input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode) if window is not None: input = input * window diff --git a/torch/_tensor.py b/torch/_tensor.py index f33951e0488..1af81357cd6 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -940,6 +940,7 @@ class Tensor(torch._C.TensorBase): normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, ): r"""See :func:`torch.stft` @@ -961,6 +962,7 @@ class Tensor(torch._C.TensorBase): normalized=normalized, onesided=onesided, return_complex=return_complex, + align_to_window=align_to_window, ) return torch.stft( self, @@ -973,6 +975,7 @@ class Tensor(torch._C.TensorBase): normalized, onesided, return_complex=return_complex, + align_to_window=align_to_window, ) def istft( diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 351a2886bd9..c5f7aa1983c 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -6398,7 +6398,8 @@ See :func:`torch.dsplit` add_docstr_all( "stft", r""" -stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor +stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, + pad_end=0, align_to_window=None) -> Tensor See :func:`torch.stft` """, diff --git a/torch/functional.py b/torch/functional.py index 852661253a9..6201dc7b63c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -551,6 +551,7 @@ def stft( normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, ) -> Tensor: r"""Short-time Fourier transform (STFT). @@ -698,6 +699,11 @@ def stft( normalized=normalized, onesided=onesided, return_complex=return_complex, + align_to_window=align_to_window, + ) + if center and align_to_window is not None: + raise RuntimeError( + "stft align_to_window should only be set when center = false" ) # NOTE: Do not edit. This code will be removed once the forward-compatibility # period is over for PR #73432 @@ -716,6 +722,7 @@ def stft( normalized, onesided, return_complex, + align_to_window, ) diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index ed5cd93ab05..bcf80058fe2 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -98,7 +98,7 @@ def _compute_edge_sizes(n_fft, window_size): @_onnx_symbolic("aten::stft") -@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b") +@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b") def stft( g: jit_utils.GraphContext, input: _C.Value, @@ -109,6 +109,7 @@ def stft( normalized: bool = False, onesided: Optional[bool] = True, return_complex: Optional[bool] = False, + align_to_window: Optional[bool] = None, ) -> _C.Value: """Associates `torch.stft` with the `STFT` ONNX operator. Note that torch.stft calls _VF.stft, without centering or padding options. @@ -137,6 +138,12 @@ def stft( msg="STFT does not currently support complex types", value=input ) + if align_to_window is not None: + raise errors.SymbolicValueError( + msg="STFT does not currently support the align_to_window option", + value=input, + ) # TODO(#145944): add compatibility with align_to_window option. + # Get STFT sizes frame_step_value = hop_length if hop_length is not None else n_fft // 4 frame_step_const = g.op( diff --git a/torch/overrides.py b/torch/overrides.py index 56249324252..97fc86af920 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1130,7 +1130,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.std: lambda input, dim=None: -1, torch.std_mean: lambda input, dim=None: -1, torch.stft: ( - lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950 + lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 # noqa: B950 ), torch.sub: lambda input, other, out=None: -1, torch.subtract: lambda input, other, out=None: -1,