From fb11c673682cefe70992ecf6dc5c22289ebb24dd Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 8 Aug 2023 14:04:03 -0700 Subject: [PATCH] Fix SkipLayerNorm for 2D input (#17014) Fix an obvious bug: (1) In packing mode, the input for SLN has two dimensions (introduced by #15283): [token_count, hidden_size]. Current code of `element_count = input_dims[0] * sequence_length * hidden_size` will use element_size = token_count * hidden_size * hidden_size, and causes invalid memory write in cuda kernel and ORT crash and two minor issues: (2) potential integer overflow in `static_cast(element_count)` (3) some dead code after `return LaunchSkipLayerNormKernel` that will never have chance to run. --- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 41 ++++++++----------- .../cuda/bert/skip_layer_norm_impl.cu | 30 +++++--------- .../cuda/bert/skip_layer_norm_impl.h | 9 ++-- 3 files changed, 31 insertions(+), 49 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 0d8f887d98..78174181ac 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -79,9 +79,10 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); + int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); + typedef typename ToCudaType::MappedType CudaT; + if (strict_) { - int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); - typedef typename ToCudaType::MappedType CudaT; HostApplyLayerNorm( GetDeviceProp(), Stream(ctx), @@ -96,34 +97,24 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add + skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + } else { + LaunchSkipLayerNormKernel( + Stream(ctx), + reinterpret_cast(output->MutableData()), skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + epsilon_, + hidden_size, + row_count, skip_broadcasted, skip_size); - - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - return Status::OK(); } - int sequence_length = static_cast(input_dims[1]); - int64_t element_count = input_dims[0] * sequence_length * hidden_size; - size_t element_size = sizeof(T); - typedef typename ToCudaType::MappedType CudaT; - return LaunchSkipLayerNormKernel( - Stream(ctx), - reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - epsilon_, - hidden_size, - static_cast(element_count), - element_size, - skip_broadcasted, - skip_size); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 1d0d97323d..f2ee076a8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -174,22 +174,18 @@ __global__ void SkipLayerNormKernelSmall( } template -Status LaunchSkipLayerNormKernel( +void LaunchSkipLayerNormKernel( cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, const int ld, const int element_count, - size_t element_size, const bool skip_broadcasted, const int skip_size) { - // this must be true because n is the total size of the tensor - - if (element_count == 0) { - return Status::OK(); + const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) { + if (row_count == 0) { + return; } - assert(element_count % ld == 0); bool hasBias = (bias == nullptr) ? false : true; bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; const int next_size = NextSize(ld); - const int grid_size = element_count / ld; + const int grid_size = row_count; bool flag_vec2 = CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); bool flag_vec4 = @@ -231,18 +227,14 @@ Status LaunchSkipLayerNormKernel( #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL } - return CUDA_CALL(cudaGetLastError()); } -#define SKIPLAYERNORM_IMPL(T, Simplified) \ - template Status LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, \ - T * skip_input_bias_add_output, \ - const T* input, const T* skip, const T* gamma, \ - const T* beta, const T* bias, float epsilon, \ - const int ld, const int element_count, \ - size_t element_size, const bool skip_broadcasted, \ - const int skip_size); - +#define SKIPLAYERNORM_IMPL(T, Simplified) \ + template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, \ + T * skip_input_bias_add_output, \ + const T* input, const T* skip, const T* gamma, \ + const T* beta, const T* bias, float epsilon, \ + int ld, int row_count, bool skip_broadcasted, int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index 4900e64be1..ffb5850c82 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -9,7 +9,7 @@ namespace contrib { namespace cuda { template -Status LaunchSkipLayerNormKernel( +void LaunchSkipLayerNormKernel( cudaStream_t stream, T* output, // normalized output tensor T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output @@ -20,10 +20,9 @@ Status LaunchSkipLayerNormKernel( const T* bias, // Layer normalization beta tensor float epsilon, // Layer normalization epsilon int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count, // number of elements in input tensor - size_t element_size, - const bool skip_broadcasted, // determines if broadcasting should be implemented - const int skip_size); // determines size of the skip tensor + int row_count, // number of rows. That is total number of elements divided by hidden size. + bool skip_broadcasted, // determines if broadcasting should be implemented + int skip_size); // determines size of the skip tensor } // namespace cuda } // namespace contrib