Enable AVX NE CONVERT for FP16 to FP32 cast (#21183)

### Description
Implementation of a new cast assembly kernel that uses AVX_NE_CONVERT
instructions to accelerate casting from FP16 to FP32. Added CPUID checks
to determine support of the ISA.

### Motivation and Context
Currently FP16 models executed on systems that lack complete FP16
operator support use single precision on every node to run the model,
this means the original FP16 weights have to be casted to FP32 in order
to run the model properly, this change aims to accelerate the casting by
using upconvert instructions and therefore improve performance.
This commit is contained in:
Erick Muñoz 2024-09-09 22:19:31 -06:00 committed by GitHub
parent d4d419f789
commit 7489bfee53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 537 additions and 11 deletions

View file

@ -40,6 +40,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
)
target_sources(onnxruntime_mlas PRIVATE
@ -212,6 +213,12 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
)
if(MSVC_VERSION GREATER_EQUAL 1933)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm
)
endif()
if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
@ -522,6 +529,12 @@ else()
${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S
${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S
)
if(NOT APPLE)
set(mlas_platform_srcs_sse2
${mlas_platform_srcs_sse2}
${MLAS_SRC_DIR}/x86_64/cvtfp16a.S
)
endif()
set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2")
set(mlas_platform_srcs_avx
@ -555,6 +568,12 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S
)
endif()
message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")

View file

@ -1029,14 +1029,13 @@ MlasComputeTanh(
// Half-precision floating-point routines.
//
extern "C"
void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
);
);
//
// Transpose routines.

View file

@ -0,0 +1,151 @@
;++
;
; Copyright (c) Intel Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; cvtfp16Avx2.asm
;
; Abstract:
;
; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA.
;
;--
.xlist
INCLUDE mlasi.inc
.list
.const
SINGLE_SIZE equ 4
HALF_SIZE equ 2
LOW_SELECTOR equ 00100000b
HIGH_SELECTOR equ 00110001b
SUBTTL "Convert buffer of half-precision floats to single-precision floats"
;++
;
; Routine Description:
;
; This routine converts the source buffer of half-precision floats to the
; destination buffer of single-precision floats.
;
; This implementation uses AVX2 instructions.
;
; Arguments:
;
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
;
; Return Value:
;
; None.
;
;--
LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT
test r8, r8 ; Check if we have any elements to convert
jz ExitRoutine
cmp r8, 8
jb ConvertMaskedVectors
cmp r8, 16
jb Convert128Vectors
Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes
vunpcklps ymm2, ymm0, ymm1 ; Interleave low part
vunpckhps ymm1, ymm0, ymm1 ; Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order
vmovups ymmword PTR [rdx], ymm0 ; Store the low part
vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part
add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements
add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements
sub r8, 16 ; Reduce the counter by 16 elements
jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors
Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part
add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements
add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements
sub r8, 8 ; Reduce the counter by 8 elements
jz ExitRoutine ; If we are done, exit
ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order
cmp r8, 4 ; Check if we can store the complete lower vector
jae ConvertLowerVector
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
sub r8, 4 ; Check if we still need to convert
jz ExitRoutine
add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples
ExitRoutine:
ret
LEAF_END MlasCastF16ToF32KernelAvx, _TEXT
END

View file

@ -42,7 +42,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h)
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (edx) - Supplies the address of the destination buffer of
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
@ -53,7 +53,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h)
;
;--
LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT
LEAF_ENTRY MlasCastF16ToF32KernelSse, _TEXT
test r8,r8
jz ExitRoutine
@ -119,6 +119,6 @@ StoreLastElement:
ExitRoutine:
ret
LEAF_END MlasConvertHalfToFloatBuffer, _TEXT
LEAF_END MlasCastF16ToF32KernelSse, _TEXT
END

View file

@ -0,0 +1,59 @@
/*++
Copyright (c) Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
cast.cpp
Abstract:
This module implements Half (F16) to Single (F32) precision casting.
--*/
#include "mlasi.h"
union fp32_bits {
uint32_t u;
float f;
};
void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
)
{
if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) {
// If there is no kernel use the reference implementation, adapted from mlas_float16.h.
constexpr fp32_bits magic = {113 << 23};
constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
for (size_t i = 0; i < Count; ++i) {
fp32_bits o;
o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits
uint32_t exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}
o.u |= (Source[i] & 0x8000) << 16; // sign bit
Destination[i] = o.f;
}
} else {
// If the kernel is available, use it to perform the conversion.
GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count);
}
}

