Amx flag removal (#16527)

### Description
1. Replacing AMX intrinsics with machine code macros in QGEMM kernel.
2. Removing AMX build flags for GCC in cmake file.
3. Fixing the link time optimization (LTO) issue introduced with asm
.include of an assembly file.

I have moved the AMX instruction macro definitions from
QgemmU8S8KernelAmxCommon.S to the amx_common.h to fix the LTO issue.
Note that I am also pushing the macros defined in
QgemmU8S8KernelAmxCommon.S for future reference.

A special thanks to @laxmansole who helped in the development of the
instruction macro definitions for AMX intrinsics and fixing the LTO
issue.

### Motivation and Context
The additional AMX flag in cmake adds an extra layer of dependency on
GCC version to use the feature.These changes should allow the usage of
the AMX feature with just the CPU ID check.
This commit is contained in:
Dipanjan Sengupta 2023-07-13 11:19:49 -07:00 committed by GitHub
parent c07a3b869c
commit a461608409
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 461 additions and 162 deletions

View file

@ -3,28 +3,6 @@
set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
set(MLAS_AMX_SUPPORTED FALSE)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
# match assembler version, AMX instructions are supported from 2.38
if (CMAKE_ASM_COMPILER_ID STREQUAL "GNU")
execute_process(
COMMAND as --version
OUTPUT_VARIABLE _as_version
)
# 2.38 or later
if (_as_version MATCHES "GNU.[Aa]ssembler.*(2\\.38|2\\.39|2\\.[4-9][0-9]|[3-9]\\.[0-9][0-9])")
set(MLAS_AMX_SUPPORTED TRUE)
endif()
endif()
endif()
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
set(MLAS_AMX_SUPPORTED TRUE)
endif()
#
# All hardware agnostic source files here
# hardware specific files would cause trouble in
@ -57,12 +35,6 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
)
if(MLAS_AMX_SUPPORTED)
target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_AMX_SUPPORTED)
else()
message(WARNING "AMX instructions NOT supported due to lack of compiler tool chain!")
endif()
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
#TODO: set MASM flags properly
@ -550,15 +522,16 @@ else()
${mlas_platform_srcs_avx512core}
)
if(MLAS_AMX_SUPPORTED)
if(NOT APPLE)
set(mlas_platform_srcs
${mlas_platform_srcs}
${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl")
endif()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})

View file

@ -0,0 +1,80 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
amx_common.h
Abstract:
Intrinsic and inline functions for amx processing.
--*/
#pragma once
#include "mlasi.h"
#ifdef WIN32
#define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2)
#define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2)
#define tile_dpbusd(dst, src1, src2) _tile_dpbusd(dst, src1, src2)
#define tile_dpbuud(dst, src1, src2) _tile_dpbuud(dst, src1, src2)
#define tile_loadd(dst, base, stride) _tile_loadd(dst, base, stride)
#define tile_stream_loadd(dst, base, stride) _tile_stream_loadd(dst, base, stride)
#define tile_stored(dst, base, stride) _tile_stored(dst, base, stride)
#define tile_loadconfig(config) \
_tile_loadconfig(config)
#define tile_storeconfig(config) _tile_storeconfig(config)
#else
#define tile_dpbusd_internal(dst,src1,src2) \
__asm__ volatile (".set Payload1, 0x01\n\t" \
".set Payload1, Payload1 + (("#src2" & 15) ^ 15) << 3\n\t" \
".set ModRMByte, 0xC0\n\t" \
".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \
".set ModRMByte, ModRMByte + ("#src1")\n\t" \
".byte 0xC4, 0xE2, Payload1, 0x5E, ModRMByte\n\t")
#define tile_dpbusd(dst,src1,src2) \
tile_dpbusd_internal(dst,src1,src2)
#define tile_loadd_internal1(dst,base,stride) \
__asm__ volatile (".set ModRMByte, 0x04\n\t" \
".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \
".byte 0xC4, 0xE2, 0x7B, 0x4B, ModRMByte, 0x18\n\t" \
:: "a" ((const void*) (base)), "b" ((long) (stride)))
#define tile_loadd(dst,base,stride) \
tile_loadd_internal1(dst, base, stride)
#define tile_stored_internal1(dst,base,stride) \
__asm__ volatile (".set ModRMByte, 0x04\n\t" \
".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \
".byte 0xC4, 0xE2, 0x7A, 0x4B, ModRMByte, 0x18\n\t" \
:: "a" ((const void*) (base)), "b" ((long) (stride)))
#define tile_stored(dst,base,stride) \
tile_stored_internal1(dst, base, stride)
#define tile_loadconfig(config) \
__asm__ volatile (".byte 0xC4, 0xE2, 0x78, 0x49, 0x00" :: "a" (((const void *)config))) \
#define tile_storeconfig(config) \
__asm__ volatile (".byte 0xC4, 0xE2, 0x79, 0x49, 0x00" :: "a" (((const void *)config))) \
#endif

View file

@ -824,9 +824,7 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2;
#ifdef MLAS_AMX_SUPPORTED
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAmx;
#endif
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot;

View file

@ -408,7 +408,7 @@ Return Value:
}
}
#ifdef MLAS_AMX_SUPPORTED
#ifndef __APPLE__
//
// Check if the processor supports AMX-TILE and AMX-INT8
// features.
@ -419,7 +419,7 @@ Return Value:
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx;
}
}
#endif // MLAS_AMX_SUPPORTED
#endif // __APPLE__
#endif // ORT_MINIMAL_BUILD

