Generalize pin memory logic for accelerator when non blocking copy happened (#143783)

# Motivation
fix https://github.com/pytorch/pytorch/issues/143641
Generalize pin memory logic for accelerator when non-blocking copy happened. Each accelerator has its implementation on `empty_strided`. The accelerator which doesn't have pin memory mechanism could ignore or mimic when pin_out is True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143783
Approved by: https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #144959
This commit is contained in:
Yu, Guangye 2025-01-17 18:09:29 +00:00 committed by PyTorch MergeBot
parent 28b6430823
commit 719938c77f
2 changed files with 11 additions and 2 deletions

View file

@ -343,7 +343,8 @@ Tensor _to_copy(
}
bool pin_out =
(non_blocking && (self.is_cuda() || self.is_privateuseone()) &&
(non_blocking &&
at::accelerator::isAcceleratorExcluded(self.device().type(), at::kMPS) &&
options.device().is_cpu() && (options.layout() == c10::kStrided));
if (memory_format == MemoryFormat::Preserve) {

View file

@ -4,7 +4,7 @@ import sys
import unittest
import torch
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_MPS, TestCase
if not torch.accelerator.is_available():
@ -102,6 +102,14 @@ class TestAccelerator(TestCase):
self.assertEqual(torch.accelerator.current_stream(), src_prev_stream)
self.assertEqual(torch.accelerator.current_stream(dst_device), dst_prev_stream)
@unittest.skipIf(TEST_MPS, "MPS doesn't support pin memory!")
def test_pin_memory_on_non_blocking_copy(self):
t_acc = torch.randn(100).to(torch.accelerator.current_accelerator())
t_host = t_acc.to("cpu", non_blocking=True)
torch.accelerator.synchronize()
self.assertTrue(t_host.is_pinned())
self.assertEqual(t_acc.cpu(), t_host)
if __name__ == "__main__":
run_tests()