From ca9b3f18e9cb5a2ce686d5b13bfc76164c2a5c2a Mon Sep 17 00:00:00 2001 From: Weixing Zhang Date: Sun, 25 Apr 2021 01:18:56 -0700 Subject: [PATCH] Explicitly pass cuda stream to thrust function rather than use cuda default stream implicitly (#7414) * Pass cuda stream to thrust function to not use default stream. In the commit 299ace0, ORT has been changed to not use cuda default stream. * update amd_hipify.py * remove un-necessary stream sync Co-authored-by: Weixing Zhang --- .../cuda/object_detection/non_max_suppression_impl.cu | 3 +-- orttraining/orttraining/training_ops/cuda/reduction/all.cc | 3 --- orttraining/orttraining/training_ops/cuda/reduction/all.cu | 2 +- tools/ci_build/amd_hipify.py | 1 + 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu index deb1e60e0f..ec98e00413 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu @@ -395,13 +395,12 @@ Status NonMaxSuppressionImpl( // STEP 2. filter boxes by scores int limited_num_boxes = num_boxes; if (pc.score_threshold_ != nullptr) { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); thrust::device_ptr sorted_scores_device_ptr(d_sorted_scores); limited_num_boxes = thrust::count_if( + thrust::cuda::par.on(stream), sorted_scores_device_ptr, sorted_scores_device_ptr + num_boxes, DeviceGreaterThan(score_threshold)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(0)); CUDA_RETURN_IF_ERROR(cudaGetLastError()); if (limited_num_boxes == 0) { diff --git a/orttraining/orttraining/training_ops/cuda/reduction/all.cc b/orttraining/orttraining/training_ops/cuda/reduction/all.cc index 9e1c282667..53b80154b2 100644 --- a/orttraining/orttraining/training_ops/cuda/reduction/all.cc +++ b/orttraining/orttraining/training_ops/cuda/reduction/all.cc @@ -25,14 +25,11 @@ Status All::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(size <= std::numeric_limits::max(), "Number of reduced elements (", size, ") exceeds the max allowed value (", std::numeric_limits::max(), ")."); - // TODO: LaunchAllKernel is implemented with thrust, which always uses default CUDA stream. - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream())); LaunchAllKernel( Stream(), input.Data(), static_cast(size), output.MutableData()); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(0)); return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cuda/reduction/all.cu b/orttraining/orttraining/training_ops/cuda/reduction/all.cu index 678d01893d..3647da6903 100644 --- a/orttraining/orttraining/training_ops/cuda/reduction/all.cu +++ b/orttraining/orttraining/training_ops/cuda/reduction/all.cu @@ -22,7 +22,7 @@ __global__ void assign_false(bool* ptr) { template<> void LaunchAllKernel(cudaStream_t stream, const bool* data, const int size, bool* output) { - if(thrust::all_of(thrust::device, data, data + size, thrust::identity())) { + if(thrust::all_of(thrust::cuda::par.on(stream), data, data + size, thrust::identity())) { assign_true<<<1, 1, 0, stream>>>(output); } else diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index c869fff2c9..93c5f86964 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -231,6 +231,7 @@ def hipify(src_file_path, dst_file_path): s = s.replace('CUDA_KERNEL_ASSERT', 'HIP_KERNEL_ASSERT') s = s.replace('CUDA_CALL', 'HIP_CALL') s = s.replace('SliceCuda', 'SliceRocm') + s = s.replace('thrust::cuda', 'thrust::hip') s = s.replace('cuda', 'rocm') # s = s.replace('Cuda', 'Rocm') s = s.replace('CUDA', 'ROCM')