mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Generalize pin memory logic for accelerator when non blocking copy happened (#143783)
# Motivation fix https://github.com/pytorch/pytorch/issues/143641 Generalize pin memory logic for accelerator when non-blocking copy happened. Each accelerator has its implementation on `empty_strided`. The accelerator which doesn't have pin memory mechanism could ignore or mimic when pin_out is True. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143783 Approved by: https://github.com/EikanWang, https://github.com/albanD ghstack dependencies: #144959
This commit is contained in:
parent
28b6430823
commit
719938c77f
2 changed files with 11 additions and 2 deletions
|
|
@ -343,7 +343,8 @@ Tensor _to_copy(
|
|||
}
|
||||
|
||||
bool pin_out =
|
||||
(non_blocking && (self.is_cuda() || self.is_privateuseone()) &&
|
||||
(non_blocking &&
|
||||
at::accelerator::isAcceleratorExcluded(self.device().type(), at::kMPS) &&
|
||||
options.device().is_cpu() && (options.layout() == c10::kStrided));
|
||||
|
||||
if (memory_format == MemoryFormat::Preserve) {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_MPS, TestCase
|
||||
|
||||
|
||||
if not torch.accelerator.is_available():
|
||||
|
|
@ -102,6 +102,14 @@ class TestAccelerator(TestCase):
|
|||
self.assertEqual(torch.accelerator.current_stream(), src_prev_stream)
|
||||
self.assertEqual(torch.accelerator.current_stream(dst_device), dst_prev_stream)
|
||||
|
||||
@unittest.skipIf(TEST_MPS, "MPS doesn't support pin memory!")
|
||||
def test_pin_memory_on_non_blocking_copy(self):
|
||||
t_acc = torch.randn(100).to(torch.accelerator.current_accelerator())
|
||||
t_host = t_acc.to("cpu", non_blocking=True)
|
||||
torch.accelerator.synchronize()
|
||||
self.assertTrue(t_host.is_pinned())
|
||||
self.assertEqual(t_acc.cpu(), t_host)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in a new issue