diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index ff976795b29..0113d9f0e33 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -272,6 +272,23 @@ void copy_device_to_device(TensorIterator& iter, AT_CUDA_CHECK(cudaGetLastError()); } +inline std::tuple getCopyParameters(const TensorIteratorBase& iter) { + size_t element_size = iter.tensor(0).element_size(); + if (iter.ndim() == 1) { + size_t width_in_bytes = element_size; + size_t src_pitch = iter.strides(1)[0]; + size_t dst_pitch = iter.strides(0)[0]; + size_t height = iter.shape()[0]; + return std::make_tuple(width_in_bytes, src_pitch, dst_pitch, height); + } else { + size_t width_in_bytes = iter.shape()[0] * element_size; + size_t src_pitch = iter.strides(1)[1]; + size_t dst_pitch = iter.strides(0)[1]; + size_t height = iter.shape()[1]; + return std::make_tuple(width_in_bytes, src_pitch, dst_pitch, height); + } +} + static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) { Device dst_device = iter.device(0); Device src_device = iter.device(1); @@ -289,11 +306,23 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) { } else if (dst_device.is_cuda() && src_device.is_cuda()) { // Copies between GPUs can use the copy kernel if P2P is supported return !p2p_enabled; - } else { - // The remaining cases require temporaries. For example, this includes - // non-contiguous copies between CPU and GPU. - return true; } + + //for cross-device copies we can use memcpy2d if conditions are satisfied + if (dst_device.is_cuda() != src_device.is_cuda() && same_dtype && iter.ndim() <= 2) { + // TensorIterator reorders strides so that the first one is the smallest + + if (iter.ndim() == 1 || iter.has_contiguous_first_dim()) { + auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters(iter); + if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) { + return false; // No need for temporaries + } + } + } + + // The remaining cases require temporaries. For example, this includes + // non-contiguous copies between CPU and GPU. + return true; } static bool maybe_enable_p2p_access(Device dst_device, Device src_device) { @@ -374,11 +403,28 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) { void* dst = iter.data_ptr(0); void* src = iter.data_ptr(1); - int64_t nbytes = iter.numel() * iter.element_size(0); CUDAStream stream = getCurrentCUDAStream(); + int64_t nbytes = 0; + int64_t width_in_bytes = -1; + int64_t src_pitch = -1; + int64_t dst_pitch = -1; + int64_t height = -1; + if (iter.is_contiguous()) { + nbytes = iter.numel() * iter.element_size(0); + } else { + // the only non-contiguous iter situation that can happen here is + // acceptable for 2d copy, this has been vetted in requires_temporaries + std::tie(width_in_bytes, src_pitch, dst_pitch, height) = getCopyParameters(iter); + } + if (non_blocking) { - AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); + if (width_in_bytes == -1) { + AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); + } else { + AT_CUDA_CHECK(cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width_in_bytes, height, kind, stream)); + } + // we use both the storage context and the tensor data pointer as the key // for the caching host allocator. This allows us to better attribute the // events to the original tensor allocation correctly. The cases we seek to @@ -399,7 +445,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) { CachingHostAllocator_recordEvent(ptr, ctx, stream); } else { - at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream); + at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream, width_in_bytes, src_pitch, dst_pitch, height); } if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) { diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 192fafbad10..ac7e7d72b23 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -73,13 +73,21 @@ C10_CUDA_API __inline__ WarningState& warning_state() { return warning_state_; } // the subsequent functions are defined in the header because for performance -// reasons we want them to be inline +// reasons we want them to be inline. +// performs contiguous or 2D cudaMemcpy and synchronizes afterwards +// if width_in_bytes is not -1, 2d copy is performed and all 2d params are +// expected to be set to valid values, no additional checks are performed other +// than by cuda call itself C10_CUDA_API void __inline__ memcpy_and_sync( void* dst, const void* src, int64_t nbytes, cudaMemcpyKind kind, - cudaStream_t stream) { + cudaStream_t stream, + int64_t width_in_bytes = -1, + int64_t src_pitch = -1, + int64_t dst_pitch = -1, + int64_t height = -1) { if (C10_UNLIKELY( warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { warn_or_error_on_sync(); @@ -89,12 +97,18 @@ C10_CUDA_API void __inline__ memcpy_and_sync( (*interp)->trace_gpu_stream_synchronization( c10::kCUDA, reinterpret_cast(stream)); } + if (width_in_bytes == -1) { #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) - C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); + C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); #else - C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); - C10_CUDA_CHECK(cudaStreamSynchronize(stream)); + C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); #endif + } else { + C10_CUDA_CHECK(cudaMemcpy2DAsync( + dst, dst_pitch, src, src_pitch, width_in_bytes, height, kind, stream)); + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); + } } C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { diff --git a/test/test_cuda.py b/test/test_cuda.py index 8377265ede4..3d28ed24263 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -358,22 +358,23 @@ class TestCuda(TestCase): self.assertEqual(len(str(uuid)), 36) # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx self.assertEqual(len(uuid.bytes), 16) - def test_copy_non_blocking(self): - def _test_copy_non_blocking(a, b): - event = torch.cuda.Event() - a.copy_(b, non_blocking=True) + def _test_copy(self, a, b, non_blocking): + event = torch.cuda.Event() + a.copy_(b, non_blocking=non_blocking) + if non_blocking: event.record() event.synchronize() - self.assertEqual(a, b) + self.assertEqual(a.contiguous(), b.contiguous(), atol=0, rtol=0) + def test_copy_non_blocking(self): # 10MB copies x = torch.ones(10000000, dtype=torch.uint8).cuda() y = torch.zeros(10000000, dtype=torch.uint8).pin_memory() - _test_copy_non_blocking(x, y) + self._test_copy(x, y, non_blocking=True) x = torch.zeros(10000000, dtype=torch.uint8).pin_memory() y = torch.ones(10000000, dtype=torch.uint8).cuda() - _test_copy_non_blocking(x, y) + self._test_copy(x, y, non_blocking=True) # Test the case where the pinned data_ptr is not equal to the storage data_ptr. x_base = torch.zeros(10000000, dtype=torch.uint8).pin_memory() @@ -381,9 +382,11 @@ class TestCuda(TestCase): self.assertTrue(x.is_pinned()) self.assertTrue(x_base.is_pinned()) self.assertNotEqual(x_base.data_ptr(), x.data_ptr()) - self.assertEqual(x_base.storage().data_ptr(), x.storage().data_ptr()) + self.assertEqual( + x_base.untyped_storage().data_ptr(), x.untyped_storage().data_ptr() + ) y = torch.ones(10000000 - 1, dtype=torch.uint8).cuda() - _test_copy_non_blocking(x, y) + self._test_copy(x, y, non_blocking=True) def test_copy_non_blocking_type_conversion(self): a = torch.ones(1, device="cuda") @@ -394,6 +397,96 @@ class TestCuda(TestCase): c.copy_(b, non_blocking=True) self.assertEqual(a, c, exact_dtype=False) + def test_copy_2d(self): + # 1d + def _test_copy_shape(shape, slice): + for dst_device, non_blocking in product(("cuda", "cpu"), (True, False)): + src_device = "cpu" if dst_device == "cuda" else "cuda" + src = torch.randint(8, shape, device=src_device).__getitem__(slice) + dst = torch.empty_like(src, device=dst_device) + if non_blocking: + if src_device == "cpu": + src = src.pin_memory() + else: + dst = dst.pin_memory() + self._test_copy(dst, src, non_blocking) + dst = torch.empty(shape, device=dst_device).__getitem__(slice) + src = torch.randint_like(dst, 8, device=src_device) + if non_blocking: + if src_device == "cpu": + src = src.pin_memory() + else: + dst = dst.pin_memory() + self._test_copy(dst, src, non_blocking) + + _test_copy_shape((12800000,), slice(None, None, 2)) + _test_copy_shape((4, 5), (slice(None, None, None), slice(None, 4, None))) + _test_copy_shape( + (4, 5, 6), + (slice(None, None, None), slice(None, 4, None), slice(None, None, None)), + ) + _test_copy_shape( + (4, 5, 6), + (slice(None, None, None), slice(None, None, None), slice(None, 4, None)), + ) + _test_copy_shape( + (4, 5, 6, 8), + ( + slice(None, None, None), + slice(None, 4, None), + slice(None, None, None), + slice(None, None, None), + ), + ) + + def test_copy_2d_complex(self): + for dst_device, non_blocking, conj in product( + ("cuda", "cpu"), (True, False), (True, False) + ): + src_device = "cpu" if dst_device == "cuda" else "cuda" + if dst_device == "cpu" and non_blocking and conj: + continue # FiXME this is also broken for contiguous tensors + src = torch.randn((8,), dtype=torch.complex64, device=src_device)[::2] + dst = torch.zeros_like(src, device=dst_device) + if non_blocking: + if src_device == "cpu": + src = src.pin_memory() + else: + dst = dst.pin_memory() + if conj: + src = src.conj() + self._test_copy(dst, src, non_blocking) + dst = torch.empty((8,), dtype=torch.complex64, device=dst_device)[::2] + src = torch.randn_like(dst, device=src_device) + if non_blocking: + if src_device == "cpu": + src = src.pin_memory() + else: + dst = dst.pin_memory() + if conj: + src = src.conj() + + self._test_copy(dst, src, non_blocking) + + def test_copy_broadcast(self): + # broadcasted copies should not take 2d path, + # it would error out with cuda invalid param errors + def _test_copy_shape(shape_dst, shape_src): + for dst_device, non_blocking in product(("cuda", "cpu"), (True, False)): + src_device = "cpu" if dst_device == "cuda" else "cuda" + src = torch.randint(8, shape_src, device=src_device) + dst = torch.empty(shape_dst, dtype=torch.int64, device=dst_device) + if non_blocking: + if src_device == "cpu": + src = src.pin_memory() + else: + dst = dst.pin_memory() + self._test_copy(dst, src.expand_as(dst), non_blocking) + + _test_copy_shape((128,), (1,)) + _test_copy_shape((128, 128), (128, 1)) + _test_copy_shape((128, 128), (1, 128)) + @serialTest() def test_to_non_blocking(self): stream = torch.cuda.current_stream()