diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index 7229c7b522..80f5851164 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -292,32 +292,32 @@ Status ComputeSoftmax( const dim3 grid(sequence_length * num_heads, batch_size, 1); if (all_sequence_length <= 32) { const int blockSize = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 64) { const int blockSize = 64; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 128) { const int blockSize = 128; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 256) { const int blockSize = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 512) { const int blockSize = 512; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (!is_unidirectional) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, add_before_softmax, input, output); + SoftmaxKernel<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output); } else { ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); } @@ -403,39 +403,39 @@ Status ComputeSoftmaxWithMask1D( if (all_sequence_length <= 32) { const int blockSize = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + MaskedSoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 64) { const int blockSize = 64; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + MaskedSoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 128) { const int blockSize = 128; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + MaskedSoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 256) { const int blockSize = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + MaskedSoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 512) { const int blockSize = 512; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + MaskedSoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + MaskedSoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (!is_unidirectional) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output); + MaskedSoftmaxKernel<<>>( + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output); } else { ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); } @@ -465,46 +465,46 @@ Status ComputeSoftmaxWithRawMask(hipStream_t stream, T* out = use_persistent_softmax ? persistent_softmax_workspace : output; if (all_sequence_length <= 32) { const int blockSize = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax); + SoftmaxWithRawMaskSmallKernel<<>>( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 64) { const int blockSize = 64; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax); + SoftmaxWithRawMaskSmallKernel<<>>( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 128) { const int blockSize = 128; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax); + SoftmaxWithRawMaskSmallKernel<<>>( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 256) { const int blockSize = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax); + SoftmaxWithRawMaskSmallKernel<<>>( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 512) { const int blockSize = 512; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax); + SoftmaxWithRawMaskSmallKernel<<>>( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax); + SoftmaxWithRawMaskSmallKernel<<>>( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else { ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); } diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu index 0ef3e0af93..4df7e0b5a5 100644 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ b/onnxruntime/core/providers/rocm/fpgeneric.cu @@ -58,7 +58,7 @@ rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocbla dim3 dimGrid((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1); dim3 dimBlock(TRANS_TILE_DIM, BLOCK_ROWS, 1); - hipLaunchKernelGGL(transposeNoOverlap, dim3(dimGrid), dim3(dimBlock), 0, stream, C, A, n, m); + transposeNoOverlap<<>>(C, A, n, m); } else { return rocblas_status_not_implemented; } @@ -68,7 +68,7 @@ rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocbla rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const half* x, int incx, half* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - hipLaunchKernelGGL(CopyVectorHalf, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n); + CopyVectorHalf<<>>(x, incx, y, incy, n); return rocblas_status_success; } @@ -76,6 +76,6 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons onnxruntime::BFloat16* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - hipLaunchKernelGGL(CopyVectorBFloat16, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n); + CopyVectorBFloat16<<>>(x, incx, y, incy, n); return rocblas_status_success; -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu index 7b8630568a..94bee88a46 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu @@ -63,21 +63,21 @@ void DiagonalImpl( switch (element_size) { case sizeof(int32_t): - hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, + _DiagonalKernel<<>>( reinterpret_cast::MappedType*>(input_data), input_rank, dim_1, dim_2, input_strides, reinterpret_cast::MappedType*>(output_data), output_strides, output_size); break; case sizeof(int64_t): - hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, + _DiagonalKernel<<>>( reinterpret_cast::MappedType*>(input_data), input_rank, dim_1, dim_2, input_strides, reinterpret_cast::MappedType*>(output_data), output_strides, output_size); break; case sizeof(int16_t): - hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, + _DiagonalKernel<<>>( reinterpret_cast(input_data), input_rank, dim_1, dim_2, input_strides, reinterpret_cast(output_data), output_strides, output_size); diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu index 4addee44bc..f5a26ef045 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu +++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu @@ -52,37 +52,37 @@ void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { case 0: // 1 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 1: // 2 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 2: // 4 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 3: // 8 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 4: // 16 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 5: // 32 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 6: // 64 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 7: // 128 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 8: // 256 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 9: // 512 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; case 10: // 1024 - hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, stream, dst, src, batch_count, softmax_elements_stride, softmax_elements); + softmax_warp_forward<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); break; default: break; diff --git a/onnxruntime/core/providers/rocm/rocm_utils.cu b/onnxruntime/core/providers/rocm/rocm_utils.cu index ef65d70eea..cbf410e78a 100644 --- a/onnxruntime/core/providers/rocm/rocm_utils.cu +++ b/onnxruntime/core/providers/rocm/rocm_utils.cu @@ -30,7 +30,7 @@ template void Fill(hipStream_t stream, T* output, T value, int64_t count) { int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); HIP_LONG N = static_cast(count); - hipLaunchKernelGGL(HIP_KERNEL_NAME(_Fill), dim3(blocksPerGrid), dim3(GridDim::maxThreadsPerBlock), 0, stream, output, value, N); + _Fill<<>>(output, value, N); } template class ConstantBufferImpl : public IConstantBuffer {