[BeamSearch]optimize key cache reordering (#17771)

### Description
<!-- Describe your changes. --> 

Replace
onnxruntime::cuda::Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim()
with custom transpose kernel in ReorderPastState(). The original
implementation doesn't benefit from vectorized loading and coalesced
accessing(write). and not fully utilize threads in the block.

benchmarked with TNLGv4 model(batch=4, seq_len=4K)
transpose kernel speed up: ~1.9X (392 μs -> 206 μs)
overall reordering speedup: ~1.48X

Latency:
before:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/34c7ab73-3da1-4c41-a036-e9fb6a966891)
after:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/337818ec-9598-4d8a-9e9b-7215b6862498)

GPU matrix:
before:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/4962248f-703c-49bd-8586-deaeccd9bce0)
after:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/a795a892-4c5d-432d-8375-0bb67385d2bc)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Ye Wang 2023-10-05 17:29:11 +00:00 committed by GitHub
parent e1a089c23c
commit 0e988239cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 22 deletions

View file

@ -1315,6 +1315,67 @@ template void BufferExpansionKernelLauncher(const int32_t* input,
int chunk_size,
cudaStream_t stream);
// Support head_size up to 128
constexpr unsigned int kTileSize = 32;
constexpr unsigned int kSeqTileSize = 16;
__global__ void ReorderPastStatesKernel(float4* out_buffer,
const float4* in_buffer,
int batch_size,
int num_heads,
int max_length,
int chunked_head_size) {
__shared__ float4 tile[kSeqTileSize][kTileSize + 1];
const int b = blockIdx.z;
const int n = blockIdx.y;
const int s_base = blockIdx.x * kSeqTileSize;
const int s = s_base + threadIdx.y;
const int base_offset = (b * num_heads + n) * max_length * chunked_head_size;
if (s < max_length) {
const int in_offset = base_offset + s * chunked_head_size + threadIdx.x;
tile[threadIdx.y][threadIdx.x] = in_buffer[in_offset];
}
__syncthreads();
const int tidx = threadIdx.x + threadIdx.y * chunked_head_size;
const int tidx_x = tidx % kSeqTileSize;
const int tidx_y = tidx / kSeqTileSize;
const int s2 = s_base + tidx_x;
if (s2 < max_length) {
const int out_offset = base_offset + tidx_y * max_length + s2;
out_buffer[out_offset] = tile[tidx_x][tidx_y];
}
}
void ReorderPastStatesKernelLauncher(void* out_buffer,
const void* in_buffer,
int batch_size,
int num_heads,
int max_length,
int head_size,
int chunk_size,
cudaStream_t stream) {
//[B, N, max_length, H2(head_size/chunk_size), equv_chunk_size] -> [B, N, H2(head_size/chunk_size), max_length, equv_chunk_size]
const int chunked_head_size = head_size / chunk_size;
const dim3 block(chunked_head_size, kSeqTileSize);
const dim3 grid((max_length + kSeqTileSize - 1) / kSeqTileSize, num_heads, batch_size);
if (chunk_size == 4 || chunk_size == 8) {
ReorderPastStatesKernel<<<grid, block, 0, stream>>>(reinterpret_cast<float4*>(out_buffer),
reinterpret_cast<const float4*>(in_buffer),
batch_size,
num_heads,
max_length,
chunked_head_size);
} else {
ORT_THROW("ReorderPastStatesKernelLauncher only support float or half");
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -213,6 +213,14 @@ void BufferExpansionKernelLauncher(const T* input,
int chunk_size,
cudaStream_t stream);
void ReorderPastStatesKernelLauncher(void* out_buffer,
const void* in_buffer,
int batch_size,
int num_heads,
int max_length,
int head_size,
int chunk_size,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -56,19 +56,23 @@ namespace GenerationCudaDeviceHelper {
// It might be better to forcefully require the same type since cast node generates
// extra overhead.
Status ReorderPastState(
const void* cuda_device_prop,
const void*,
Tensor& past_state,
Tensor& past_state_staging,
Stream* stream) {
ORT_ENFORCE(stream);
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream->GetHandle());
cublasHandle_t cublas_handle = static_cast<CudaStream*>(stream)->cublas_handle_;
const auto& past_state_shape = past_state.Shape();
const auto& past_state_dims = past_state_shape.GetDims();
const bool packed_past = past_state_dims.size() == 5;
size_t batch_size = packed_past ? past_state_dims[1] : past_state_dims[0];
size_t num_heads = packed_past ? past_state_dims[2] : past_state_dims[1];
size_t max_length = packed_past ? past_state_dims[3] : past_state_dims[2];
size_t head_size = packed_past ? past_state_dims[4] : past_state_dims[3];
// Copy the 'K' values into the temp staging buffer
size_t past_state_size = packed_past ? past_state.SizeInBytes() / 2 : past_state.SizeInBytes();
void* past_state_staging_buffer = past_state_staging.MutableDataRaw();
@ -79,27 +83,16 @@ Status ReorderPastState(
// [B, N, head_size / x, max_length, x], where x = 16 / sizeof(T)
int64_t chunk_size = static_cast<int64_t>(16 / past_state.DataType()->Size());
std::vector<size_t> permutation_vector = {0, 1, 3, 2, 4};
gsl::span<size_t> permutation(permutation_vector.data(), 5);
cuda::ReorderPastStatesKernelLauncher(past_state.MutableDataRaw(),
past_state_staging_buffer,
static_cast<int>(batch_size),
static_cast<int>(num_heads),
static_cast<int>(max_length),
static_cast<int>(head_size),
static_cast<int>(chunk_size),
cuda_stream);
// "Fake" the shapes of the input and output tensors of the Transpose operation to suit our need
size_t offset = packed_past ? 1 : 0;
TensorShape transpose_input_shape_override = {past_state_shape[offset],
past_state_shape[offset + 1],
past_state_shape[offset + 2],
past_state_shape[offset + 3] / chunk_size,
chunk_size};
TensorShape transpose_output_shape_override = {past_state_shape[offset], past_state_shape[offset + 1],
past_state_shape[offset + 3] / chunk_size, past_state_shape[offset + 2],
chunk_size};
// TODO(hasesh): Explore perf tuning for this Transpose operation
return onnxruntime::cuda::Transpose::DoTranspose(*static_cast<const cudaDeviceProp*>(cuda_device_prop), cuda_stream,
cublas_handle, permutation,
past_state_staging, past_state,
&transpose_input_shape_override,
&transpose_output_shape_override);
return Status::OK();
}
Status InitCacheIndir(Tensor& cache_indir, Stream* stream) {