From 8de885fdb1b070dc85e19c92f1dccdeaee482131 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Tue, 7 Feb 2023 09:03:14 -0800 Subject: [PATCH] reduce cuda library binary size (#14555) ### Description Reduce the cuda library size by: 1. refactoring beam_search_top_k to reduce template instantiation. It saves ~56MB 2. opt out TopK for type uint*, int8_t and int16_t. It saves ~50MB. ### Motivation and Context --- cmake/CMakeLists.txt | 1 + docs/OperatorKernels.md | 6 +-- .../cuda/transformers/beam_search_topk.cu | 53 +------------------ .../cuda/transformers/beam_search_topk.h | 12 ----- .../transformers/generation_device_helper.cc | 47 +++++++++------- onnxruntime/core/providers/cuda/math/topk.cc | 41 +++++++++----- .../core/providers/cuda/math/topk_impl_i16.cu | 5 -- .../core/providers/cuda/math/topk_impl_i8.cu | 5 -- .../core/providers/cuda/math/topk_impl_u16.cu | 5 -- .../core/providers/cuda/math/topk_impl_u32.cu | 5 -- .../core/providers/cuda/math/topk_impl_u64.cu | 5 -- .../core/providers/cuda/math/topk_impl_u8.cu | 5 -- 12 files changed, 61 insertions(+), 129 deletions(-) delete mode 100644 onnxruntime/core/providers/cuda/math/topk_impl_i16.cu delete mode 100644 onnxruntime/core/providers/cuda/math/topk_impl_i8.cu delete mode 100644 onnxruntime/core/providers/cuda/math/topk_impl_u16.cu delete mode 100644 onnxruntime/core/providers/cuda/math/topk_impl_u32.cu delete mode 100644 onnxruntime/core/providers/cuda/math/topk_impl_u64.cu delete mode 100644 onnxruntime/core/providers/cuda/math/topk_impl_u8.cu diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 57abcb04ba..9eab5a2a0f 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -603,6 +603,7 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() + endif() if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 286cad61d5..618214acae 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -768,9 +768,9 @@ Do not modify directly.* |||1+|**T** = tensor(double), tensor(float), tensor(float16)| |Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| -|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||10|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 9]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu index 5c54c03a05..dcbc733f2a 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu @@ -291,7 +291,7 @@ void LaunchBatchTopKKernel(const T* topk_scores, int32_t num_beams, int32_t k, cudaStream_t stream) { - ORT_ENFORCE(k <= 256, "LaunchBatchTopKKernel doesn't support k >= 256"); + ORT_ENFORCE(k <= 64, "LaunchBatchTopKKernel doesn't support k >= 64"); #define BatchTopKKernelLauncher(K) \ BatchTopKKernel<<>>(topk_scores, \ @@ -311,12 +311,8 @@ void LaunchBatchTopKKernel(const T* topk_scores, BatchTopKKernelLauncher(16); } else if (k <= 32) { BatchTopKKernelLauncher(32); - } else if (k <= 64) { - BatchTopKKernelLauncher(64); - } else if (k <= 128) { - BatchTopKKernelLauncher(128); } else { - BatchTopKKernelLauncher(256); + BatchTopKKernelLauncher(64); } } @@ -330,36 +326,6 @@ template void LaunchBatchTopKKernel(const float* topk_scores, int32_t k, cudaStream_t stream); -template void LaunchBatchTopKKernel(const float* topk_scores, - const int64_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - float* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - -template void LaunchBatchTopKKernel(const half* topk_scores, - const int32_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - half* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - -template void LaunchBatchTopKKernel(const half* topk_scores, - const int64_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - half* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - template void BeamSearchTopK( const T* input, @@ -426,21 +392,6 @@ template void BeamSearchTopK( int32_t* output_indices, cudaStream_t stream); -template void BeamSearchTopK( - const half* input, - int32_t batch_size, - int32_t num_beams, - int32_t vocab_size, - int32_t k, - half* tmp_values_1st_stage, - int32_t* tmp_indices_1st_stage, - half* tmp_values_2st_stage, - int32_t* tmp_indices_2st_stage, - half* output_values, - int32_t* output_tokens, - int32_t* output_indices, - cudaStream_t stream); - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h index 5e338b417e..096448c002 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h @@ -11,18 +11,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template -void LaunchBatchTopKKernel( - const T* topk_scores, - const I* topk_indices, - int32_t* next_indices, - int32_t* next_tokens, - T* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - template void BeamSearchTopK( const T* input, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 703bd6a0e9..3895d16d4d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -440,12 +440,16 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams); dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams); #endif + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), + beam_state->next_scores.data(), + beam_state->next_scores.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); } else { // Apply top-k selection like the following: // next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) // next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) - // int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size}; - int64_t next_token_scores_dims[] = {batch_size * num_beams, vocab_size}; + int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); auto element_type = DataTypeImpl::GetType(); @@ -460,31 +464,36 @@ Status ProcessLogits(const OrtValue& logits, // constexpr bool sorted = true; // results returned in sorted order. std::unique_ptr topk_scores = Tensor::CreateDefault(); - std::unique_ptr topk_tokens = Tensor::CreateDefault(); + std::unique_ptr topk_indices = Tensor::CreateDefault(); ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, ort_stream, thread_pool, - *topk_scores, *topk_tokens)); + *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION dumper->Print("topk_scores", *(topk_scores.get())); - dumper->Print("topk_tokens", *(topk_tokens.get())); + dumper->Print("topk_indices", *(topk_indices.get())); #endif - cuda::LaunchBatchTopKKernel(topk_scores->Data(), - topk_tokens->Data(), - beam_state->next_indices.data(), - beam_state->next_tokens.data(), - beam_state->next_scores.data(), - batch_size, - num_beams, - 2 * num_beams, - cuda_stream); + // Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following: + // next_indices = (next_tokens / vocab_size).long() + // next_tokens = next_tokens % vocab_size + const int64_t* next_token_indices = topk_indices->Data(); + cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), + batch_size, top_k, vocab_size, cuda_stream); + + const float* data = topk_scores->Data(); +#ifdef DEBUG_GENERATION + dumper->Print("next_scores before scorer", data, batch_size, top_k); + dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k); + dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k); +#endif + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), + data, + topk_scores->SizeInBytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); } - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), - beam_state->next_scores.data(), - beam_state->next_scores.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream)); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), beam_state->next_tokens.data(), beam_state->next_tokens.size_bytes(), diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index 7ea165c611..3b0edaa559 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -12,7 +12,12 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 1, 9, kCudaExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + (*KernelDefBuilder::Create()) + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), TopK); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -20,7 +25,14 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 10, 10, kCudaExecutionProvider, - (*KernelDefBuilder::Create()).InputMemoryType(OrtMemTypeCPUInput, 1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("I", DataTypeImpl::GetTensorType()), + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), TopK); ONNX_OPERATOR_KERNEL_EX( @@ -28,7 +40,14 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 11, kCudaExecutionProvider, - (*KernelDefBuilder::Create()).InputMemoryType(OrtMemTypeCPUInput, 1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("I", DataTypeImpl::GetTensorType()), + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), TopK); template @@ -42,11 +61,11 @@ TopK::TopK(const OpKernelInfo& info) : CudaKernel(info) { } #define IS_PRIM_TYPE(T) utils::IsPrimitiveDataType(prim_type) -#define TOPKIMPL(T) TopKImpl(this, ctx->GetComputeStream(), tensor_X->Data(), \ - static_cast(tensor_V->MutableDataRaw()), \ - static_cast(tensor_I->MutableDataRaw()), \ - elem_nums_cuda, \ - elem_nums.size(), \ +#define TOPKIMPL(T) TopKImpl(this, ctx->GetComputeStream(), tensor_X->Data(), \ + static_cast(tensor_V->MutableDataRaw()), \ + static_cast(tensor_I->MutableDataRaw()), \ + elem_nums_cuda, \ + elem_nums.size(), \ axis, K_, largest_, sorted_, N, dimension) template @@ -87,12 +106,6 @@ Status TopK::ComputeInternal(OpKernelContext* ctx) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator"); } - if (IS_PRIM_TYPE(uint8_t)) return TOPKIMPL(uint8_t); - if (IS_PRIM_TYPE(uint16_t)) return TOPKIMPL(uint16_t); - if (IS_PRIM_TYPE(uint32_t)) return TOPKIMPL(uint32_t); - if (IS_PRIM_TYPE(uint64_t)) return TOPKIMPL(uint64_t); - if (IS_PRIM_TYPE(int8_t)) return TOPKIMPL(int8_t); - if (IS_PRIM_TYPE(int16_t)) return TOPKIMPL(int16_t); if (IS_PRIM_TYPE(int32_t)) return TOPKIMPL(int32_t); if (IS_PRIM_TYPE(int64_t)) return TOPKIMPL(int64_t); if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16); diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu b/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu deleted file mode 100644 index e194bd1bfd..0000000000 --- a/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define TOPK_IMPL_TYPE int16_t -#include "topk_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu b/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu deleted file mode 100644 index db32e9e433..0000000000 --- a/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define TOPK_IMPL_TYPE int8_t -#include "topk_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_u16.cu b/onnxruntime/core/providers/cuda/math/topk_impl_u16.cu deleted file mode 100644 index c9ed54e832..0000000000 --- a/onnxruntime/core/providers/cuda/math/topk_impl_u16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define TOPK_IMPL_TYPE uint16_t -#include "topk_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_u32.cu b/onnxruntime/core/providers/cuda/math/topk_impl_u32.cu deleted file mode 100644 index fceb367e7e..0000000000 --- a/onnxruntime/core/providers/cuda/math/topk_impl_u32.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define TOPK_IMPL_TYPE uint32_t -#include "topk_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_u64.cu b/onnxruntime/core/providers/cuda/math/topk_impl_u64.cu deleted file mode 100644 index 1a7b3f2aed..0000000000 --- a/onnxruntime/core/providers/cuda/math/topk_impl_u64.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define TOPK_IMPL_TYPE uint64_t -#include "topk_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu b/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu deleted file mode 100644 index 7fcd4b81b3..0000000000 --- a/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define TOPK_IMPL_TYPE uint8_t -#include "topk_impl.cuh"