diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index fa3955ce85..44d683700e 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -21,6 +21,12 @@ enum class SimpleBroadcast : int32_t { RightPerChannelBatchN = (int32_t)-5, }; +enum class BroadcastIndexType : int32_t { + NoBroadcast = (int32_t)0, + Scalar = (int32_t)1, + NeedCompute = (int32_t)2, +}; + template class IConstantBuffer { public: diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index da5492d842..4be5578b28 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -73,6 +73,9 @@ struct TernaryElementwisePreparation { TArray b_padded_strides; // for b shape == output shape, this is nullptr TArray c_padded_strides; // for c shape == output shape, this is nullptr TArray fdm_output_strides; + BroadcastIndexType a_index_type = BroadcastIndexType::NoBroadcast; + BroadcastIndexType b_index_type = BroadcastIndexType::NoBroadcast; + BroadcastIndexType c_index_type = BroadcastIndexType::NoBroadcast; TernaryElementwisePreparation(const Tensor* a, const Tensor* b, const Tensor* c) : a_tensor(a), b_tensor(b), c_tensor(c) {} @@ -108,16 +111,34 @@ struct TernaryElementwisePreparation { } }; - if (a_shape != output_shape) { + bool has_need_compute = false; + if (a_shape.Size() == 1) { + a_index_type = BroadcastIndexType::Scalar; + } else if (a_shape != output_shape) { padder(a_rank, a_shape, a_padded_strides); + a_index_type = BroadcastIndexType::NeedCompute; + has_need_compute = true; } - if (b_shape != output_shape) { - padder(b_rank, b_shape, b_padded_strides); + if (b_shape.Size() == 1) { + b_index_type = BroadcastIndexType::Scalar; + } else if (b_shape != output_shape) { + padder(b_rank, b_shape, b_padded_strides); + b_index_type = BroadcastIndexType::NeedCompute; + has_need_compute = true; } - if (c_shape != output_shape) { + if (c_shape.Size() == 1) { + c_index_type = BroadcastIndexType::Scalar; + } else if (c_shape != output_shape) { padder(c_rank, c_shape, c_padded_strides); + c_index_type = BroadcastIndexType::NeedCompute; + has_need_compute = true; + } + + if (!has_need_compute) { + output_rank_or_simple_broadcast = static_cast(SimpleBroadcast::NoBroadcast); + return Status::OK(); } TensorPitches output_pitches(output_shape.GetDims()); @@ -154,10 +175,13 @@ Status Where::ComputeInternal(OpKernelContext* context) const { WhereImpl( prepare.output_rank_or_simple_broadcast, + prepare.a_index_type, prepare.a_padded_strides, reinterpret_cast(prepare.a_tensor->template Data()), + prepare.b_index_type, prepare.b_padded_strides, reinterpret_cast(prepare.b_tensor->template Data()), + prepare.c_index_type, prepare.c_padded_strides, reinterpret_cast(prepare.c_tensor->template Data()), prepare.fdm_output_strides, diff --git a/onnxruntime/core/providers/cuda/tensor/where_impl.cu b/onnxruntime/core/providers/cuda/tensor/where_impl.cu index f9b475318b..438c4f1eac 100644 --- a/onnxruntime/core/providers/cuda/tensor/where_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/where_impl.cu @@ -10,7 +10,7 @@ namespace onnxruntime { namespace cuda { // broadcast by computing output coordinate from offset, using fast_divmod -template +template __global__ void _TenaryElementWise( size_t output_rank, const TArray cond_padded_strides, @@ -22,184 +22,208 @@ __global__ void _TenaryElementWise( const TArray fdm_output_strides, T* output_data, CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - CUDA_LONG cond_index = (cond_need_compute ? 0 : id); - CUDA_LONG x_index = (x_need_compute ? 0 : id); - CUDA_LONG y_index = (y_need_compute ? 0 : id); + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + bool cond_value[NumElementsPerThread]; + T x_value[NumElementsPerThread]; + T y_value[NumElementsPerThread]; - // compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md - CUDA_LONG offset = id; - for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) { - if (dim >= output_rank) { - break; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + // compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md + CUDA_LONG cond_index = (CondIndexType == BroadcastIndexType::NoBroadcast ? id : 0); + CUDA_LONG x_index = (XIndexType == BroadcastIndexType::NoBroadcast ? id : 0); + CUDA_LONG y_index = (YIndexType == BroadcastIndexType::NoBroadcast ? id : 0); + CUDA_LONG offset = id; +#pragma unroll + for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) { + if (dim >= output_rank) { + break; + } + + int q, r; + fdm_output_strides[dim].divmod(offset, q, r); + + if (CondIndexType == BroadcastIndexType::NeedCompute) { + cond_index += static_cast(cond_padded_strides[dim]) * q; + } + + if (XIndexType == BroadcastIndexType::NeedCompute) { + x_index += static_cast(x_padded_strides[dim]) * q; + } + + if (YIndexType == BroadcastIndexType::NeedCompute) { + y_index += static_cast(y_padded_strides[dim]) * q; + } + + offset = r; + } + + cond_value[i] = cond_data[cond_index]; + x_value[i] = x_data[x_index]; + y_value[i] = y_data[y_index]; + id += NumThreadsPerBlock; } - - int q, r; - fdm_output_strides[dim].divmod(offset, q, r); - - if (cond_need_compute) { - cond_index += static_cast(cond_padded_strides[dim]) * q; - } - - if (x_need_compute) { - x_index += static_cast(x_padded_strides[dim]) * q; - } - - if (y_need_compute) { - y_index += static_cast(y_padded_strides[dim]) * q; - } - - offset = r; } - output_data[id] = cond_data[cond_index] ? x_data[x_index] : y_data[y_index]; + id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + output_data[id] = cond_value[i] ? x_value[i] : y_value[i]; + id += NumThreadsPerBlock; + } + } } // for scalar broadcast or non-broadcast case -template +template __global__ void _TenaryElementWiseSimple( const bool* cond_data, const T* x_data, const T* y_data, T* output_data, CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - output_data[id] = cond_data[id] ? x_data[id] : y_data[id]; + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + bool cond_value[NumElementsPerThread]; + T x_value[NumElementsPerThread]; + T y_value[NumElementsPerThread]; + + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + cond_value[i] = cond_data[CondIndexType == BroadcastIndexType::NoBroadcast ? id : 0]; + x_value[i] = x_data[XIndexType == BroadcastIndexType::NoBroadcast ? id : 0]; + y_value[i] = y_data[YIndexType == BroadcastIndexType::NoBroadcast ? id : 0]; + id += NumThreadsPerBlock; + } + } + + id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + output_data[id] = cond_value[i] ? x_value[i] : y_value[i]; + id += NumThreadsPerBlock; + } + } } +#define HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \ + case Y_INDEX_TYPE: { \ + _TenaryElementWiseSimple \ + <<>>(cond_data, \ + x_data, \ + y_data, \ + output_data, \ + N); \ + } break + +#define HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \ + case X_INDEX_TYPE: { \ + switch(Y_INDEX_TYPE_VAL) { \ + HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \ + HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \ + } \ + } break + +#define HANDLE_COND_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \ + case COND_INDEX_TYPE: { \ + switch(X_INDEX_TYPE_VAL) { \ + HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \ + HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \ + } \ + } break + +#define HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \ + case Y_INDEX_TYPE: { \ + _TenaryElementWise \ + <<>>(output_rank_or_simple_broadcast, \ + cond_padded_strides, \ + cond_data, \ + x_padded_strides, \ + x_data, \ + y_padded_strides, \ + y_data, \ + fdm_output_strides, \ + output_data, \ + N); \ + } break + +#define HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \ + case X_INDEX_TYPE: { \ + switch(Y_INDEX_TYPE_VAL) { \ + HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \ + HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \ + HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NeedCompute); \ + } \ + } break + +#define HANDLE_COND_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \ + case COND_INDEX_TYPE: { \ + switch(X_INDEX_TYPE_VAL) { \ + HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \ + HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \ + HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NeedCompute, Y_INDEX_TYPE_VAL); \ + } \ + } break + template void WhereImpl( size_t output_rank_or_simple_broadcast, + BroadcastIndexType cond_index_type, const TArray& cond_padded_strides, const bool* cond_data, + BroadcastIndexType x_index_type, const TArray& x_padded_strides, const T* x_data, + BroadcastIndexType y_index_type, const TArray& y_padded_strides, const T* y_data, const TArray& fdm_output_strides, T* output_data, size_t count) { - int blocksPerGrid = (int)(ceil(static_cast(count) / GridDim::maxThreadsPerBlock)); + int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); CUDA_LONG N = static_cast(count); - if (output_rank_or_simple_broadcast == static_cast(SimpleBroadcast::NoBroadcast)) { - _TenaryElementWiseSimple<<>>( - cond_data, - x_data, - y_data, - output_data, - N); + switch(cond_index_type) { + HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type); + HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::Scalar, x_index_type, y_index_type); + } } else { - if (cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } else if (cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } else if (cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } else if (!cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } else if (cond_padded_strides.size_ && !x_padded_strides.size_ && !y_padded_strides.size_) { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } else if (!cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } else if (!cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } else { - _TenaryElementWise<<>>( - output_rank_or_simple_broadcast, - cond_padded_strides, - cond_data, - x_padded_strides, - x_data, - y_padded_strides, - y_data, - fdm_output_strides, - output_data, - N); - } + switch(cond_index_type) { + HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type); + HANDLE_COND_INDEX_TYPE(BroadcastIndexType::Scalar, x_index_type, y_index_type); + HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NeedCompute, x_index_type, y_index_type); + } } } -#define SPECIALIZED_IMPL(T) \ - template void WhereImpl(size_t output_rank_or_simple_broadcast, \ - const TArray& cond_padded_strides, \ - const bool* cond_data, \ - const TArray& x_padded_strides, \ - const T* x_data, \ - const TArray& y_padded_strides, \ - const T* y_data, \ - const TArray& fdm_output_strides, \ - T* output_data, \ +#define SPECIALIZED_IMPL(T) \ + template void WhereImpl(size_t output_rank_or_simple_broadcast, \ + BroadcastIndexType cond_index_type, \ + const TArray& cond_padded_strides, \ + const bool* cond_data, \ + BroadcastIndexType x_index_type, \ + const TArray& x_padded_strides, \ + const T* x_data, \ + BroadcastIndexType y_index_type, \ + const TArray& y_padded_strides, \ + const T* y_data, \ + const TArray& fdm_output_strides, \ + T* output_data, \ size_t count); SPECIALIZED_IMPL(uint8_t) @@ -209,4 +233,4 @@ SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(half) } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/where_impl.h b/onnxruntime/core/providers/cuda/tensor/where_impl.h index aa7e98b4d6..24cf54f351 100644 --- a/onnxruntime/core/providers/cuda/tensor/where_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/where_impl.h @@ -12,10 +12,13 @@ namespace cuda { template void WhereImpl( size_t output_rank_or_simple_broadcast, + BroadcastIndexType cond_index_type, const TArray& cond_padded_strides, const bool* cond_data, + BroadcastIndexType x_index_type, const TArray& x_padded_strides, const T* x_data, + BroadcastIndexType y_index_type, const TArray& y_padded_strides, const T* y_data, const TArray& fdm_output_strides,