mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR re-implements pin memory aiming to get rid of the optional `device` argument and makes all related APIs to be device-agnostic. We add two new abstract APIs in [AcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/detail/AcceleratorHooksInterface.h#L12) and redefine pin memory as: "Pin memory is always pinned for the current accelerator device". In detail, it uses [getAcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Context.h#L61) in pin_memory/is_pinned to get an appropriate device and invoke the corresponding overridden interfaces, instead of using BackendSelect and then dispatching to CUDA or other specific backends' implement methods. Note: For new backends who want to implement and use pin memory, just inherit AcceleratorHooksInterface and overwrite the `isPinnedPtr` and `getPinnedMemoryAllocator` methods. Additional context: To avoid BC-breaking, this PR just preserves the `device` arg of related APIs and would throw a deprecation warning if `device` arg is passed. Another PR will be submitted to update all PT callers (`Tensor.is_pinned()`, `Tensor.pin_memory()`...) not to pass this arg based on this PR. In future, `device` arg will be actually removed. Relates #124908 Relates #14560 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126376 Approved by: https://github.com/albanD |
||
|---|---|---|
| .. | ||
| no_python_abi_suffix_test | ||
| self_compiler_include_dirs_test | ||
| torch_test_cpp_extension | ||
| cpp_c10d_extension.cpp | ||
| cpp_c10d_extension.hpp | ||
| cpp_frontend_extension.cpp | ||
| cublas_extension.cpp | ||
| cuda_dlink_extension.cpp | ||
| cuda_dlink_extension_add.cu | ||
| cuda_dlink_extension_add.cuh | ||
| cuda_dlink_extension_kernel.cu | ||
| cuda_extension.cpp | ||
| cuda_extension.cu | ||
| cuda_extension_kernel.cu | ||
| cuda_extension_kernel2.cu | ||
| cudnn_extension.cpp | ||
| cusolver_extension.cpp | ||
| dangling_impl_extension.cpp | ||
| doubler.h | ||
| extension.cpp | ||
| identity.cpp | ||
| jit_extension.cpp | ||
| jit_extension2.cpp | ||
| maia_extension.cpp | ||
| mps_extension.mm | ||
| mtia_extension.cpp | ||
| open_registration_extension.cpp | ||
| rng_extension.cpp | ||
| setup.py | ||
| torch_library.cu | ||