View file

@ -16,6 +16,7 @@ Abstract:
#include "mlasi.h"
#include "qgemm.h"
#include "amx_common.h"
#define TMM0 0
@ -202,7 +203,7 @@ MlasGemmQuantThreadInit<MLAS_GEMM_U8S8_KERNEL_AMX>()
static thread_local struct tileconfig_t tc = {0};
struct tileconfig_t current_tc = {0};
_tile_storeconfig(&current_tc);
tile_storeconfig(&current_tc);
if (tc.palette_id == 0 || (std::memcmp(&current_tc.colb, &tc.colb, sizeof(uint16_t) * 8) != 0 &&
std::memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
@ -212,7 +213,8 @@ MlasGemmQuantThreadInit<MLAS_GEMM_U8S8_KERNEL_AMX>()
tc.rows[t] = 16;
tc.colb[t] = 64;
}
_tile_loadconfig(&tc);
tile_loadconfig(&tc);
}
}
@ -238,14 +240,14 @@ InitHalfTileWithRowColSums(
row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6]));
row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7]));
if (!ZeroMode){
row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr));
row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc));
row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2));
row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3));
row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4));
row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5));
row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6));
row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7));
row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr));
row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc));
row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2));
row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3));
row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4));
row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5));
row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6));
row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7));
}
_mm512_storeu_si512(Tile, row0);
_mm512_storeu_si512(Tile+16, row1);
@ -290,14 +292,14 @@ InitHalfTileWithRowColSumsZeroPoints(
row6 = _mm512_add_epi32(colsum, row6);
row7 = _mm512_add_epi32(colsum, row7);
if (!ZeroMode){
row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr));
row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc));
row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2));
row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3));
row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4));
row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5));
row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6));
row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7));
row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr));
row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc));
row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2));
row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3));
row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4));
row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5));
row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6));
row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7));
}
_mm512_storeu_si512(Tile, row0);
_mm512_storeu_si512(Tile+16, row1);
@ -435,58 +437,58 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
size_t n = CountN;
for (; n >= 2 * TILE_N; n -= 2 * TILE_N) {
__m512i colsum = _mm512_loadu_epi32(col_sum_ptr);
__m512i colsum = _mm512_loadu_si512(col_sum_ptr);
col_sum_ptr += TILE_N;
if (ZeroPointB != nullptr){
__m512i zeropoint = _mm512_loadu_epi32(zp_ptr);
__m512i zeropoint = _mm512_loadu_si512(zp_ptr);
zp_ptr += TILE_N;
InitTileWithRowColSumsZeroPoints(
Tile4, m0, FullMask, RowSumBuffer, colsum,
zeropoint, ZeroMode, c_blk, ldc);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
if (m1 != 0){
InitTileWithRowColSumsZeroPoints(
Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum,
zeropoint, ZeroMode, c16_blk, ldc);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
}
} else {
InitTileWithRowColSums(
Tile4, m0, FullMask, RowSumBuffer, colsum,
ZeroMode, c_blk, ldc);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
if (m1 != 0){
InitTileWithRowColSums(
Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum,
ZeroMode, c16_blk, ldc);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
}
}
colsum = _mm512_loadu_epi32(col_sum_ptr);
colsum = _mm512_loadu_si512(col_sum_ptr);
col_sum_ptr += TILE_N;
if (ZeroPointB != nullptr) {
__m512i zeropoint = _mm512_loadu_epi32(zp_ptr);
__m512i zeropoint = _mm512_loadu_si512(zp_ptr);
zp_ptr += TILE_N;
InitTileWithRowColSumsZeroPoints(
Tile6, m0, FullMask, RowSumBuffer, colsum,
zeropoint, ZeroMode, c_blk + TILE_N, ldc);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
if (m1 != 0){
InitTileWithRowColSumsZeroPoints(
Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum,
zeropoint, ZeroMode, c16_blk + TILE_N, ldc);
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
}
} else {
InitTileWithRowColSums(
Tile6, m0, FullMask, RowSumBuffer, colsum,
ZeroMode, c_blk + TILE_N, ldc);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
if (m1 != 0){
InitTileWithRowColSums(
Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum,
ZeroMode, c16_blk + TILE_N, ldc);
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
}
}
@ -494,33 +496,36 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A;
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
for (size_t k = PackedCountK; k > 0; k -=TILE_K) {
_tile_loadd(TMM0, b_blk, TILE_K);
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
_tile_dpbusd(TMM4, TMM2, TMM0);
_tile_dpbusd(TMM6, TMM2, TMM1);
tile_loadd(TMM0, b_blk, TILE_K);
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
tile_dpbusd(TMM4, TMM2, TMM0);
tile_dpbusd(TMM6, TMM2, TMM1);
if (m1 > 0){
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
_tile_dpbusd(TMM5, TMM3, TMM0);
_tile_dpbusd(TMM7, TMM3, TMM1);
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
tile_dpbusd(TMM5, TMM3, TMM0);
tile_dpbusd(TMM7, TMM3, TMM1);
}
b_blk += TILE_N * TILE_K;
a_blk += TILE_K;
a_next_blk += TILE_K;
}
if (m0 == TILE_M) {
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
_tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
} else {
_tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
_tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
MoveTile(Tile4, m0, FullMask, c_blk, ldc);
MoveTile(Tile6, m0, FullMask, c_blk + TILE_N, ldc);
}
if (m1 != 0){
_tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
MoveTile(Tile5, m1, FullMask, c16_blk, ldc);
_tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
MoveTile(Tile7, m1, FullMask, c16_blk + TILE_N, ldc);
}
c_blk += 2 * TILE_N;
@ -539,23 +544,23 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
InitTileWithRowColSumsZeroPoints(
Tile4, m0, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
zeropoint, ZeroMode, c_blk, ldc);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
if (m1 > 0){
InitTileWithRowColSumsZeroPoints(
Tile5, m1, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
zeropoint, ZeroMode, c16_blk, ldc);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
}
} else {
InitTileWithRowColSums(
Tile4, m0, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
ZeroMode, c_blk, ldc);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
if (m1 > 0){
InitTileWithRowColSums(
Tile5, m1, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
ZeroMode, c16_blk, ldc);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
}
}
if (nmask_high != 0){
@ -565,23 +570,23 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
InitTileWithRowColSumsZeroPoints(
Tile6, m0, nmask_high, RowSumBuffer, colsum,
zeropoint, ZeroMode, c_blk + TILE_N, ldc);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
if (m1 > 0){
InitTileWithRowColSumsZeroPoints(
Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum,
zeropoint, ZeroMode, c16_blk + TILE_N, ldc);
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
}
} else {
InitTileWithRowColSums(
Tile6, m0, nmask_high, RowSumBuffer, colsum,
ZeroMode, c_blk + TILE_N, ldc);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
if (m1 > 0){
InitTileWithRowColSums(
Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum,
ZeroMode, c16_blk + TILE_N, ldc);
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
}
}
}
@ -589,18 +594,19 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A;
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
for (size_t k = PackedCountK; k > 0; k -=TILE_K) {
_tile_loadd(TMM0, b_blk, TILE_K);
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
_tile_dpbusd(TMM4, TMM2, TMM0);
tile_loadd(TMM0, b_blk, TILE_K);
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
tile_dpbusd(TMM4, TMM2, TMM0);
if (m1 > 0){
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
_tile_dpbusd(TMM5, TMM3, TMM0);
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
tile_dpbusd(TMM5, TMM3, TMM0);
}
if (nmask_high != 0){
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
_tile_dpbusd(TMM6, TMM2, TMM1);
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
tile_dpbusd(TMM6, TMM2, TMM1);
if (m1 > 0){
_tile_dpbusd(TMM7, TMM3, TMM1);
tile_dpbusd(TMM7, TMM3, TMM1);
}
}
b_blk += TILE_N * TILE_K;
@ -608,20 +614,20 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
a_next_blk += TILE_K;
}
if ((static_cast<uint16_t>(nmasks) & 0x8000) != 0 && m0 == TILE_M){
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
} else {
_tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
MoveTile(Tile4, m0, static_cast<uint16_t>(nmasks), c_blk, ldc);
}
if (m1 > 0){
_tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
MoveTile(Tile5, m1, static_cast<uint16_t>(nmasks), c16_blk, ldc);
}
if (nmask_high != 0){
_tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
MoveTile(Tile6, m0, nmask_high, c_blk + TILE_N, ldc);
if (m1 > 0){
_tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
MoveTile(Tile7, m1, nmask_high, c16_blk + TILE_N, ldc);
}
}
@ -643,76 +649,78 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
if (ZeroPointB != nullptr){
__m512i colsum = _mm512_loadu_epi32(col_sum_ptr);
__m512i colsum = _mm512_loadu_si512(col_sum_ptr);
col_sum_ptr += TILE_N;
__m512i zeropoint = _mm512_loadu_epi32(zp_ptr);
__m512i zeropoint = _mm512_loadu_si512(zp_ptr);
zp_ptr += TILE_N;
_tile_loadd(TMM0, b_blk, TILE_K);
tile_loadd(TMM0, b_blk, TILE_K);
InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode);
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode);
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
colsum = _mm512_loadu_epi32(col_sum_ptr);
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
colsum = _mm512_loadu_si512(col_sum_ptr);
col_sum_ptr += TILE_N;
zeropoint = _mm512_loadu_epi32(zp_ptr);
zeropoint = _mm512_loadu_si512(zp_ptr);
zp_ptr += TILE_N;
InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode);
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
_tile_dpbusd(TMM4, TMM2, TMM0);
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_dpbusd(TMM4, TMM2, TMM0);
InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode);
InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode);
} else {
__m512i colsum = _mm512_loadu_epi32(col_sum_ptr);
__m512i colsum = _mm512_loadu_si512(col_sum_ptr);
col_sum_ptr += TILE_N;
_tile_loadd(TMM0, b_blk, TILE_K);
tile_loadd(TMM0, b_blk, TILE_K);
InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode);
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode);
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
colsum = _mm512_loadu_epi32(col_sum_ptr);
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
colsum = _mm512_loadu_si512(col_sum_ptr);
col_sum_ptr += TILE_N;
InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode);
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
_tile_dpbusd(TMM4, TMM2, TMM0);
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_dpbusd(TMM4, TMM2, TMM0);
InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode);
InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode);
}
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) {
b_blk += TILE_N * TILE_K;
a_blk += TILE_K;
a_next_blk += TILE_K;
_tile_dpbusd(TMM5, TMM3, TMM0);
_tile_loadd(TMM0, b_blk, TILE_K);
_tile_dpbusd(TMM6, TMM2, TMM1);
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
_tile_dpbusd(TMM7, TMM3, TMM1);
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
_tile_dpbusd(TMM4, TMM2, TMM0);
tile_dpbusd(TMM5, TMM3, TMM0);
tile_loadd(TMM0, b_blk, TILE_K);
tile_dpbusd(TMM6, TMM2, TMM1);
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
tile_dpbusd(TMM7, TMM3, TMM1);
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
tile_dpbusd(TMM4, TMM2, TMM0);
}
_tile_dpbusd(TMM5, TMM3, TMM0);
_tile_dpbusd(TMM6, TMM2, TMM1);
_tile_dpbusd(TMM7, TMM3, TMM1);
tile_dpbusd(TMM5, TMM3, TMM0);
tile_dpbusd(TMM6, TMM2, TMM1);
tile_dpbusd(TMM7, TMM3, TMM1);
b_blk += PackedCountK * TILE_N + TILE_N * TILE_K;
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
_tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
_tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
c_blk += 2 * TILE_N;
_tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
c16_blk += 2 * TILE_N;
}
@ -726,20 +734,20 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
InitTileWithRowColSumsZeroPoints(
Tile4, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
zeropoint, ZeroMode, c_blk, ldc);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
InitTileWithRowColSumsZeroPoints(
Tile5, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
zeropoint, ZeroMode, c16_blk, ldc);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
} else {
InitTileWithRowColSums(
Tile4, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
ZeroMode, c_blk, ldc);
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
InitTileWithRowColSums(
Tile5, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
ZeroMode, c16_blk, ldc);
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
}
if (nmask_high != 0){
colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr);
@ -748,52 +756,58 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
InitTileWithRowColSumsZeroPoints(
Tile6, TILE_M, nmask_high, RowSumBuffer, colsum,
zeropoint, ZeroMode, c_blk + TILE_N, ldc);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
InitTileWithRowColSumsZeroPoints(
Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum,
zeropoint, ZeroMode, c16_blk + TILE_N, ldc);
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
} else {
InitTileWithRowColSums(
Tile6, TILE_M, nmask_high, RowSumBuffer, colsum,
ZeroMode, c_blk + TILE_N, ldc);
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
InitTileWithRowColSums(
Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum,
ZeroMode, c16_blk + TILE_N, ldc);
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
}
}
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A;
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
for (size_t k = PackedCountK; k > 0; k -=TILE_K) {
_tile_loadd(TMM0, b_blk, TILE_K);
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
_tile_dpbusd(TMM4, TMM2, TMM0);
_tile_dpbusd(TMM5, TMM3, TMM0);
tile_loadd(TMM0, b_blk, TILE_K);
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
tile_dpbusd(TMM4, TMM2, TMM0);
tile_dpbusd(TMM5, TMM3, TMM0);
if (nmask_high != 0){
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
_tile_dpbusd(TMM6, TMM2, TMM1);
_tile_dpbusd(TMM7, TMM3, TMM1);
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
tile_dpbusd(TMM6, TMM2, TMM1);
tile_dpbusd(TMM7, TMM3, TMM1);
}
b_blk += TILE_N * TILE_K;
a_blk += TILE_K;
a_next_blk += TILE_K;
}
if ((static_cast<uint16_t>(nmasks) & 0x8000) != 0){
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
_tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
} else {
_tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
_tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
MoveTile(Tile4, TILE_M, static_cast<uint16_t>(nmasks), c_blk, ldc);
MoveTile(Tile5, TILE_M, static_cast<uint16_t>(nmasks), c16_blk, ldc);
}
if (nmask_high != 0){
_tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
_tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc);
MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc);
}