View file

@ -610,6 +610,13 @@ void
size_t N
);
typedef
void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)(
const unsigned short* Source,
float* Destination,
size_t Count
);
typedef
void
(MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)(
@ -870,6 +877,11 @@ extern "C" {
MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx;
#endif
#if defined(MLAS_TARGET_AMD64)
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse;
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx;
#endif
}
//
@ -1151,6 +1163,8 @@ struct MLAS_PLATFORM {
const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr};
const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr};
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
};
inline

View file

@ -244,6 +244,7 @@ Return Value:
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel<uint8_t, uint8_t>;
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel<int8_t, uint8_t>;
this->CastF16ToF32Kernel = nullptr;
#if defined(MLAS_TARGET_AMD64_IX86)
@ -283,6 +284,9 @@ Return Value:
this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;
this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel;
this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel;
#ifndef __APPLE__
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse;
#endif // __APPLE__
this->NchwcBlockSize = 8;
this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT;
@ -469,6 +473,16 @@ Return Value:
}
#ifndef __APPLE__
#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))
//
// Check if the processor supports AVX NE CONVERT.
//
if ((Cpuid7_1[3] & (0b1 << 5)) != 0) {
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx;
}
#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))
//
// Check if the processor supports AMX-TILE and AMX-INT8
// features.

View file

@ -0,0 +1,143 @@
/*++
Copyright (c) Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
cvtfp16Avx2.asm
Abstract:
This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA.
--*/
#include "asmmacro.h"
.data
.equ SINGLE_SIZE, 4
.equ HALF_SIZE, 2
.equ LOW_SELECTOR, 0b00100000
.equ HIGH_SELECTOR, 0b00110001
.text
.intel_syntax noprefix
/*++ Routine Description:
This routine converts the source buffer of half-precision floats to the
destination buffer of single-precision floats.
This implementation uses AVX2 instructions.
Arguments:
Source (rdi) - Supplies the address of the source buffer of half-precision
floats.
Destination (rsi) - Supplies the address of the destination buffer of
single-precision floats.
Count (rdx) - Supplies the number of elements to convert.
Return Value:
None.
--*/
FUNCTION_ENTRY MlasCastF16ToF32KernelAvx
test rdx, rdx // Check if we have any elements to convert
jz ExitRoutine
AVX_NE_CONVERT:
cmp rdx, 8
jb ConvertMaskedVectors
cmp rdx, 16
jb Convert128Vectors
Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes
vunpcklps ymm2, ymm0, ymm1 // Interleave low part
vunpckhps ymm1, ymm0, ymm1 // Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order
vmovups ymmword PTR [rsi], ymm0 // Store the low part
vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part
add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements
add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements
sub rdx, 16 // Reduce the counter by 16 elements
jz ExitRoutine // If we are done, exit
cmp rdx, 16 // If the vector is big enough, we go again
jae Convert256Vectors
Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order
vmovups xmmword PTR [rsi], xmm0 // Store the low part
vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part
add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements
add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements
sub rdx, 8 // Reduce the counter by 8 elements
jz ExitRoutine // If we are done, exit
ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order
cmp rdx, 4 // Check if we can store the complete lower vector
jae ConvertLowerVector
vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones
cmp rdx, 2 // Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rsi], xmm0 // Store the low part
sub rdx, 4 // Check if we still need to convert
jz ExitRoutine
add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones
cmp rdx, 2 // Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine
ExitRoutine:
ret

View file

