mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
### Description Implementation of a new cast assembly kernel that uses AVX_NE_CONVERT instructions to accelerate casting from FP16 to FP32. Added CPUID checks to determine support of the ISA. ### Motivation and Context Currently FP16 models executed on systems that lack complete FP16 operator support use single precision on every node to run the model, this means the original FP16 weights have to be casted to FP32 in order to run the model properly, this change aims to accelerate the casting by using upconvert instructions and therefore improve performance.
59 lines
1.6 KiB
C++
59 lines
1.6 KiB
C++
/*++
|
|
|
|
Copyright (c) Intel Corporation. All rights reserved.
|
|
|
|
Licensed under the MIT License.
|
|
|
|
Module Name:
|
|
|
|
cast.cpp
|
|
|
|
Abstract:
|
|
|
|
This module implements Half (F16) to Single (F32) precision casting.
|
|
|
|
--*/
|
|
#include "mlasi.h"
|
|
|
|
union fp32_bits {
|
|
uint32_t u;
|
|
float f;
|
|
};
|
|
|
|
void
|
|
MLASCALL
|
|
MlasConvertHalfToFloatBuffer(
|
|
const unsigned short* Source,
|
|
float* Destination,
|
|
size_t Count
|
|
)
|
|
{
|
|
|
|
if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) {
|
|
// If there is no kernel use the reference implementation, adapted from mlas_float16.h.
|
|
constexpr fp32_bits magic = {113 << 23};
|
|
constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
|
|
|
for (size_t i = 0; i < Count; ++i) {
|
|
fp32_bits o;
|
|
o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits
|
|
uint32_t exp = shifted_exp & o.u; // just the exponent
|
|
o.u += (127 - 15) << 23; // exponent adjust
|
|
|
|
// handle exponent special cases
|
|
if (exp == shifted_exp) { // Inf/NaN?
|
|
o.u += (128 - 16) << 23; // extra exp adjust
|
|
} else if (exp == 0) { // Zero/Denormal?
|
|
o.u += 1 << 23; // extra exp adjust
|
|
o.f -= magic.f; // renormalize
|
|
}
|
|
|
|
o.u |= (Source[i] & 0x8000) << 16; // sign bit
|
|
Destination[i] = o.f;
|
|
}
|
|
|
|
} else {
|
|
// If the kernel is available, use it to perform the conversion.
|
|
GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count);
|
|
}
|
|
}
|