mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Fix SkipLayerNorm for 2D input (#17014)
Fix an obvious bug: (1) In packing mode, the input for SLN has two dimensions (introduced by #15283): [token_count, hidden_size]. Current code of `element_count = input_dims[0] * sequence_length * hidden_size` will use element_size = token_count * hidden_size * hidden_size, and causes invalid memory write in cuda kernel and ORT crash and two minor issues: (2) potential integer overflow in `static_cast<int>(element_count)` (3) some dead code after `return LaunchSkipLayerNormKernel` that will never have chance to run.
This commit is contained in:
parent
73037978f8
commit
fb11c67368
3 changed files with 31 additions and 49 deletions
|
|
@ -79,9 +79,10 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
|
|||
const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false;
|
||||
const int skip_size = static_cast<int>(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]);
|
||||
|
||||
int row_count = gsl::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
if (strict_) {
|
||||
int row_count = gsl::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
HostApplyLayerNorm<CudaT, float, CudaT, Simplified>(
|
||||
GetDeviceProp(),
|
||||
Stream(ctx),
|
||||
|
|
@ -96,34 +97,24 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
|
|||
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
|
||||
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
|
||||
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
|
||||
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr);
|
||||
} else {
|
||||
LaunchSkipLayerNormKernel<CudaT, Simplified>(
|
||||
Stream(ctx),
|
||||
reinterpret_cast<CudaT*>(output->MutableData<T>()),
|
||||
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
|
||||
reinterpret_cast<const CudaT*>(input->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(skip->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
|
||||
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
|
||||
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
|
||||
epsilon_,
|
||||
hidden_size,
|
||||
row_count,
|
||||
skip_broadcasted,
|
||||
skip_size);
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int sequence_length = static_cast<int>(input_dims[1]);
|
||||
int64_t element_count = input_dims[0] * sequence_length * hidden_size;
|
||||
size_t element_size = sizeof(T);
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
return LaunchSkipLayerNormKernel<CudaT, Simplified>(
|
||||
Stream(ctx),
|
||||
reinterpret_cast<CudaT*>(output->MutableData<T>()),
|
||||
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
|
||||
reinterpret_cast<const CudaT*>(input->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(skip->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
|
||||
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
|
||||
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
|
||||
epsilon_,
|
||||
hidden_size,
|
||||
static_cast<int>(element_count),
|
||||
element_size,
|
||||
skip_broadcasted,
|
||||
skip_size);
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,22 +174,18 @@ __global__ void SkipLayerNormKernelSmall(
|
|||
}
|
||||
|
||||
template <typename T, bool Simplified>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
void LaunchSkipLayerNormKernel(
|
||||
cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
|
||||
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
|
||||
size_t element_size, const bool skip_broadcasted, const int skip_size) {
|
||||
// this must be true because n is the total size of the tensor
|
||||
|
||||
if (element_count == 0) {
|
||||
return Status::OK();
|
||||
const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) {
|
||||
if (row_count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(element_count % ld == 0);
|
||||
bool hasBias = (bias == nullptr) ? false : true;
|
||||
bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true;
|
||||
|
||||
const int next_size = NextSize(ld);
|
||||
const int grid_size = element_count / ld;
|
||||
const int grid_size = row_count;
|
||||
bool flag_vec2 =
|
||||
CanVectorized<T, 2>(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size);
|
||||
bool flag_vec4 =
|
||||
|
|
@ -231,18 +227,14 @@ Status LaunchSkipLayerNormKernel(
|
|||
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL
|
||||
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL
|
||||
}
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
#define SKIPLAYERNORM_IMPL(T, Simplified) \
|
||||
template Status LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, \
|
||||
T * skip_input_bias_add_output, \
|
||||
const T* input, const T* skip, const T* gamma, \
|
||||
const T* beta, const T* bias, float epsilon, \
|
||||
const int ld, const int element_count, \
|
||||
size_t element_size, const bool skip_broadcasted, \
|
||||
const int skip_size);
|
||||
|
||||
#define SKIPLAYERNORM_IMPL(T, Simplified) \
|
||||
template void LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, \
|
||||
T * skip_input_bias_add_output, \
|
||||
const T* input, const T* skip, const T* gamma, \
|
||||
const T* beta, const T* bias, float epsilon, \
|
||||
int ld, int row_count, bool skip_broadcasted, int skip_size);
|
||||
SKIPLAYERNORM_IMPL(float, true);
|
||||
SKIPLAYERNORM_IMPL(float, false);
|
||||
SKIPLAYERNORM_IMPL(half, true);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ namespace contrib {
|
|||
namespace cuda {
|
||||
|
||||
template <typename T, bool Simplified>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
void LaunchSkipLayerNormKernel(
|
||||
cudaStream_t stream,
|
||||
T* output, // normalized output tensor
|
||||
T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output
|
||||
|
|
@ -20,10 +20,9 @@ Status LaunchSkipLayerNormKernel(
|
|||
const T* bias, // Layer normalization beta tensor
|
||||
float epsilon, // Layer normalization epsilon
|
||||
int hidden_size, // hidden size, it is the leading dimension (ld)
|
||||
int element_count, // number of elements in input tensor
|
||||
size_t element_size,
|
||||
const bool skip_broadcasted, // determines if broadcasting should be implemented
|
||||
const int skip_size); // determines size of the skip tensor
|
||||
int row_count, // number of rows. That is total number of elements divided by hidden size.
|
||||
bool skip_broadcasted, // determines if broadcasting should be implemented
|
||||
int skip_size); // determines size of the skip tensor
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
|||
Loading…
Reference in a new issue