diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index df0f170317..c1d383064a 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -87,11 +87,6 @@ elseif(MSVC) ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") - if (onnxruntime_MINIMAL_BUILD) - # exclude AVX512 in minimal build - set_source_files_properties(${mlas_common_srcs} PROPERTIES COMPILE_FLAGS "-DMLAS_AVX512F_UNSUPPORTED") - endif() - set(mlas_platform_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/dgemm.cpp ${mlas_platform_srcs_avx} @@ -323,81 +318,22 @@ else() ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") - # Some toolchains do not support AVX512 compiler flags but are still able - # to build the sources. Other toolchains require the AVX512 compiler flags - # to be specified. - check_cxx_compiler_flag("-mavx512f" HAS_AVX512F) - if(HAS_AVX512F) - set(CMAKE_REQUIRED_FLAGS "-mavx512f") - else() - set(CMAKE_REQUIRED_FLAGS "") - endif() - check_cxx_source_compiles(" - int main() { - asm(\"vpxord %zmm0,%zmm0,%zmm0\"); - return 0; - }" - COMPILES_AVX512F + set(mlas_platform_srcs_avx512f + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/DgemmKernelAvx512F.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelAvx512F.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelAvx512F.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SpoolKernelAvx512F.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/TransKernelAvx512F.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp ) + set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") - if(COMPILES_AVX512F AND NOT onnxruntime_MINIMAL_BUILD) - set(mlas_platform_srcs_avx512f - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/DgemmKernelAvx512F.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelAvx512F.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelAvx512F.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SpoolKernelAvx512F.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/TransKernelAvx512F.S - ) - - check_cxx_source_compiles(" - #include - int main() { - __m512 zeros = _mm512_set1_ps(0.f); - (void)zeros; - return 0; - }" - COMPILES_AVX512F_INTRINSICS - ) - if(COMPILES_AVX512F_INTRINSICS) - set(mlas_platform_srcs_avx512f - ${ONNXRUNTIME_ROOT}/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp - ${mlas_platform_srcs_avx512f} - ) - else() - set_source_files_properties(${mlas_common_srcs} PROPERTIES COMPILE_FLAGS "-DMLAS_AVX512F_INTRINSICS_UNSUPPORTED") - endif() - if(HAS_AVX512F) - set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") - endif() - - check_cxx_compiler_flag("-mavx512bw -mavx512dq -mavx512vl" HAS_AVX512CORE) - if(HAS_AVX512CORE) - set(CMAKE_REQUIRED_FLAGS "-mavx512bw -mavx512dq -mavx512vl") - endif() - check_cxx_source_compiles(" - int main() { - asm(\"vpmaddwd %zmm0,%zmm0,%zmm0\"); // AVX512BW feature - asm(\"vandnps %xmm31,%xmm31,%xmm31\"); // AVX512DQ/AVX512VL feature - return 0; - }" - COMPILES_AVX512CORE - ) - - if(COMPILES_AVX512CORE) - set(mlas_platform_srcs_avx512core - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S - ) - if(HAS_AVX512CORE) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") - endif() - else() - set_source_files_properties(${mlas_common_srcs} PROPERTIES COMPILE_FLAGS "-DMLAS_AVX512CORE_UNSUPPORTED") - endif() - else() - set_source_files_properties(${mlas_common_srcs} PROPERTIES COMPILE_FLAGS "-DMLAS_AVX512F_UNSUPPORTED") - endif() + set(mlas_platform_srcs_avx512core + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S + ) + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") set(mlas_platform_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/dgemm.cpp diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 9fe37259c0..499401a210 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -281,7 +281,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; } -#if !defined(MLAS_AVX512F_UNSUPPORTED) +#if !defined(ORT_MINIMAL_BUILD) // // Check if the processor supports AVX512F features and the @@ -301,21 +301,16 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx512F; this->ComputeExpF32Kernel = MlasComputeExpF32KernelAvx512F; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelAvx512F; - this->NchwcBlockSize = 16; - this->PreferredBufferAlignment = 64; - -#if !defined(MLAS_AVX512F_INTRINSICS_UNSUPPORTED) this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelAvx512F; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelAvx512F; -#endif + this->NchwcBlockSize = 16; + this->PreferredBufferAlignment = 64; // // Check if the processor supports AVX512 core features // (AVX512BW/AVX512DQ/AVX512VL). // -#if !defined(MLAS_AVX512CORE_UNSUPPORTED) - if ((Cpuid7[1] & 0xC0020000) == 0xC0020000) { this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Core; @@ -333,12 +328,9 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; } } - -#endif // MLAS_AVX512CORE_UNSUPPORTED - } -#endif // MLAS_AVX512F_UNSUPPORTED +#endif // ORT_MINIMAL_BUILD }