diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index ab05164c33..db40dee554 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -3,28 +3,6 @@ set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) - -set(MLAS_AMX_SUPPORTED FALSE) - -if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11) - # match assembler version, AMX instructions are supported from 2.38 - if (CMAKE_ASM_COMPILER_ID STREQUAL "GNU") - execute_process( - COMMAND as --version - OUTPUT_VARIABLE _as_version - ) - # 2.38 or later - if (_as_version MATCHES "GNU.[Aa]ssembler.*(2\\.38|2\\.39|2\\.[4-9][0-9]|[3-9]\\.[0-9][0-9])") - set(MLAS_AMX_SUPPORTED TRUE) - endif() - endif() -endif() - -if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - set(MLAS_AMX_SUPPORTED TRUE) -endif() - - # # All hardware agnostic source files here # hardware specific files would cause trouble in @@ -57,12 +35,6 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp ) -if(MLAS_AMX_SUPPORTED) - target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_AMX_SUPPORTED) -else() - message(WARNING "AMX instructions NOT supported due to lack of compiler tool chain!") -endif() - set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) #TODO: set MASM flags properly @@ -550,15 +522,16 @@ else() ${mlas_platform_srcs_avx512core} ) - if(MLAS_AMX_SUPPORTED) + if(NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl") - endif() + ) + set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h new file mode 100644 index 0000000000..3eb0700932 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amx_common.h @@ -0,0 +1,80 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + amx_common.h + +Abstract: + + Intrinsic and inline functions for amx processing. + +--*/ + +#pragma once + +#include "mlasi.h" + +#ifdef WIN32 +#define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) + +#define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) + +#define tile_dpbusd(dst, src1, src2) _tile_dpbusd(dst, src1, src2) + +#define tile_dpbuud(dst, src1, src2) _tile_dpbuud(dst, src1, src2) + +#define tile_loadd(dst, base, stride) _tile_loadd(dst, base, stride) + +#define tile_stream_loadd(dst, base, stride) _tile_stream_loadd(dst, base, stride) + +#define tile_stored(dst, base, stride) _tile_stored(dst, base, stride) + +#define tile_loadconfig(config) \ + _tile_loadconfig(config) + +#define tile_storeconfig(config) _tile_storeconfig(config) + +#else + +#define tile_dpbusd_internal(dst,src1,src2) \ +__asm__ volatile (".set Payload1, 0x01\n\t" \ + ".set Payload1, Payload1 + (("#src2" & 15) ^ 15) << 3\n\t" \ + ".set ModRMByte, 0xC0\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".set ModRMByte, ModRMByte + ("#src1")\n\t" \ + ".byte 0xC4, 0xE2, Payload1, 0x5E, ModRMByte\n\t") + +#define tile_dpbusd(dst,src1,src2) \ +tile_dpbusd_internal(dst,src1,src2) + +#define tile_loadd_internal1(dst,base,stride) \ + __asm__ volatile (".set ModRMByte, 0x04\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".byte 0xC4, 0xE2, 0x7B, 0x4B, ModRMByte, 0x18\n\t" \ + :: "a" ((const void*) (base)), "b" ((long) (stride))) + +#define tile_loadd(dst,base,stride) \ + tile_loadd_internal1(dst, base, stride) + + +#define tile_stored_internal1(dst,base,stride) \ + __asm__ volatile (".set ModRMByte, 0x04\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".byte 0xC4, 0xE2, 0x7A, 0x4B, ModRMByte, 0x18\n\t" \ + :: "a" ((const void*) (base)), "b" ((long) (stride))) + +#define tile_stored(dst,base,stride) \ +tile_stored_internal1(dst, base, stride) + + +#define tile_loadconfig(config) \ +__asm__ volatile (".byte 0xC4, 0xE2, 0x78, 0x49, 0x00" :: "a" (((const void *)config))) \ + +#define tile_storeconfig(config) \ +__asm__ volatile (".byte 0xC4, 0xE2, 0x79, 0x49, 0x00" :: "a" (((const void *)config))) \ + +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6eb4cc446d..5a0ca3d5a9 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -824,9 +824,7 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; -#ifdef MLAS_AMX_SUPPORTED extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAmx; -#endif extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 5be38975f8..6446007610 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -408,7 +408,7 @@ Return Value: } } -#ifdef MLAS_AMX_SUPPORTED +#ifndef __APPLE__ // // Check if the processor supports AMX-TILE and AMX-INT8 // features. @@ -419,7 +419,7 @@ Return Value: this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; } } -#endif // MLAS_AMX_SUPPORTED +#endif // __APPLE__ #endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp index 7c8743026b..479a82e712 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp @@ -16,6 +16,7 @@ Abstract: #include "mlasi.h" #include "qgemm.h" +#include "amx_common.h" #define TMM0 0 @@ -202,7 +203,7 @@ MlasGemmQuantThreadInit() static thread_local struct tileconfig_t tc = {0}; struct tileconfig_t current_tc = {0}; - _tile_storeconfig(¤t_tc); + tile_storeconfig(¤t_tc); if (tc.palette_id == 0 || (std::memcmp(¤t_tc.colb, &tc.colb, sizeof(uint16_t) * 8) != 0 && std::memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { @@ -212,7 +213,8 @@ MlasGemmQuantThreadInit() tc.rows[t] = 16; tc.colb[t] = 64; } - _tile_loadconfig(&tc); + + tile_loadconfig(&tc); } } @@ -238,14 +240,14 @@ InitHalfTileWithRowColSums( row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6])); row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7])); if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7)); + row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); + row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); + row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); + row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); + row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); + row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); + row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); + row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); } _mm512_storeu_si512(Tile, row0); _mm512_storeu_si512(Tile+16, row1); @@ -290,14 +292,14 @@ InitHalfTileWithRowColSumsZeroPoints( row6 = _mm512_add_epi32(colsum, row6); row7 = _mm512_add_epi32(colsum, row7); if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7)); + row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); + row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); + row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); + row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); + row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); + row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); + row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); + row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); } _mm512_storeu_si512(Tile, row0); _mm512_storeu_si512(Tile+16, row1); @@ -435,58 +437,58 @@ MlasGemmQuantKernel( size_t n = CountN; for (; n >= 2 * TILE_N; n -= 2 * TILE_N) { - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; if (ZeroPointB != nullptr){ - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( Tile4, m0, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSumsZeroPoints( Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile4, m0, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSums( Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } - colsum = _mm512_loadu_epi32(col_sum_ptr); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; if (ZeroPointB != nullptr) { - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( Tile6, m0, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSumsZeroPoints( Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile6, m0, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSums( Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } @@ -494,33 +496,36 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM4, TMM2, TMM0); - _tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + + tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); if (m1 > 0){ - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM7, TMM3, TMM1); } b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } if (m0 == TILE_M) { - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + MoveTile(Tile4, m0, FullMask, c_blk, ldc); MoveTile(Tile6, m0, FullMask, c_blk + TILE_N, ldc); } if (m1 != 0){ - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); MoveTile(Tile5, m1, FullMask, c16_blk, ldc); - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile7, m1, FullMask, c16_blk + TILE_N, ldc); } c_blk += 2 * TILE_N; @@ -539,23 +544,23 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSumsZeroPoints( Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSums( Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } if (nmask_high != 0){ @@ -565,23 +570,23 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile6, m0, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSumsZeroPoints( Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile6, m0, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSums( Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } } @@ -589,18 +594,19 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + + tile_dpbusd(TMM4, TMM2, TMM0); if (m1 > 0){ - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM5, TMM3, TMM0); } if (nmask_high != 0){ - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); if (m1 > 0){ - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); } } b_blk += TILE_N * TILE_K; @@ -608,20 +614,20 @@ MlasGemmQuantKernel( a_next_blk += TILE_K; } if ((static_cast(nmasks) & 0x8000) != 0 && m0 == TILE_M){ - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); MoveTile(Tile4, m0, static_cast(nmasks), c_blk, ldc); } if (m1 > 0){ - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); MoveTile(Tile5, m1, static_cast(nmasks), c16_blk, ldc); } if (nmask_high != 0){ - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); MoveTile(Tile6, m0, nmask_high, c_blk + TILE_N, ldc); if (m1 > 0){ - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile7, m1, nmask_high, c16_blk + TILE_N, ldc); } } @@ -643,76 +649,78 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; if (ZeroPointB != nullptr){ - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; - _tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM0, b_blk, TILE_K); InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_epi32(col_sum_ptr); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - zeropoint = _mm512_loadu_epi32(zp_ptr); + zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM4, TMM2, TMM0); InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode); InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); } else { - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - _tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM0, b_blk, TILE_K); InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_epi32(col_sum_ptr); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM4, TMM2, TMM0); InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode); InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); } - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) { b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM7, TMM3, TMM1); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM4, TMM2, TMM0); } - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); + b_blk += PackedCountK * TILE_N + TILE_N * TILE_K; - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + c_blk += 2 * TILE_N; - _tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); c16_blk += 2 * TILE_N; } @@ -726,20 +734,20 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitTileWithRowColSumsZeroPoints( Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } else { InitTileWithRowColSums( Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitTileWithRowColSums( Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } if (nmask_high != 0){ colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr); @@ -748,52 +756,58 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); InitTileWithRowColSumsZeroPoints( Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } else { InitTileWithRowColSums( Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); InitTileWithRowColSums( Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM4, TMM2, TMM0); - _tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + + tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM5, TMM3, TMM0); + if (nmask_high != 0){ - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); + } b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } if ((static_cast(nmasks) & 0x8000) != 0){ - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + MoveTile(Tile4, TILE_M, static_cast(nmasks), c_blk, ldc); MoveTile(Tile5, TILE_M, static_cast(nmasks), c16_blk, ldc); } if (nmask_high != 0){ - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc); MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc); } diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S new file mode 100644 index 0000000000..7d042e2d8f --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S @@ -0,0 +1,234 @@ +/*++ + +Copyright (c) 2023 Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + AssembleAmx.h + +Abstract: + + This module contains macros to build AMX instructions for toolchains that + do not natively support this newer instruction set extension. + +--*/ + +// +// Map friendly register names to the encoded register index. +// + + .equ .LTmmIndex_tmm0, 0 + .equ .LTmmIndex_tmm1, 1 + .equ .LTmmIndex_tmm2, 2 + .equ .LTmmIndex_tmm3, 3 + .equ .LTmmIndex_tmm4, 4 + .equ .LTmmIndex_tmm5, 5 + .equ .LTmmIndex_tmm6, 6 + .equ .LTmmIndex_tmm7, 7 + +/*++ + +Macro Description: + + This macro builds a AMX instruction of the form: + + instr tmm1,tmm2,tmm3 + +Arguments: + + prefix - Specifies the opcode for the AMX instruction. + + DestReg - Specifies the destination AMX tile. + + Src1Reg - Specifies the first source AMX tile. + + Src2Reg - Specifies the second source AMX tile. + +--*/ + + .macro DPTmmTmmTmm prefix, DestReg, Src1Reg, Src2Reg + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LTmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) + .set Payload0, Payload0 + (1 << 6) + .set Payload0, Payload0 + ((((.LTmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) + + .set Payload1, \prefix\() + .set Payload1, Payload1 + (((.LTmmIndex_\Src2Reg\() & 15) ^ 15) << 3) + + .set ModRMByte, 0xC0 # register form + .set ModRMByte, ModRMByte + ((.LTmmIndex_\DestReg\() & 7) << 3) + .set ModRMByte, ModRMByte + (.LTmmIndex_\Src1Reg\() & 7) + + .byte 0xC4, Payload0, Payload1, 0x5E, ModRMByte + + .endm + + + .macro TdpbssdTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbsudTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbusdTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x01, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbuudTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + +/*++ + +Macro Description: + + This macro builds a AMX tile release instruction. + +Arguments: + + + +--*/ + +// .macro TileReleaseMacro + +// .byte 0xC4, 0xE2, 0x78, 0x49, 0xC0 + +// .endm + + +/*++ + +Macro Description: + + This macro builds an AMX tile zero instruction of the form: + + instr tmm1 + +Arguments: + + SrcReg - Specifies the source AMX tile. + +--*/ + + .macro TileZeroMacro SrcReg + + .set ModRMByte, 0xC0 # register form + .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) + .byte 0xC4, 0xE2, 0x7B, 0x49, ModRMByte + + .endm + +/*++ + +Macro Description: + + This macro builds an AMX memory instruction of the form: + + instr tmm, base, stride + +Arguments: + + instr - Specifies the opcode for the AMX instruction. + + SrcReg - Specifies the target AMX tile. + + BaseReg - Specifies the base address of memory location. + + Stride - Specifies the stride for the memory instruction + +--*/ + + .macro TileLoadMacro instr, SrcReg, BaseReg, Stride + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LTmmIndex_\SrcReg\() >> 3) & 1) ^ 1) << 7) + .set Payload0, Payload0 + ((((3 >> 3) & 1) ^ 1) << 6) + .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) + + .set ModRMByte, 0x00 # memory form + .set ModRMByte, ModRMByte + (1 << 2) # SibBye required + .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) + + .set SibByte, 0x00 # scale factor 1(SS) + .set SibByte, SibByte + ((3 & 7) << 3) + .set SibByte, SibByte + (0 & 7) + + .byte 0xC4, Payload0, \instr\(), 0x4B, ModRMByte, SibByte + + .endm + + + .macro TileloaddTmmMem DstReg, BaseReg, Stride + TileLoadMacro 0x7B, \DstReg\(), \BaseReg\(), \Stride\() + .endm + + .macro TileloaddT1TmmMem DstReg, BaseReg, Stride + TileLoadMacro 0x79, \DstReg\(), \BaseReg\(), \Stride\() + .endm + + + .macro TileStoredMemTmm SrcReg, BaseReg, Stride + TileLoadMacro 0x7A, \SrcReg\(), \BaseReg\(), \Stride\() + .endm + + +/*++ + +Macro Description: + + This macro builds an AMX tile configuration instruction of the form: + + instr base + +Arguments: + + instr - Specifies the opcode for the AMX instruction. + + BaseReg - Specifies the memory address of the tile configuration. + +--*/ + + .macro tilecfgMacro instr, BaseReg + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + (1 << 7) + .set Payload0, Payload0 + (1 << 6) + .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) + + .set ModRMByte, 0x00 # memory form & no reg + .set ModRMByte, ModRMByte + (0 & 7) + + .byte 0xC4, Payload0, \instr\(), 0x49, ModRMByte + + .endm + + + .macro ldtilecfgMacro BaseReg + + tilecfgMacro 0x78, \BaseReg\() + + .endm + + + .macro sttilecfgMacro BaseReg + + tilecfgMacro 0x79, \BaseReg\() + + .endm +