From a461608409a67f7eb4f7f9badc12ca81cd20f52d Mon Sep 17 00:00:00 2001 From: Dipanjan Sengupta Date: Thu, 13 Jul 2023 11:19:49 -0700 Subject: [PATCH] Amx flag removal (#16527) ### Description 1. Replacing AMX intrinsics with machine code macros in QGEMM kernel. 2. Removing AMX build flags for GCC in cmake file. 3. Fixing the link time optimization (LTO) issue introduced with asm .include of an assembly file. I have moved the AMX instruction macro definitions from QgemmU8S8KernelAmxCommon.S to the amx_common.h to fix the LTO issue. Note that I am also pushing the macros defined in QgemmU8S8KernelAmxCommon.S for future reference. A special thanks to @laxmansole who helped in the development of the instruction macro definitions for AMX intrinsics and fixing the LTO issue. ### Motivation and Context The additional AMX flag in cmake adds an extra layer of dependency on GCC version to use the feature.These changes should allow the usage of the AMX feature with just the CPU ID check. --- cmake/onnxruntime_mlas.cmake | 39 +-- onnxruntime/core/mlas/lib/amx_common.h | 80 ++++++ onnxruntime/core/mlas/lib/mlasi.h | 2 - onnxruntime/core/mlas/lib/platform.cpp | 4 +- .../core/mlas/lib/qgemm_kernel_amx.cpp | 264 +++++++++--------- .../lib/x86_64/QgemmU8S8KernelAmxCommon.S | 234 ++++++++++++++++ 6 files changed, 461 insertions(+), 162 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/amx_common.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index ab05164c33..db40dee554 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -3,28 +3,6 @@ set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) - -set(MLAS_AMX_SUPPORTED FALSE) - -if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11) - # match assembler version, AMX instructions are supported from 2.38 - if (CMAKE_ASM_COMPILER_ID STREQUAL "GNU") - execute_process( - COMMAND as --version - OUTPUT_VARIABLE _as_version - ) - # 2.38 or later - if (_as_version MATCHES "GNU.[Aa]ssembler.*(2\\.38|2\\.39|2\\.[4-9][0-9]|[3-9]\\.[0-9][0-9])") - set(MLAS_AMX_SUPPORTED TRUE) - endif() - endif() -endif() - -if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - set(MLAS_AMX_SUPPORTED TRUE) -endif() - - # # All hardware agnostic source files here # hardware specific files would cause trouble in @@ -57,12 +35,6 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp ) -if(MLAS_AMX_SUPPORTED) - target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_AMX_SUPPORTED) -else() - message(WARNING "AMX instructions NOT supported due to lack of compiler tool chain!") -endif() - set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) #TODO: set MASM flags properly @@ -550,15 +522,16 @@ else() ${mlas_platform_srcs_avx512core} ) - if(MLAS_AMX_SUPPORTED) + if(NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl") - endif() + ) + set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h new file mode 100644 index 0000000000..3eb0700932 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amx_common.h @@ -0,0 +1,80 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + amx_common.h + +Abstract: + + Intrinsic and inline functions for amx processing. + +--*/ + +#pragma once + +#include "mlasi.h" + +#ifdef WIN32 +#define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) + +#define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) + +#define tile_dpbusd(dst, src1, src2) _tile_dpbusd(dst, src1, src2) + +#define tile_dpbuud(dst, src1, src2) _tile_dpbuud(dst, src1, src2) + +#define tile_loadd(dst, base, stride) _tile_loadd(dst, base, stride) + +#define tile_stream_loadd(dst, base, stride) _tile_stream_loadd(dst, base, stride) + +#define tile_stored(dst, base, stride) _tile_stored(dst, base, stride) + +#define tile_loadconfig(config) \ + _tile_loadconfig(config) + +#define tile_storeconfig(config) _tile_storeconfig(config) + +#else + +#define tile_dpbusd_internal(dst,src1,src2) \ +__asm__ volatile (".set Payload1, 0x01\n\t" \ + ".set Payload1, Payload1 + (("#src2" & 15) ^ 15) << 3\n\t" \ + ".set ModRMByte, 0xC0\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".set ModRMByte, ModRMByte + ("#src1")\n\t" \ + ".byte 0xC4, 0xE2, Payload1, 0x5E, ModRMByte\n\t") + +#define tile_dpbusd(dst,src1,src2) \ +tile_dpbusd_internal(dst,src1,src2) + +#define tile_loadd_internal1(dst,base,stride) \ + __asm__ volatile (".set ModRMByte, 0x04\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".byte 0xC4, 0xE2, 0x7B, 0x4B, ModRMByte, 0x18\n\t" \ + :: "a" ((const void*) (base)), "b" ((long) (stride))) + +#define tile_loadd(dst,base,stride) \ + tile_loadd_internal1(dst, base, stride) + + +#define tile_stored_internal1(dst,base,stride) \ + __asm__ volatile (".set ModRMByte, 0x04\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".byte 0xC4, 0xE2, 0x7A, 0x4B, ModRMByte, 0x18\n\t" \ + :: "a" ((const void*) (base)), "b" ((long) (stride))) + +#define tile_stored(dst,base,stride) \ +tile_stored_internal1(dst, base, stride) + + +#define tile_loadconfig(config) \ +__asm__ volatile (".byte 0xC4, 0xE2, 0x78, 0x49, 0x00" :: "a" (((const void *)config))) \ + +#define tile_storeconfig(config) \ +__asm__ volatile (".byte 0xC4, 0xE2, 0x79, 0x49, 0x00" :: "a" (((const void *)config))) \ + +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6eb4cc446d..5a0ca3d5a9 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -824,9 +824,7 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; -#ifdef MLAS_AMX_SUPPORTED extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAmx; -#endif extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 5be38975f8..6446007610 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -408,7 +408,7 @@ Return Value: } } -#ifdef MLAS_AMX_SUPPORTED +#ifndef __APPLE__ // // Check if the processor supports AMX-TILE and AMX-INT8 // features. @@ -419,7 +419,7 @@ Return Value: this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; } } -#endif // MLAS_AMX_SUPPORTED +#endif // __APPLE__ #endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp index 7c8743026b..479a82e712 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp @@ -16,6 +16,7 @@ Abstract: #include "mlasi.h" #include "qgemm.h" +#include "amx_common.h" #define TMM0 0 @@ -202,7 +203,7 @@ MlasGemmQuantThreadInit() static thread_local struct tileconfig_t tc = {0}; struct tileconfig_t current_tc = {0}; - _tile_storeconfig(¤t_tc); + tile_storeconfig(¤t_tc); if (tc.palette_id == 0 || (std::memcmp(¤t_tc.colb, &tc.colb, sizeof(uint16_t) * 8) != 0 && std::memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { @@ -212,7 +213,8 @@ MlasGemmQuantThreadInit() tc.rows[t] = 16; tc.colb[t] = 64; } - _tile_loadconfig(&tc); + + tile_loadconfig(&tc); } } @@ -238,14 +240,14 @@ InitHalfTileWithRowColSums( row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6])); row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7])); if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7)); + row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); + row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); + row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); + row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); + row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); + row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); + row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); + row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); } _mm512_storeu_si512(Tile, row0); _mm512_storeu_si512(Tile+16, row1); @@ -290,14 +292,14 @@ InitHalfTileWithRowColSumsZeroPoints( row6 = _mm512_add_epi32(colsum, row6); row7 = _mm512_add_epi32(colsum, row7); if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7)); + row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); + row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); + row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); + row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); + row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); + row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); + row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); + row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); } _mm512_storeu_si512(Tile, row0); _mm512_storeu_si512(Tile+16, row1); @@ -435,58 +437,58 @@ MlasGemmQuantKernel( size_t n = CountN; for (; n >= 2 * TILE_N; n -= 2 * TILE_N) { - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; if (ZeroPointB != nullptr){ - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( Tile4, m0, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSumsZeroPoints( Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile4, m0, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSums( Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } - colsum = _mm512_loadu_epi32(col_sum_ptr); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; if (ZeroPointB != nullptr) { - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( Tile6, m0, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSumsZeroPoints( Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile6, m0, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSums( Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } @@ -494,33 +496,36 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM4, TMM2, TMM0); - _tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + + tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); if (m1 > 0){ - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM7, TMM3, TMM1); } b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } if (m0 == TILE_M) { - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + MoveTile(Tile4, m0, FullMask, c_blk, ldc); MoveTile(Tile6, m0, FullMask, c_blk + TILE_N, ldc); } if (m1 != 0){ - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); MoveTile(Tile5, m1, FullMask, c16_blk, ldc); - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile7, m1, FullMask, c16_blk + TILE_N, ldc); } c_blk += 2 * TILE_N; @@ -539,23 +544,23 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSumsZeroPoints( Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSums( Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } if (nmask_high != 0){ @@ -565,23 +570,23 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile6, m0, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSumsZeroPoints( Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile6, m0, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSums( Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } } @@ -589,18 +594,19 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + + tile_dpbusd(TMM4, TMM2, TMM0); if (m1 > 0){ - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM5, TMM3, TMM0); } if (nmask_high != 0){ - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); if (m1 > 0){ - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); } } b_blk += TILE_N * TILE_K; @@ -608,20 +614,20 @@ MlasGemmQuantKernel( a_next_blk += TILE_K; } if ((static_cast(nmasks) & 0x8000) != 0 && m0 == TILE_M){ - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); MoveTile(Tile4, m0, static_cast(nmasks), c_blk, ldc); } if (m1 > 0){ - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); MoveTile(Tile5, m1, static_cast(nmasks), c16_blk, ldc); } if (nmask_high != 0){ - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); MoveTile(Tile6, m0, nmask_high, c_blk + TILE_N, ldc); if (m1 > 0){ - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile7, m1, nmask_high, c16_blk + TILE_N, ldc); } } @@ -643,76 +649,78 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; if (ZeroPointB != nullptr){ - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; - _tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM0, b_blk, TILE_K); InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_epi32(col_sum_ptr); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - zeropoint = _mm512_loadu_epi32(zp_ptr); + zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM4, TMM2, TMM0); InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode); InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); } else { - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - _tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM0, b_blk, TILE_K); InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_epi32(col_sum_ptr); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM4, TMM2, TMM0); InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode); InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); } - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) { b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM7, TMM3, TMM1); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM4, TMM2, TMM0); } - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); + b_blk += PackedCountK * TILE_N + TILE_N * TILE_K; - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + c_blk += 2 * TILE_N; - _tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); c16_blk += 2 * TILE_N; } @@ -726,20 +734,20 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitTileWithRowColSumsZeroPoints( Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } else { InitTileWithRowColSums( Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitTileWithRowColSums( Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } if (nmask_high != 0){ colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr); @@ -748,52 +756,58 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); InitTileWithRowColSumsZeroPoints( Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } else { InitTileWithRowColSums( Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); InitTileWithRowColSums( Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM4, TMM2, TMM0); - _tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + + tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM5, TMM3, TMM0); + if (nmask_high != 0){ - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); + } b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } if ((static_cast(nmasks) & 0x8000) != 0){ - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + MoveTile(Tile4, TILE_M, static_cast(nmasks), c_blk, ldc); MoveTile(Tile5, TILE_M, static_cast(nmasks), c16_blk, ldc); } if (nmask_high != 0){ - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc); MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc); } diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S new file mode 100644 index 0000000000..7d042e2d8f --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S @@ -0,0 +1,234 @@ +/*++ + +Copyright (c) 2023 Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + AssembleAmx.h + +Abstract: + + This module contains macros to build AMX instructions for toolchains that + do not natively support this newer instruction set extension. + +--*/ + +// +// Map friendly register names to the encoded register index. +// + + .equ .LTmmIndex_tmm0, 0 + .equ .LTmmIndex_tmm1, 1 + .equ .LTmmIndex_tmm2, 2 + .equ .LTmmIndex_tmm3, 3 + .equ .LTmmIndex_tmm4, 4 + .equ .LTmmIndex_tmm5, 5 + .equ .LTmmIndex_tmm6, 6 + .equ .LTmmIndex_tmm7, 7 + +/*++ + +Macro Description: + + This macro builds a AMX instruction of the form: + + instr tmm1,tmm2,tmm3 + +Arguments: + + prefix - Specifies the opcode for the AMX instruction. + + DestReg - Specifies the destination AMX tile. + + Src1Reg - Specifies the first source AMX tile. + + Src2Reg - Specifies the second source AMX tile. + +--*/ + + .macro DPTmmTmmTmm prefix, DestReg, Src1Reg, Src2Reg + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LTmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) + .set Payload0, Payload0 + (1 << 6) + .set Payload0, Payload0 + ((((.LTmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) + + .set Payload1, \prefix\() + .set Payload1, Payload1 + (((.LTmmIndex_\Src2Reg\() & 15) ^ 15) << 3) + + .set ModRMByte, 0xC0 # register form + .set ModRMByte, ModRMByte + ((.LTmmIndex_\DestReg\() & 7) << 3) + .set ModRMByte, ModRMByte + (.LTmmIndex_\Src1Reg\() & 7) + + .byte 0xC4, Payload0, Payload1, 0x5E, ModRMByte + + .endm + + + .macro TdpbssdTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbsudTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbusdTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x01, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbuudTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + +/*++ + +Macro Description: + + This macro builds a AMX tile release instruction. + +Arguments: + + + +--*/ + +// .macro TileReleaseMacro + +// .byte 0xC4, 0xE2, 0x78, 0x49, 0xC0 + +// .endm + + +/*++ + +Macro Description: + + This macro builds an AMX tile zero instruction of the form: + + instr tmm1 + +Arguments: + + SrcReg - Specifies the source AMX tile. + +--*/ + + .macro TileZeroMacro SrcReg + + .set ModRMByte, 0xC0 # register form + .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) + .byte 0xC4, 0xE2, 0x7B, 0x49, ModRMByte + + .endm + +/*++ + +Macro Description: + + This macro builds an AMX memory instruction of the form: + + instr tmm, base, stride + +Arguments: + + instr - Specifies the opcode for the AMX instruction. + + SrcReg - Specifies the target AMX tile. + + BaseReg - Specifies the base address of memory location. + + Stride - Specifies the stride for the memory instruction + +--*/ + + .macro TileLoadMacro instr, SrcReg, BaseReg, Stride + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LTmmIndex_\SrcReg\() >> 3) & 1) ^ 1) << 7) + .set Payload0, Payload0 + ((((3 >> 3) & 1) ^ 1) << 6) + .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) + + .set ModRMByte, 0x00 # memory form + .set ModRMByte, ModRMByte + (1 << 2) # SibBye required + .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) + + .set SibByte, 0x00 # scale factor 1(SS) + .set SibByte, SibByte + ((3 & 7) << 3) + .set SibByte, SibByte + (0 & 7) + + .byte 0xC4, Payload0, \instr\(), 0x4B, ModRMByte, SibByte + + .endm + + + .macro TileloaddTmmMem DstReg, BaseReg, Stride + TileLoadMacro 0x7B, \DstReg\(), \BaseReg\(), \Stride\() + .endm + + .macro TileloaddT1TmmMem DstReg, BaseReg, Stride + TileLoadMacro 0x79, \DstReg\(), \BaseReg\(), \Stride\() + .endm + + + .macro TileStoredMemTmm SrcReg, BaseReg, Stride + TileLoadMacro 0x7A, \SrcReg\(), \BaseReg\(), \Stride\() + .endm + + +/*++ + +Macro Description: + + This macro builds an AMX tile configuration instruction of the form: + + instr base + +Arguments: + + instr - Specifies the opcode for the AMX instruction. + + BaseReg - Specifies the memory address of the tile configuration. + +--*/ + + .macro tilecfgMacro instr, BaseReg + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + (1 << 7) + .set Payload0, Payload0 + (1 << 6) + .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) + + .set ModRMByte, 0x00 # memory form & no reg + .set ModRMByte, ModRMByte + (0 & 7) + + .byte 0xC4, Payload0, \instr\(), 0x49, ModRMByte + + .endm + + + .macro ldtilecfgMacro BaseReg + + tilecfgMacro 0x78, \BaseReg\() + + .endm + + + .macro sttilecfgMacro BaseReg + + tilecfgMacro 0x79, \BaseReg\() + + .endm +