mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fixing the following issue when compiling the following program:
```
window = torch.hann_window(N_FFT).to(x.device)
stft = torch.stft(
x, N_FFT, HOP_LENGTH, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2
return magnitudes
```
```
Traceback (most recent call last):
File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 57, in testPartExecutor
yield
File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 623, in run
self._callTestMethod(testMethod)
File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 579, in _callTestMethod
if method() is not None:
^^^^^^^^
File "/home/zhxchen17/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
method(*args, **kwargs)
File "/home/zhxchen17/pytorch/test/inductor/test_torchinductor.py", line 12356, in new_test
return value(self)
^^^^^^^^^^^
File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor.py", line 4334, in test_stft
self.check_model(model, example_inputs)
File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 185, in check_model
actual = AOTIRunnerUtil.run(
^^^^^^^^^^^^^^^^^^^
File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 137, in run
optimized = AOTIRunnerUtil.load(device, so_path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 119, in load
return torch._export.aot_load(so_path, device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zhxchen17/pytorch/torch/_export/__init__.py", line 165, in aot_load
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected extern kernel aten::hann_window to have serialized argument type as_scalar_type for argument 1 but got as_device
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146263
Approved by: https://github.com/angelayi
|
||
|---|---|---|
| .. | ||
| aoti_eager | ||
| aoti_include | ||
| aoti_package | ||
| aoti_runner | ||
| aoti_runtime | ||
| aoti_torch | ||
| cpp_wrapper | ||
| array_ref_impl.h | ||
| inductor_ops.cpp | ||
| inductor_ops.h | ||
| resize_storage_bytes.cpp | ||