diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 07a8896210..67b52b466f 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1315,6 +1315,67 @@ template void BufferExpansionKernelLauncher(const int32_t* input, int chunk_size, cudaStream_t stream); +// Support head_size up to 128 +constexpr unsigned int kTileSize = 32; +constexpr unsigned int kSeqTileSize = 16; + +__global__ void ReorderPastStatesKernel(float4* out_buffer, + const float4* in_buffer, + int batch_size, + int num_heads, + int max_length, + int chunked_head_size) { + __shared__ float4 tile[kSeqTileSize][kTileSize + 1]; + + const int b = blockIdx.z; + const int n = blockIdx.y; + const int s_base = blockIdx.x * kSeqTileSize; + const int s = s_base + threadIdx.y; + const int base_offset = (b * num_heads + n) * max_length * chunked_head_size; + + if (s < max_length) { + const int in_offset = base_offset + s * chunked_head_size + threadIdx.x; + tile[threadIdx.y][threadIdx.x] = in_buffer[in_offset]; + } + + __syncthreads(); + + const int tidx = threadIdx.x + threadIdx.y * chunked_head_size; + const int tidx_x = tidx % kSeqTileSize; + const int tidx_y = tidx / kSeqTileSize; + + const int s2 = s_base + tidx_x; + + if (s2 < max_length) { + const int out_offset = base_offset + tidx_y * max_length + s2; + out_buffer[out_offset] = tile[tidx_x][tidx_y]; + } +} + +void ReorderPastStatesKernelLauncher(void* out_buffer, + const void* in_buffer, + int batch_size, + int num_heads, + int max_length, + int head_size, + int chunk_size, + cudaStream_t stream) { + //[B, N, max_length, H2(head_size/chunk_size), equv_chunk_size] -> [B, N, H2(head_size/chunk_size), max_length, equv_chunk_size] + const int chunked_head_size = head_size / chunk_size; + const dim3 block(chunked_head_size, kSeqTileSize); + const dim3 grid((max_length + kSeqTileSize - 1) / kSeqTileSize, num_heads, batch_size); + if (chunk_size == 4 || chunk_size == 8) { + ReorderPastStatesKernel<<>>(reinterpret_cast(out_buffer), + reinterpret_cast(in_buffer), + batch_size, + num_heads, + max_length, + chunked_head_size); + } else { + ORT_THROW("ReorderPastStatesKernelLauncher only support float or half"); + } +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 8c52f6fd52..2c3662fb18 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -213,6 +213,14 @@ void BufferExpansionKernelLauncher(const T* input, int chunk_size, cudaStream_t stream); +void ReorderPastStatesKernelLauncher(void* out_buffer, + const void* in_buffer, + int batch_size, + int num_heads, + int max_length, + int head_size, + int chunk_size, + cudaStream_t stream); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e4de33499c..121cd05956 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -56,19 +56,23 @@ namespace GenerationCudaDeviceHelper { // It might be better to forcefully require the same type since cast node generates // extra overhead. Status ReorderPastState( - const void* cuda_device_prop, + const void*, Tensor& past_state, Tensor& past_state_staging, Stream* stream) { ORT_ENFORCE(stream); cudaStream_t cuda_stream = reinterpret_cast(stream->GetHandle()); - cublasHandle_t cublas_handle = static_cast(stream)->cublas_handle_; const auto& past_state_shape = past_state.Shape(); const auto& past_state_dims = past_state_shape.GetDims(); const bool packed_past = past_state_dims.size() == 5; + size_t batch_size = packed_past ? past_state_dims[1] : past_state_dims[0]; + size_t num_heads = packed_past ? past_state_dims[2] : past_state_dims[1]; + size_t max_length = packed_past ? past_state_dims[3] : past_state_dims[2]; + size_t head_size = packed_past ? past_state_dims[4] : past_state_dims[3]; + // Copy the 'K' values into the temp staging buffer size_t past_state_size = packed_past ? past_state.SizeInBytes() / 2 : past_state.SizeInBytes(); void* past_state_staging_buffer = past_state_staging.MutableDataRaw(); @@ -79,27 +83,16 @@ Status ReorderPastState( // [B, N, head_size / x, max_length, x], where x = 16 / sizeof(T) int64_t chunk_size = static_cast(16 / past_state.DataType()->Size()); - std::vector permutation_vector = {0, 1, 3, 2, 4}; - gsl::span permutation(permutation_vector.data(), 5); + cuda::ReorderPastStatesKernelLauncher(past_state.MutableDataRaw(), + past_state_staging_buffer, + static_cast(batch_size), + static_cast(num_heads), + static_cast(max_length), + static_cast(head_size), + static_cast(chunk_size), + cuda_stream); - // "Fake" the shapes of the input and output tensors of the Transpose operation to suit our need - size_t offset = packed_past ? 1 : 0; - TensorShape transpose_input_shape_override = {past_state_shape[offset], - past_state_shape[offset + 1], - past_state_shape[offset + 2], - past_state_shape[offset + 3] / chunk_size, - chunk_size}; - - TensorShape transpose_output_shape_override = {past_state_shape[offset], past_state_shape[offset + 1], - past_state_shape[offset + 3] / chunk_size, past_state_shape[offset + 2], - chunk_size}; - - // TODO(hasesh): Explore perf tuning for this Transpose operation - return onnxruntime::cuda::Transpose::DoTranspose(*static_cast(cuda_device_prop), cuda_stream, - cublas_handle, permutation, - past_state_staging, past_state, - &transpose_input_shape_override, - &transpose_output_shape_override); + return Status::OK(); } Status InitCacheIndir(Tensor& cache_indir, Stream* stream) {