Fix SkipLayerNorm for 2D input (#17014)

Fix an obvious bug:
(1) In packing mode, the input for SLN has two dimensions (introduced by
#15283): [token_count, hidden_size]. Current code of `element_count =
input_dims[0] * sequence_length * hidden_size` will use element_size =
token_count * hidden_size * hidden_size, and causes invalid memory write
in cuda kernel and ORT crash

and two minor issues:
(2) potential integer overflow in `static_cast<int>(element_count)`
(3) some dead code after `return LaunchSkipLayerNormKernel` that will
never have chance to run.
This commit is contained in:
Tianlei Wu 2023-08-08 14:04:03 -07:00 committed by GitHub
parent 73037978f8
commit fb11c67368
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 49 deletions

View file

@ -79,9 +79,10 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false;
const int skip_size = static_cast<int>(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]);
int row_count = gsl::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
typedef typename ToCudaType<T>::MappedType CudaT;
if (strict_) {
int row_count = gsl::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
typedef typename ToCudaType<T>::MappedType CudaT;
HostApplyLayerNorm<CudaT, float, CudaT, Simplified>(
GetDeviceProp(),
Stream(ctx),
@ -96,34 +97,24 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr);
} else {
LaunchSkipLayerNormKernel<CudaT, Simplified>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_broadcasted,
skip_size);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
int sequence_length = static_cast<int>(input_dims[1]);
int64_t element_count = input_dims[0] * sequence_length * hidden_size;
size_t element_size = sizeof(T);
typedef typename ToCudaType<T>::MappedType CudaT;
return LaunchSkipLayerNormKernel<CudaT, Simplified>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
static_cast<int>(element_count),
element_size,
skip_broadcasted,
skip_size);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}

View file

@ -174,22 +174,18 @@ __global__ void SkipLayerNormKernelSmall(
}
template <typename T, bool Simplified>
Status LaunchSkipLayerNormKernel(
void LaunchSkipLayerNormKernel(
cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
size_t element_size, const bool skip_broadcasted, const int skip_size) {
// this must be true because n is the total size of the tensor
if (element_count == 0) {
return Status::OK();
const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) {
if (row_count == 0) {
return;
}
assert(element_count % ld == 0);
bool hasBias = (bias == nullptr) ? false : true;
bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true;
const int next_size = NextSize(ld);
const int grid_size = element_count / ld;
const int grid_size = row_count;
bool flag_vec2 =
CanVectorized<T, 2>(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size);
bool flag_vec4 =
@ -231,18 +227,14 @@ Status LaunchSkipLayerNormKernel(
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL
}
return CUDA_CALL(cudaGetLastError());
}
#define SKIPLAYERNORM_IMPL(T, Simplified) \
template Status LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, \
T * skip_input_bias_add_output, \
const T* input, const T* skip, const T* gamma, \
const T* beta, const T* bias, float epsilon, \
const int ld, const int element_count, \
size_t element_size, const bool skip_broadcasted, \
const int skip_size);
#define SKIPLAYERNORM_IMPL(T, Simplified) \
template void LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, \
T * skip_input_bias_add_output, \
const T* input, const T* skip, const T* gamma, \
const T* beta, const T* bias, float epsilon, \
int ld, int row_count, bool skip_broadcasted, int skip_size);
SKIPLAYERNORM_IMPL(float, true);
SKIPLAYERNORM_IMPL(float, false);
SKIPLAYERNORM_IMPL(half, true);

View file

@ -9,7 +9,7 @@ namespace contrib {
namespace cuda {
template <typename T, bool Simplified>
Status LaunchSkipLayerNormKernel(
void LaunchSkipLayerNormKernel(
cudaStream_t stream,
T* output, // normalized output tensor
T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output
@ -20,10 +20,9 @@ Status LaunchSkipLayerNormKernel(
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
size_t element_size,
const bool skip_broadcasted, // determines if broadcasting should be implemented
const int skip_size); // determines size of the skip tensor
int row_count, // number of rows. That is total number of elements divided by hidden size.
bool skip_broadcasted, // determines if broadcasting should be implemented
int skip_size); // determines size of the skip tensor
} // namespace cuda
} // namespace contrib