mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Explicitly pass cuda stream to thrust function rather than use cuda default stream implicitly (#7414)
* Pass cuda stream to thrust function to not use default stream.
In the commit 299ace0, ORT has been changed to not use cuda default stream.
* update amd_hipify.py
* remove un-necessary stream sync
Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
This commit is contained in:
parent
b9cbbc41ff
commit
ca9b3f18e9
4 changed files with 3 additions and 6 deletions
|
|
@ -395,13 +395,12 @@ Status NonMaxSuppressionImpl(
|
|||
// STEP 2. filter boxes by scores
|
||||
int limited_num_boxes = num_boxes;
|
||||
if (pc.score_threshold_ != nullptr) {
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
|
||||
thrust::device_ptr<float> sorted_scores_device_ptr(d_sorted_scores);
|
||||
limited_num_boxes = thrust::count_if(
|
||||
thrust::cuda::par.on(stream),
|
||||
sorted_scores_device_ptr,
|
||||
sorted_scores_device_ptr + num_boxes,
|
||||
DeviceGreaterThan(score_threshold));
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(0));
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
|
||||
if (limited_num_boxes == 0) {
|
||||
|
|
|
|||
|
|
@ -25,14 +25,11 @@ Status All<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
ORT_ENFORCE(size <= std::numeric_limits<int>::max(), "Number of reduced elements (",
|
||||
size, ") exceeds the max allowed value (", std::numeric_limits<int>::max(), ").");
|
||||
|
||||
// TODO: LaunchAllKernel is implemented with thrust, which always uses default CUDA stream.
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream()));
|
||||
LaunchAllKernel(
|
||||
Stream(),
|
||||
input.Data<T>(),
|
||||
static_cast<int>(size),
|
||||
output.MutableData<bool>());
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ __global__ void assign_false(bool* ptr) {
|
|||
|
||||
template<>
|
||||
void LaunchAllKernel(cudaStream_t stream, const bool* data, const int size, bool* output) {
|
||||
if(thrust::all_of(thrust::device, data, data + size, thrust::identity<bool>())) {
|
||||
if(thrust::all_of(thrust::cuda::par.on(stream), data, data + size, thrust::identity<bool>())) {
|
||||
assign_true<<<1, 1, 0, stream>>>(output);
|
||||
}
|
||||
else
|
||||
|
|
|
|||
|
|
@ -231,6 +231,7 @@ def hipify(src_file_path, dst_file_path):
|
|||
s = s.replace('CUDA_KERNEL_ASSERT', 'HIP_KERNEL_ASSERT')
|
||||
s = s.replace('CUDA_CALL', 'HIP_CALL')
|
||||
s = s.replace('SliceCuda', 'SliceRocm')
|
||||
s = s.replace('thrust::cuda', 'thrust::hip')
|
||||
s = s.replace('cuda', 'rocm')
|
||||
# s = s.replace('Cuda', 'Rocm')
|
||||
s = s.replace('CUDA', 'ROCM')
|
||||
|
|
|
|||
Loading…
Reference in a new issue