@ -0,0 +1,129 @@
/*++
Copyright (c) Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
cvtfp16a.S
Abstract:
This module implements routines to convert between FP16 and FP32 formats using SSE2 isntructions.
--*/
#include "asmmacro.h"
// We use RIP relative addressing to avoid reallication related errors
.section .rodata
MlasFp16MaskSign: .long 0x00007FFF
MlasFp16CompareInfinity: .long 0x00007C00
MlasFp16CompareSmallest: .long 0x00000400
MlasFp16AdjustExponent: .long 0x38000000
MlasFp16MagicDenormal: .long 0x38800000
.text
.intel_syntax noprefix
/*++ Routine Description:
This routine converts the source buffer of half-precision floats to the
destination buffer of single-precision floats.
This implementation uses SSE2 instructions.
Arguments:
Source (rdi) - Supplies the address of the source buffer of half-precision
floats.
Destination (rsi) - Supplies the address of the destination buffer of
single-precision floats.
Count (rdx) - Supplies the number of elements to convert.
Return Value:
None.
--*/
FUNCTION_ENTRY MlasCastF16ToF32KernelSse
test rdx,rdx
jz ExitRoutine
// Load xmm constants
movd xmm5, DWORD PTR [rip + MlasFp16MaskSign]
pshufd xmm5, xmm5, 0x00
movd xmm6, DWORD PTR [rip + MlasFp16AdjustExponent]
pshufd xmm6, xmm6, 0x00
movd xmm7, DWORD PTR [rip + MlasFp16MagicDenormal]
pshufd xmm7, xmm7, 0x00
cmp rdx,4
jb LoadPartialVector
LoadFullVector:
movq xmm0,QWORD PTR [rdi]
add rdi,4*2 // advance S by 4 elements
ConvertHalfToFloat:
punpcklwd xmm0,xmm0 // duplicate 4 WORDs to 4 DWORDs
movaps xmm1,xmm0 // isolate exponent/mantissa
pand xmm1,xmm5
pxor xmm0,xmm1 // isolate sign bit
movd xmm2, DWORD PTR [rip + MlasFp16CompareInfinity]
pshufd xmm2, xmm2, 0x00
pcmpgtd xmm2,xmm1 // test for infinity/NaNs
movd xmm3, DWORD PTR [rip + MlasFp16CompareSmallest]
pshufd xmm3, xmm3, 0x00
pcmpgtd xmm3,xmm1 // test for denormals
pandn xmm2,xmm6
pslld xmm1,13 // shift exponent/mask into place
movaps xmm4,xmm1
paddd xmm1,xmm6
paddd xmm1,xmm2 // adjust exponent again for infinity/NaNs
paddd xmm4,xmm7
pslld xmm0,16 // shift sign into place
subps xmm4,xmm7
pand xmm4,xmm3 // select elements that are denormals
pandn xmm3,xmm1 // select elements that are not denormals
por xmm3,xmm4 // blend the selected values together
por xmm0,xmm3 // merge sign into exponent/mantissa
cmp rdx,4 // storing full vector?
jb StorePartialVector
movups XMMWORD PTR [rsi],xmm0
add rsi,4*4 // advance D by 4 elements
sub rdx,4
jz ExitRoutine
cmp rdx,4
jae LoadFullVector
LoadPartialVector:
pxor xmm0,xmm0
pinsrw xmm0,WORD PTR [rdi],0
cmp rdx,2
jb ConvertHalfToFloat
pinsrw xmm0,WORD PTR [rdi+2],1
je ConvertHalfToFloat
pinsrw xmm0,WORD PTR [rdi+4],2
jmp ConvertHalfToFloat
StorePartialVector:
cmp rdx,2
jb StoreLastElement
movsd QWORD PTR [rsi],xmm0
je ExitRoutine
movhlps xmm0,xmm0 // shift third element down
add rsi,4*2 // advance D by 2 elements
StoreLastElement:
movss DWORD PTR [rsi],xmm0
ExitRoutine:
ret

View file

@ -22,9 +22,8 @@
#include "Eigen/src/Core/arch/Default/BFloat16.h"
#include "Eigen/src/Core/arch/Default/Half.h"
#if defined(_M_AMD64) && !defined(_M_ARM64EC)
#include "core/mlas/inc/mlas.h"
#endif
#include "core/common/cpuid_info.h"
namespace onnxruntime {
@ -252,10 +251,6 @@ struct TensorCasterNoSat<std::string, DstType> {
#endif
#if defined(_M_AMD64) && !defined(_M_ARM64EC)
// specializations to use optimized and Windows x64-specific
// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion
// tensor MLFloat16 -> float
template <>
struct TensorCaster<MLFloat16, float> {
@ -267,6 +262,9 @@ struct TensorCaster<MLFloat16, float> {
}
};
#if defined(_M_AMD64) && !defined(_M_ARM64EC)
// specializations to use optimized and Windows x64-specific
Tensor GetIntermediateMLFloat16ToFloatTensor(
const OpKernelContext& context, const TensorShape& shape, const Tensor& in) {
AllocatorPtr allocator;