View file

@ -0,0 +1,234 @@
/*++
Copyright (c) 2023 Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
AssembleAmx.h
Abstract:
This module contains macros to build AMX instructions for toolchains that
do not natively support this newer instruction set extension.
--*/
//
// Map friendly register names to the encoded register index.
//
.equ .LTmmIndex_tmm0, 0
.equ .LTmmIndex_tmm1, 1
.equ .LTmmIndex_tmm2, 2
.equ .LTmmIndex_tmm3, 3
.equ .LTmmIndex_tmm4, 4
.equ .LTmmIndex_tmm5, 5
.equ .LTmmIndex_tmm6, 6
.equ .LTmmIndex_tmm7, 7
/*++
Macro Description:
This macro builds a AMX instruction of the form:
instr tmm1,tmm2,tmm3
Arguments:
prefix - Specifies the opcode for the AMX instruction.
DestReg - Specifies the destination AMX tile.
Src1Reg - Specifies the first source AMX tile.
Src2Reg - Specifies the second source AMX tile.
--*/
.macro DPTmmTmmTmm prefix, DestReg, Src1Reg, Src2Reg
.set Payload0, 0x02 # "0F 38" prefix
.set Payload0, Payload0 + ((((.LTmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7)
.set Payload0, Payload0 + (1 << 6)
.set Payload0, Payload0 + ((((.LTmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5)
.set Payload1, \prefix\()
.set Payload1, Payload1 + (((.LTmmIndex_\Src2Reg\() & 15) ^ 15) << 3)
.set ModRMByte, 0xC0 # register form
.set ModRMByte, ModRMByte + ((.LTmmIndex_\DestReg\() & 7) << 3)
.set ModRMByte, ModRMByte + (.LTmmIndex_\Src1Reg\() & 7)
.byte 0xC4, Payload0, Payload1, 0x5E, ModRMByte
.endm
.macro TdpbssdTmmTmmTmm DestReg, Src1Reg, Src2Reg
DPTmmTmmTmm 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\()
.endm
.macro TdpbsudTmmTmmTmm DestReg, Src1Reg, Src2Reg
DPTmmTmmTmm 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\()
.endm
.macro TdpbusdTmmTmmTmm DestReg, Src1Reg, Src2Reg
DPTmmTmmTmm 0x01, \DestReg\(), \Src1Reg\(), \Src2Reg\()
.endm
.macro TdpbuudTmmTmmTmm DestReg, Src1Reg, Src2Reg
DPTmmTmmTmm 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\()
.endm
/*++
Macro Description:
This macro builds a AMX tile release instruction.
Arguments:
--*/
// .macro TileReleaseMacro
// .byte 0xC4, 0xE2, 0x78, 0x49, 0xC0
// .endm
/*++
Macro Description:
This macro builds an AMX tile zero instruction of the form:
instr tmm1
Arguments:
SrcReg - Specifies the source AMX tile.
--*/
.macro TileZeroMacro SrcReg
.set ModRMByte, 0xC0 # register form
.set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3)
.byte 0xC4, 0xE2, 0x7B, 0x49, ModRMByte
.endm
/*++
Macro Description:
This macro builds an AMX memory instruction of the form:
instr tmm, base, stride
Arguments:
instr - Specifies the opcode for the AMX instruction.
SrcReg - Specifies the target AMX tile.
BaseReg - Specifies the base address of memory location.
Stride - Specifies the stride for the memory instruction
--*/
.macro TileLoadMacro instr, SrcReg, BaseReg, Stride
.set Payload0, 0x02 # "0F 38" prefix
.set Payload0, Payload0 + ((((.LTmmIndex_\SrcReg\() >> 3) & 1) ^ 1) << 7)
.set Payload0, Payload0 + ((((3 >> 3) & 1) ^ 1) << 6)
.set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5)
.set ModRMByte, 0x00 # memory form
.set ModRMByte, ModRMByte + (1 << 2) # SibBye required
.set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3)
.set SibByte, 0x00 # scale factor 1(SS)
.set SibByte, SibByte + ((3 & 7) << 3)
.set SibByte, SibByte + (0 & 7)
.byte 0xC4, Payload0, \instr\(), 0x4B, ModRMByte, SibByte
.endm
.macro TileloaddTmmMem DstReg, BaseReg, Stride
TileLoadMacro 0x7B, \DstReg\(), \BaseReg\(), \Stride\()
.endm
.macro TileloaddT1TmmMem DstReg, BaseReg, Stride
TileLoadMacro 0x79, \DstReg\(), \BaseReg\(), \Stride\()
.endm
.macro TileStoredMemTmm SrcReg, BaseReg, Stride
TileLoadMacro 0x7A, \SrcReg\(), \BaseReg\(), \Stride\()
.endm
/*++
Macro Description:
This macro builds an AMX tile configuration instruction of the form:
instr base
Arguments:
instr - Specifies the opcode for the AMX instruction.
BaseReg - Specifies the memory address of the tile configuration.
--*/
.macro tilecfgMacro instr, BaseReg
.set Payload0, 0x02 # "0F 38" prefix
.set Payload0, Payload0 + (1 << 7)
.set Payload0, Payload0 + (1 << 6)
.set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5)
.set ModRMByte, 0x00 # memory form & no reg
.set ModRMByte, ModRMByte + (0 & 7)
.byte 0xC4, Payload0, \instr\(), 0x49, ModRMByte
.endm
.macro ldtilecfgMacro BaseReg
tilecfgMacro 0x78, \BaseReg\()
.endm
.macro sttilecfgMacro BaseReg
tilecfgMacro 0x79, \BaseReg\()
.endm