mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Refactor SkipLayerNorm and handle beta properly (#22862)
Signed-off-by: Liqun Fu <liqfu@microsoft.com> Signed-off-by: Liqun Fu <liqun.fu@microsoft.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
5928009553
commit
101ed10e5e
1 changed files with 78 additions and 108 deletions
|
|
@ -96,79 +96,6 @@ void ComputeJob(
|
|||
}
|
||||
}
|
||||
|
||||
void ComputeJob(
|
||||
const MLFloat16* input_data,
|
||||
const MLFloat16* skip_data,
|
||||
const float* prepacked_skip_fp32_data,
|
||||
const float* gamma_float_ptr,
|
||||
const float* beta_float_ptr,
|
||||
const float* bias_float_ptr,
|
||||
float* output_float_ptr,
|
||||
ptrdiff_t task_idx,
|
||||
int hidden_size,
|
||||
int64_t skip_size,
|
||||
float epsilon,
|
||||
bool simplified,
|
||||
MLFloat16* output_data,
|
||||
MLFloat16* skip_input_bias_add_output_data,
|
||||
AllocatorPtr alloc) {
|
||||
auto offset = task_idx * hidden_size;
|
||||
const MLFloat16* p_input = input_data + offset;
|
||||
MLFloat16* p_output = output_data + offset;
|
||||
MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;
|
||||
|
||||
float mean(0.0f);
|
||||
float mean_square(0.0f);
|
||||
const size_t num_elems = static_cast<size_t>(hidden_size);
|
||||
|
||||
IAllocatorUniquePtr<float> input_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
|
||||
MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems);
|
||||
|
||||
IAllocatorUniquePtr<float> skip_float_uptr = nullptr;
|
||||
if (prepacked_skip_fp32_data == nullptr && skip_data) {
|
||||
const MLFloat16* p_skip = skip_data + (offset % skip_size);
|
||||
skip_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
|
||||
MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems);
|
||||
}
|
||||
|
||||
const float* input_float_ptr = input_float_uptr.get();
|
||||
const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get();
|
||||
for (size_t h = 0; h < num_elems; h++) {
|
||||
float val = input_float_ptr[h] + skip_float_ptr[h];
|
||||
|
||||
if (bias_float_ptr) {
|
||||
val += bias_float_ptr[h];
|
||||
}
|
||||
|
||||
output_float_ptr[h] = val;
|
||||
mean += val;
|
||||
mean_square += val * val;
|
||||
}
|
||||
|
||||
if (nullptr != p_skip_input_bias_add_output) {
|
||||
MlasConvertFloatToHalfBuffer(output_float_ptr, p_skip_input_bias_add_output, num_elems);
|
||||
}
|
||||
|
||||
mean = mean / hidden_size;
|
||||
if (simplified) {
|
||||
mean_square = sqrt(mean_square / hidden_size + epsilon);
|
||||
} else {
|
||||
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
|
||||
}
|
||||
|
||||
for (size_t h = 0; h < num_elems; h++) {
|
||||
if (simplified) {
|
||||
output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
|
||||
} else if (nullptr == beta_float_ptr) {
|
||||
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
|
||||
} else {
|
||||
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
|
||||
}
|
||||
}
|
||||
|
||||
MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems);
|
||||
}
|
||||
|
||||
void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr<float>& dest, bool& is_packed) {
|
||||
if (tensor.GetElementType() == utils::ToTensorProtoElementType<MLFloat16>()) {
|
||||
auto tensor_data_ptr = tensor.Data<MLFloat16>();
|
||||
|
|
@ -200,8 +127,8 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
|
|||
const Tensor* input = p_ctx->Input<Tensor>(0);
|
||||
const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
|
||||
const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
|
||||
const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3);
|
||||
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(4);
|
||||
const Tensor* beta = simplified ? nullptr : (prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3));
|
||||
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(simplified ? 3 : 4);
|
||||
Tensor* output = p_ctx->Output(0, input->Shape());
|
||||
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
|
||||
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
|
||||
|
|
@ -232,56 +159,93 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
|
|||
|
||||
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
|
||||
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();
|
||||
|
||||
const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_;
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
IAllocatorUniquePtr<float> output_fp32;
|
||||
IAllocatorUniquePtr<float> gamma_fp32;
|
||||
IAllocatorUniquePtr<float> beta_fp32;
|
||||
IAllocatorUniquePtr<float> bias_fp32;
|
||||
|
||||
if constexpr (std::is_same_v<T, MLFloat16>) {
|
||||
const size_t total_data_size = static_cast<size_t>(input->Shape().Size());
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
IAllocatorUniquePtr<float> input_fp32;
|
||||
IAllocatorUniquePtr<float> output_fp32;
|
||||
IAllocatorUniquePtr<float> skip_input_bias_add_output_fp32;
|
||||
IAllocatorUniquePtr<float> skip_fp32;
|
||||
IAllocatorUniquePtr<float> gamma_fp32;
|
||||
IAllocatorUniquePtr<float> beta_fp32;
|
||||
IAllocatorUniquePtr<float> bias_fp32;
|
||||
|
||||
const float* input_data_f = nullptr;
|
||||
const float* skip_data_f = nullptr;
|
||||
const float* gamma_data_f = nullptr;
|
||||
const float* beta_data_f = nullptr;
|
||||
const float* bias_data_f = nullptr;
|
||||
float* output_data_f = nullptr;
|
||||
float* skip_input_bias_add_output_data_f = nullptr;
|
||||
|
||||
const size_t num_elems = static_cast<size_t>(hidden_size);
|
||||
|
||||
output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
|
||||
input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
|
||||
MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size);
|
||||
input_data_f = input_fp32.get();
|
||||
|
||||
if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
|
||||
output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
|
||||
output_data_f = output_fp32.get();
|
||||
|
||||
skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
|
||||
skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get();
|
||||
|
||||
if (skip_data) {
|
||||
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(skip_size));
|
||||
MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), static_cast<size_t>(skip_size));
|
||||
skip_data_f = skip_fp32.get();
|
||||
} else if (prepacked_skip_fp32_data_) {
|
||||
skip_data_f = prepacked_skip_fp32_data_.get();
|
||||
}
|
||||
|
||||
if (gamma_data) {
|
||||
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
|
||||
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
|
||||
gamma_data_f = gamma_fp32.get();
|
||||
} else if (prepacked_gamma_fp32_data_) {
|
||||
gamma_data_f = prepacked_gamma_fp32_data_.get();
|
||||
}
|
||||
|
||||
if (prepacked_beta_fp32_data_ == nullptr && beta_data) {
|
||||
if (beta_data) {
|
||||
beta_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
|
||||
MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems);
|
||||
beta_data_f = beta_fp32.get();
|
||||
} else if (prepacked_beta_fp32_data_) {
|
||||
beta_data_f = prepacked_beta_fp32_data_.get();
|
||||
}
|
||||
|
||||
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
|
||||
if (bias_data) {
|
||||
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
|
||||
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
|
||||
bias_data_f = bias_fp32.get();
|
||||
} else if (prepacked_bias_fp32_data_) {
|
||||
bias_data_f = prepacked_bias_fp32_data_.get();
|
||||
}
|
||||
}
|
||||
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
|
||||
[&](ptrdiff_t task_idx) {
|
||||
if constexpr (std::is_same_v<T, MLFloat16>) {
|
||||
ComputeJob(input_data, skip_data,
|
||||
prepacked_skip_fp32_data_.get(),
|
||||
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
|
||||
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
|
||||
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
|
||||
output_fp32.get(),
|
||||
task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
|
||||
skip_input_bias_add_output_data, alloc);
|
||||
} else {
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
|
||||
[&](ptrdiff_t task_idx) {
|
||||
ComputeJob(input_data_f, skip_data_f, gamma_data_f, beta_data_f, bias_data_f, task_idx, hidden_size, skip_size,
|
||||
epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f);
|
||||
},
|
||||
0);
|
||||
MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size);
|
||||
if (skip_input_bias_add_output_data != nullptr)
|
||||
MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size);
|
||||
} else {
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
|
||||
[&](ptrdiff_t task_idx) {
|
||||
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
|
||||
epsilon_, simplified, output_data, skip_input_bias_add_output_data);
|
||||
}
|
||||
},
|
||||
0);
|
||||
},
|
||||
0);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -290,16 +254,22 @@ template <typename T, bool simplified>
|
|||
Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
||||
bool& is_packed, PrePackedWeights* prepacked_weights) {
|
||||
ORT_UNUSED_PARAMETER(prepacked_weights);
|
||||
|
||||
is_packed = false;
|
||||
if (input_idx == 1) { // skip
|
||||
prepacked_skip_fp32_size_ = tensor.Shape().Size();
|
||||
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed);
|
||||
} else if (input_idx == 2) { // gamma
|
||||
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
|
||||
} else if (input_idx == 3) { // beta
|
||||
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
|
||||
} else if (input_idx == 3) {
|
||||
if constexpr (simplified) {
|
||||
// bias
|
||||
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
|
||||
} else {
|
||||
// beta
|
||||
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
|
||||
}
|
||||
} else if (input_idx == 4) { // bias
|
||||
ORT_ENFORCE(!simplified, "SkipSimplifiedLayerNormalization should only has 4 inputs (input, skip, gamma, and beta). Got 5.");
|
||||
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue