mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Amx flag removal (#16527)
### Description 1. Replacing AMX intrinsics with machine code macros in QGEMM kernel. 2. Removing AMX build flags for GCC in cmake file. 3. Fixing the link time optimization (LTO) issue introduced with asm .include of an assembly file. I have moved the AMX instruction macro definitions from QgemmU8S8KernelAmxCommon.S to the amx_common.h to fix the LTO issue. Note that I am also pushing the macros defined in QgemmU8S8KernelAmxCommon.S for future reference. A special thanks to @laxmansole who helped in the development of the instruction macro definitions for AMX intrinsics and fixing the LTO issue. ### Motivation and Context The additional AMX flag in cmake adds an extra layer of dependency on GCC version to use the feature.These changes should allow the usage of the AMX feature with just the CPU ID check.
This commit is contained in:
parent
c07a3b869c
commit
a461608409
6 changed files with 461 additions and 162 deletions
|
|
@ -3,28 +3,6 @@
|
|||
|
||||
set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
|
||||
|
||||
|
||||
set(MLAS_AMX_SUPPORTED FALSE)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
|
||||
# match assembler version, AMX instructions are supported from 2.38
|
||||
if (CMAKE_ASM_COMPILER_ID STREQUAL "GNU")
|
||||
execute_process(
|
||||
COMMAND as --version
|
||||
OUTPUT_VARIABLE _as_version
|
||||
)
|
||||
# 2.38 or later
|
||||
if (_as_version MATCHES "GNU.[Aa]ssembler.*(2\\.38|2\\.39|2\\.[4-9][0-9]|[3-9]\\.[0-9][0-9])")
|
||||
set(MLAS_AMX_SUPPORTED TRUE)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
|
||||
set(MLAS_AMX_SUPPORTED TRUE)
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# All hardware agnostic source files here
|
||||
# hardware specific files would cause trouble in
|
||||
|
|
@ -57,12 +35,6 @@ onnxruntime_add_static_library(onnxruntime_mlas
|
|||
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
|
||||
)
|
||||
|
||||
if(MLAS_AMX_SUPPORTED)
|
||||
target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_AMX_SUPPORTED)
|
||||
else()
|
||||
message(WARNING "AMX instructions NOT supported due to lack of compiler tool chain!")
|
||||
endif()
|
||||
|
||||
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
|
||||
|
||||
#TODO: set MASM flags properly
|
||||
|
|
@ -550,15 +522,16 @@ else()
|
|||
${mlas_platform_srcs_avx512core}
|
||||
)
|
||||
|
||||
if(MLAS_AMX_SUPPORTED)
|
||||
if(NOT APPLE)
|
||||
set(mlas_platform_srcs
|
||||
${mlas_platform_srcs}
|
||||
${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S
|
||||
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
|
||||
${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl")
|
||||
endif()
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
|
||||
endif()
|
||||
|
||||
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
|
||||
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})
|
||||
|
|
|
|||
80
onnxruntime/core/mlas/lib/amx_common.h
Normal file
80
onnxruntime/core/mlas/lib/amx_common.h
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
amx_common.h
|
||||
|
||||
Abstract:
|
||||
|
||||
Intrinsic and inline functions for amx processing.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
#ifdef WIN32
|
||||
#define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2)
|
||||
|
||||
#define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2)
|
||||
|
||||
#define tile_dpbusd(dst, src1, src2) _tile_dpbusd(dst, src1, src2)
|
||||
|
||||
#define tile_dpbuud(dst, src1, src2) _tile_dpbuud(dst, src1, src2)
|
||||
|
||||
#define tile_loadd(dst, base, stride) _tile_loadd(dst, base, stride)
|
||||
|
||||
#define tile_stream_loadd(dst, base, stride) _tile_stream_loadd(dst, base, stride)
|
||||
|
||||
#define tile_stored(dst, base, stride) _tile_stored(dst, base, stride)
|
||||
|
||||
#define tile_loadconfig(config) \
|
||||
_tile_loadconfig(config)
|
||||
|
||||
#define tile_storeconfig(config) _tile_storeconfig(config)
|
||||
|
||||
#else
|
||||
|
||||
#define tile_dpbusd_internal(dst,src1,src2) \
|
||||
__asm__ volatile (".set Payload1, 0x01\n\t" \
|
||||
".set Payload1, Payload1 + (("#src2" & 15) ^ 15) << 3\n\t" \
|
||||
".set ModRMByte, 0xC0\n\t" \
|
||||
".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \
|
||||
".set ModRMByte, ModRMByte + ("#src1")\n\t" \
|
||||
".byte 0xC4, 0xE2, Payload1, 0x5E, ModRMByte\n\t")
|
||||
|
||||
#define tile_dpbusd(dst,src1,src2) \
|
||||
tile_dpbusd_internal(dst,src1,src2)
|
||||
|
||||
#define tile_loadd_internal1(dst,base,stride) \
|
||||
__asm__ volatile (".set ModRMByte, 0x04\n\t" \
|
||||
".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \
|
||||
".byte 0xC4, 0xE2, 0x7B, 0x4B, ModRMByte, 0x18\n\t" \
|
||||
:: "a" ((const void*) (base)), "b" ((long) (stride)))
|
||||
|
||||
#define tile_loadd(dst,base,stride) \
|
||||
tile_loadd_internal1(dst, base, stride)
|
||||
|
||||
|
||||
#define tile_stored_internal1(dst,base,stride) \
|
||||
__asm__ volatile (".set ModRMByte, 0x04\n\t" \
|
||||
".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \
|
||||
".byte 0xC4, 0xE2, 0x7A, 0x4B, ModRMByte, 0x18\n\t" \
|
||||
:: "a" ((const void*) (base)), "b" ((long) (stride)))
|
||||
|
||||
#define tile_stored(dst,base,stride) \
|
||||
tile_stored_internal1(dst, base, stride)
|
||||
|
||||
|
||||
#define tile_loadconfig(config) \
|
||||
__asm__ volatile (".byte 0xC4, 0xE2, 0x78, 0x49, 0x00" :: "a" (((const void *)config))) \
|
||||
|
||||
#define tile_storeconfig(config) \
|
||||
__asm__ volatile (".byte 0xC4, 0xE2, 0x79, 0x49, 0x00" :: "a" (((const void *)config))) \
|
||||
|
||||
#endif
|
||||
|
|
@ -824,9 +824,7 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse;
|
|||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2;
|
||||
#ifdef MLAS_AMX_SUPPORTED
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAmx;
|
||||
#endif
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon;
|
||||
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot;
|
||||
|
|
|
|||
|
|
@ -408,7 +408,7 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef MLAS_AMX_SUPPORTED
|
||||
#ifndef __APPLE__
|
||||
//
|
||||
// Check if the processor supports AMX-TILE and AMX-INT8
|
||||
// features.
|
||||
|
|
@ -419,7 +419,7 @@ Return Value:
|
|||
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx;
|
||||
}
|
||||
}
|
||||
#endif // MLAS_AMX_SUPPORTED
|
||||
#endif // __APPLE__
|
||||
|
||||
#endif // ORT_MINIMAL_BUILD
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ Abstract:
|
|||
|
||||
#include "mlasi.h"
|
||||
#include "qgemm.h"
|
||||
#include "amx_common.h"
|
||||
|
||||
|
||||
#define TMM0 0
|
||||
|
|
@ -202,7 +203,7 @@ MlasGemmQuantThreadInit<MLAS_GEMM_U8S8_KERNEL_AMX>()
|
|||
|
||||
static thread_local struct tileconfig_t tc = {0};
|
||||
struct tileconfig_t current_tc = {0};
|
||||
_tile_storeconfig(¤t_tc);
|
||||
tile_storeconfig(¤t_tc);
|
||||
|
||||
if (tc.palette_id == 0 || (std::memcmp(¤t_tc.colb, &tc.colb, sizeof(uint16_t) * 8) != 0 &&
|
||||
std::memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
|
||||
|
|
@ -212,7 +213,8 @@ MlasGemmQuantThreadInit<MLAS_GEMM_U8S8_KERNEL_AMX>()
|
|||
tc.rows[t] = 16;
|
||||
tc.colb[t] = 64;
|
||||
}
|
||||
_tile_loadconfig(&tc);
|
||||
|
||||
tile_loadconfig(&tc);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -238,14 +240,14 @@ InitHalfTileWithRowColSums(
|
|||
row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6]));
|
||||
row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7]));
|
||||
if (!ZeroMode){
|
||||
row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr));
|
||||
row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc));
|
||||
row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2));
|
||||
row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3));
|
||||
row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4));
|
||||
row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5));
|
||||
row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6));
|
||||
row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7));
|
||||
row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr));
|
||||
row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc));
|
||||
row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2));
|
||||
row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3));
|
||||
row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4));
|
||||
row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5));
|
||||
row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6));
|
||||
row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7));
|
||||
}
|
||||
_mm512_storeu_si512(Tile, row0);
|
||||
_mm512_storeu_si512(Tile+16, row1);
|
||||
|
|
@ -290,14 +292,14 @@ InitHalfTileWithRowColSumsZeroPoints(
|
|||
row6 = _mm512_add_epi32(colsum, row6);
|
||||
row7 = _mm512_add_epi32(colsum, row7);
|
||||
if (!ZeroMode){
|
||||
row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr));
|
||||
row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc));
|
||||
row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2));
|
||||
row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3));
|
||||
row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4));
|
||||
row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5));
|
||||
row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6));
|
||||
row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7));
|
||||
row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr));
|
||||
row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc));
|
||||
row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2));
|
||||
row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3));
|
||||
row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4));
|
||||
row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5));
|
||||
row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6));
|
||||
row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7));
|
||||
}
|
||||
_mm512_storeu_si512(Tile, row0);
|
||||
_mm512_storeu_si512(Tile+16, row1);
|
||||
|
|
@ -435,58 +437,58 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
|
||||
size_t n = CountN;
|
||||
for (; n >= 2 * TILE_N; n -= 2 * TILE_N) {
|
||||
__m512i colsum = _mm512_loadu_epi32(col_sum_ptr);
|
||||
__m512i colsum = _mm512_loadu_si512(col_sum_ptr);
|
||||
col_sum_ptr += TILE_N;
|
||||
if (ZeroPointB != nullptr){
|
||||
__m512i zeropoint = _mm512_loadu_epi32(zp_ptr);
|
||||
__m512i zeropoint = _mm512_loadu_si512(zp_ptr);
|
||||
zp_ptr += TILE_N;
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile4, m0, FullMask, RowSumBuffer, colsum,
|
||||
zeropoint, ZeroMode, c_blk, ldc);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
if (m1 != 0){
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum,
|
||||
zeropoint, ZeroMode, c16_blk, ldc);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
InitTileWithRowColSums(
|
||||
Tile4, m0, FullMask, RowSumBuffer, colsum,
|
||||
ZeroMode, c_blk, ldc);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
if (m1 != 0){
|
||||
InitTileWithRowColSums(
|
||||
Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum,
|
||||
ZeroMode, c16_blk, ldc);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
colsum = _mm512_loadu_epi32(col_sum_ptr);
|
||||
colsum = _mm512_loadu_si512(col_sum_ptr);
|
||||
col_sum_ptr += TILE_N;
|
||||
if (ZeroPointB != nullptr) {
|
||||
__m512i zeropoint = _mm512_loadu_epi32(zp_ptr);
|
||||
__m512i zeropoint = _mm512_loadu_si512(zp_ptr);
|
||||
zp_ptr += TILE_N;
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile6, m0, FullMask, RowSumBuffer, colsum,
|
||||
zeropoint, ZeroMode, c_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
if (m1 != 0){
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum,
|
||||
zeropoint, ZeroMode, c16_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
InitTileWithRowColSums(
|
||||
Tile6, m0, FullMask, RowSumBuffer, colsum,
|
||||
ZeroMode, c_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
if (m1 != 0){
|
||||
InitTileWithRowColSums(
|
||||
Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum,
|
||||
ZeroMode, c16_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -494,33 +496,36 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A;
|
||||
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
|
||||
for (size_t k = PackedCountK; k > 0; k -=TILE_K) {
|
||||
_tile_loadd(TMM0, b_blk, TILE_K);
|
||||
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
_tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
_tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
tile_loadd(TMM0, b_blk, TILE_K);
|
||||
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
|
||||
tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
if (m1 > 0){
|
||||
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
_tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
_tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
}
|
||||
b_blk += TILE_N * TILE_K;
|
||||
a_blk += TILE_K;
|
||||
a_next_blk += TILE_K;
|
||||
}
|
||||
if (m0 == TILE_M) {
|
||||
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
_tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
|
||||
|
||||
} else {
|
||||
_tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
_tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
|
||||
MoveTile(Tile4, m0, FullMask, c_blk, ldc);
|
||||
MoveTile(Tile6, m0, FullMask, c_blk + TILE_N, ldc);
|
||||
}
|
||||
if (m1 != 0){
|
||||
_tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
MoveTile(Tile5, m1, FullMask, c16_blk, ldc);
|
||||
_tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
MoveTile(Tile7, m1, FullMask, c16_blk + TILE_N, ldc);
|
||||
}
|
||||
c_blk += 2 * TILE_N;
|
||||
|
|
@ -539,23 +544,23 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile4, m0, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
|
||||
zeropoint, ZeroMode, c_blk, ldc);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
if (m1 > 0){
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile5, m1, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
|
||||
zeropoint, ZeroMode, c16_blk, ldc);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
InitTileWithRowColSums(
|
||||
Tile4, m0, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
|
||||
ZeroMode, c_blk, ldc);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
if (m1 > 0){
|
||||
InitTileWithRowColSums(
|
||||
Tile5, m1, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
|
||||
ZeroMode, c16_blk, ldc);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
if (nmask_high != 0){
|
||||
|
|
@ -565,23 +570,23 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile6, m0, nmask_high, RowSumBuffer, colsum,
|
||||
zeropoint, ZeroMode, c_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
if (m1 > 0){
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum,
|
||||
zeropoint, ZeroMode, c16_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
InitTileWithRowColSums(
|
||||
Tile6, m0, nmask_high, RowSumBuffer, colsum,
|
||||
ZeroMode, c_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
if (m1 > 0){
|
||||
InitTileWithRowColSums(
|
||||
Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum,
|
||||
ZeroMode, c16_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -589,18 +594,19 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A;
|
||||
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
|
||||
for (size_t k = PackedCountK; k > 0; k -=TILE_K) {
|
||||
_tile_loadd(TMM0, b_blk, TILE_K);
|
||||
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
_tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
tile_loadd(TMM0, b_blk, TILE_K);
|
||||
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
|
||||
tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
if (m1 > 0){
|
||||
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
_tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
}
|
||||
if (nmask_high != 0){
|
||||
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
_tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
if (m1 > 0){
|
||||
_tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
}
|
||||
}
|
||||
b_blk += TILE_N * TILE_K;
|
||||
|
|
@ -608,20 +614,20 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
a_next_blk += TILE_K;
|
||||
}
|
||||
if ((static_cast<uint16_t>(nmasks) & 0x8000) != 0 && m0 == TILE_M){
|
||||
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
} else {
|
||||
_tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
MoveTile(Tile4, m0, static_cast<uint16_t>(nmasks), c_blk, ldc);
|
||||
}
|
||||
if (m1 > 0){
|
||||
_tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
MoveTile(Tile5, m1, static_cast<uint16_t>(nmasks), c16_blk, ldc);
|
||||
}
|
||||
if (nmask_high != 0){
|
||||
_tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
MoveTile(Tile6, m0, nmask_high, c_blk + TILE_N, ldc);
|
||||
if (m1 > 0){
|
||||
_tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
MoveTile(Tile7, m1, nmask_high, c16_blk + TILE_N, ldc);
|
||||
}
|
||||
}
|
||||
|
|
@ -643,76 +649,78 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
|
||||
|
||||
if (ZeroPointB != nullptr){
|
||||
__m512i colsum = _mm512_loadu_epi32(col_sum_ptr);
|
||||
__m512i colsum = _mm512_loadu_si512(col_sum_ptr);
|
||||
col_sum_ptr += TILE_N;
|
||||
__m512i zeropoint = _mm512_loadu_epi32(zp_ptr);
|
||||
__m512i zeropoint = _mm512_loadu_si512(zp_ptr);
|
||||
zp_ptr += TILE_N;
|
||||
_tile_loadd(TMM0, b_blk, TILE_K);
|
||||
tile_loadd(TMM0, b_blk, TILE_K);
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode);
|
||||
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode);
|
||||
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
colsum = _mm512_loadu_epi32(col_sum_ptr);
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
colsum = _mm512_loadu_si512(col_sum_ptr);
|
||||
col_sum_ptr += TILE_N;
|
||||
zeropoint = _mm512_loadu_epi32(zp_ptr);
|
||||
zeropoint = _mm512_loadu_si512(zp_ptr);
|
||||
zp_ptr += TILE_N;
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode);
|
||||
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
_tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode);
|
||||
InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode);
|
||||
} else {
|
||||
__m512i colsum = _mm512_loadu_epi32(col_sum_ptr);
|
||||
__m512i colsum = _mm512_loadu_si512(col_sum_ptr);
|
||||
col_sum_ptr += TILE_N;
|
||||
_tile_loadd(TMM0, b_blk, TILE_K);
|
||||
tile_loadd(TMM0, b_blk, TILE_K);
|
||||
InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode);
|
||||
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode);
|
||||
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
colsum = _mm512_loadu_epi32(col_sum_ptr);
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
colsum = _mm512_loadu_si512(col_sum_ptr);
|
||||
col_sum_ptr += TILE_N;
|
||||
InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode);
|
||||
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
_tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode);
|
||||
InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode);
|
||||
}
|
||||
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
|
||||
for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) {
|
||||
b_blk += TILE_N * TILE_K;
|
||||
a_blk += TILE_K;
|
||||
a_next_blk += TILE_K;
|
||||
_tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
_tile_loadd(TMM0, b_blk, TILE_K);
|
||||
_tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
_tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
_tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
tile_loadd(TMM0, b_blk, TILE_K);
|
||||
tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
}
|
||||
_tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
_tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
_tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
|
||||
b_blk += PackedCountK * TILE_N + TILE_N * TILE_K;
|
||||
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
_tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
_tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
|
||||
|
||||
c_blk += 2 * TILE_N;
|
||||
_tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast<int>(ldc * sizeof(int32_t)));
|
||||
c16_blk += 2 * TILE_N;
|
||||
}
|
||||
|
||||
|
|
@ -726,20 +734,20 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile4, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
|
||||
zeropoint, ZeroMode, c_blk, ldc);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile5, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
|
||||
zeropoint, ZeroMode, c16_blk, ldc);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
} else {
|
||||
InitTileWithRowColSums(
|
||||
Tile4, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer, colsum,
|
||||
ZeroMode, c_blk, ldc);
|
||||
_tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
InitTileWithRowColSums(
|
||||
Tile5, TILE_M, static_cast<uint16_t>(nmasks), RowSumBuffer + TILE_M, colsum,
|
||||
ZeroMode, c16_blk, ldc);
|
||||
_tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
if (nmask_high != 0){
|
||||
colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr);
|
||||
|
|
@ -748,52 +756,58 @@ MlasGemmQuantKernel<MLAS_GEMM_U8S8_KERNEL_AMX>(
|
|||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile6, TILE_M, nmask_high, RowSumBuffer, colsum,
|
||||
zeropoint, ZeroMode, c_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
InitTileWithRowColSumsZeroPoints(
|
||||
Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum,
|
||||
zeropoint, ZeroMode, c16_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
} else {
|
||||
InitTileWithRowColSums(
|
||||
Tile6, TILE_M, nmask_high, RowSumBuffer, colsum,
|
||||
ZeroMode, c_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
InitTileWithRowColSums(
|
||||
Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum,
|
||||
ZeroMode, c16_blk + TILE_N, ldc);
|
||||
_tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A;
|
||||
const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M;
|
||||
for (size_t k = PackedCountK; k > 0; k -=TILE_K) {
|
||||
_tile_loadd(TMM0, b_blk, TILE_K);
|
||||
_tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
_tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
_tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
_tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
tile_loadd(TMM0, b_blk, TILE_K);
|
||||
tile_loadd(TMM2, a_blk, static_cast<int>(PackedCountK));
|
||||
tile_loadd(TMM3, a_next_blk, static_cast<int>(PackedCountK));
|
||||
|
||||
tile_dpbusd(TMM4, TMM2, TMM0);
|
||||
tile_dpbusd(TMM5, TMM3, TMM0);
|
||||
|
||||
if (nmask_high != 0){
|
||||
_tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
_tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
_tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K);
|
||||
tile_dpbusd(TMM6, TMM2, TMM1);
|
||||
tile_dpbusd(TMM7, TMM3, TMM1);
|
||||
|
||||
}
|
||||
b_blk += TILE_N * TILE_K;
|
||||
a_blk += TILE_K;
|
||||
a_next_blk += TILE_K;
|
||||
}
|
||||
if ((static_cast<uint16_t>(nmasks) & 0x8000) != 0){
|
||||
_tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
_tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM4, c_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
tile_stored(TMM5, c16_blk, static_cast<int>(ldc * sizeof(int32_t)));
|
||||
|
||||
} else {
|
||||
_tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
_tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
||||
|
||||
MoveTile(Tile4, TILE_M, static_cast<uint16_t>(nmasks), c_blk, ldc);
|
||||
MoveTile(Tile5, TILE_M, static_cast<uint16_t>(nmasks), c16_blk, ldc);
|
||||
}
|
||||
if (nmask_high != 0){
|
||||
_tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
_tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
||||
tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
||||
|
||||
MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc);
|
||||
MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc);
|
||||
}
|
||||
|
|
|
|||
234
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S
Normal file
234
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) 2023 Intel Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
AssembleAmx.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains macros to build AMX instructions for toolchains that
|
||||
do not natively support this newer instruction set extension.
|
||||
|
||||
--*/
|
||||
|
||||
//
|
||||
// Map friendly register names to the encoded register index.
|
||||
//
|
||||
|
||||
.equ .LTmmIndex_tmm0, 0
|
||||
.equ .LTmmIndex_tmm1, 1
|
||||
.equ .LTmmIndex_tmm2, 2
|
||||
.equ .LTmmIndex_tmm3, 3
|
||||
.equ .LTmmIndex_tmm4, 4
|
||||
.equ .LTmmIndex_tmm5, 5
|
||||
.equ .LTmmIndex_tmm6, 6
|
||||
.equ .LTmmIndex_tmm7, 7
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro builds a AMX instruction of the form:
|
||||
|
||||
instr tmm1,tmm2,tmm3
|
||||
|
||||
Arguments:
|
||||
|
||||
prefix - Specifies the opcode for the AMX instruction.
|
||||
|
||||
DestReg - Specifies the destination AMX tile.
|
||||
|
||||
Src1Reg - Specifies the first source AMX tile.
|
||||
|
||||
Src2Reg - Specifies the second source AMX tile.
|
||||
|
||||
--*/
|
||||
|
||||
.macro DPTmmTmmTmm prefix, DestReg, Src1Reg, Src2Reg
|
||||
|
||||
.set Payload0, 0x02 # "0F 38" prefix
|
||||
.set Payload0, Payload0 + ((((.LTmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7)
|
||||
.set Payload0, Payload0 + (1 << 6)
|
||||
.set Payload0, Payload0 + ((((.LTmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5)
|
||||
|
||||
.set Payload1, \prefix\()
|
||||
.set Payload1, Payload1 + (((.LTmmIndex_\Src2Reg\() & 15) ^ 15) << 3)
|
||||
|
||||
.set ModRMByte, 0xC0 # register form
|
||||
.set ModRMByte, ModRMByte + ((.LTmmIndex_\DestReg\() & 7) << 3)
|
||||
.set ModRMByte, ModRMByte + (.LTmmIndex_\Src1Reg\() & 7)
|
||||
|
||||
.byte 0xC4, Payload0, Payload1, 0x5E, ModRMByte
|
||||
|
||||
.endm
|
||||
|
||||
|
||||
.macro TdpbssdTmmTmmTmm DestReg, Src1Reg, Src2Reg
|
||||
|
||||
DPTmmTmmTmm 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\()
|
||||
|
||||
.endm
|
||||
|
||||
|
||||
.macro TdpbsudTmmTmmTmm DestReg, Src1Reg, Src2Reg
|
||||
|
||||
DPTmmTmmTmm 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\()
|
||||
|
||||
.endm
|
||||
|
||||
|
||||
.macro TdpbusdTmmTmmTmm DestReg, Src1Reg, Src2Reg
|
||||
|
||||
DPTmmTmmTmm 0x01, \DestReg\(), \Src1Reg\(), \Src2Reg\()
|
||||
|
||||
.endm
|
||||
|
||||
|
||||
.macro TdpbuudTmmTmmTmm DestReg, Src1Reg, Src2Reg
|
||||
|
||||
DPTmmTmmTmm 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\()
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro builds a AMX tile release instruction.
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
|
||||
--*/
|
||||
|
||||
// .macro TileReleaseMacro
|
||||
|
||||
// .byte 0xC4, 0xE2, 0x78, 0x49, 0xC0
|
||||
|
||||
// .endm
|
||||
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro builds an AMX tile zero instruction of the form:
|
||||
|
||||
instr tmm1
|
||||
|
||||
Arguments:
|
||||
|
||||
SrcReg - Specifies the source AMX tile.
|
||||
|
||||
--*/
|
||||
|
||||
.macro TileZeroMacro SrcReg
|
||||
|
||||
.set ModRMByte, 0xC0 # register form
|
||||
.set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3)
|
||||
.byte 0xC4, 0xE2, 0x7B, 0x49, ModRMByte
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro builds an AMX memory instruction of the form:
|
||||
|
||||
instr tmm, base, stride
|
||||
|
||||
Arguments:
|
||||
|
||||
instr - Specifies the opcode for the AMX instruction.
|
||||
|
||||
SrcReg - Specifies the target AMX tile.
|
||||
|
||||
BaseReg - Specifies the base address of memory location.
|
||||
|
||||
Stride - Specifies the stride for the memory instruction
|
||||
|
||||
--*/
|
||||
|
||||
.macro TileLoadMacro instr, SrcReg, BaseReg, Stride
|
||||
|
||||
.set Payload0, 0x02 # "0F 38" prefix
|
||||
.set Payload0, Payload0 + ((((.LTmmIndex_\SrcReg\() >> 3) & 1) ^ 1) << 7)
|
||||
.set Payload0, Payload0 + ((((3 >> 3) & 1) ^ 1) << 6)
|
||||
.set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5)
|
||||
|
||||
.set ModRMByte, 0x00 # memory form
|
||||
.set ModRMByte, ModRMByte + (1 << 2) # SibBye required
|
||||
.set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3)
|
||||
|
||||
.set SibByte, 0x00 # scale factor 1(SS)
|
||||
.set SibByte, SibByte + ((3 & 7) << 3)
|
||||
.set SibByte, SibByte + (0 & 7)
|
||||
|
||||
.byte 0xC4, Payload0, \instr\(), 0x4B, ModRMByte, SibByte
|
||||
|
||||
.endm
|
||||
|
||||
|
||||
.macro TileloaddTmmMem DstReg, BaseReg, Stride
|
||||
TileLoadMacro 0x7B, \DstReg\(), \BaseReg\(), \Stride\()
|
||||
.endm
|
||||
|
||||
.macro TileloaddT1TmmMem DstReg, BaseReg, Stride
|
||||
TileLoadMacro 0x79, \DstReg\(), \BaseReg\(), \Stride\()
|
||||
.endm
|
||||
|
||||
|
||||
.macro TileStoredMemTmm SrcReg, BaseReg, Stride
|
||||
TileLoadMacro 0x7A, \SrcReg\(), \BaseReg\(), \Stride\()
|
||||
.endm
|
||||
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro builds an AMX tile configuration instruction of the form:
|
||||
|
||||
instr base
|
||||
|
||||
Arguments:
|
||||
|
||||
instr - Specifies the opcode for the AMX instruction.
|
||||
|
||||
BaseReg - Specifies the memory address of the tile configuration.
|
||||
|
||||
--*/
|
||||
|
||||
.macro tilecfgMacro instr, BaseReg
|
||||
.set Payload0, 0x02 # "0F 38" prefix
|
||||
.set Payload0, Payload0 + (1 << 7)
|
||||
.set Payload0, Payload0 + (1 << 6)
|
||||
.set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5)
|
||||
|
||||
.set ModRMByte, 0x00 # memory form & no reg
|
||||
.set ModRMByte, ModRMByte + (0 & 7)
|
||||
|
||||
.byte 0xC4, Payload0, \instr\(), 0x49, ModRMByte
|
||||
|
||||
.endm
|
||||
|
||||
|
||||
.macro ldtilecfgMacro BaseReg
|
||||
|
||||
tilecfgMacro 0x78, \BaseReg\()
|
||||
|
||||
.endm
|
||||
|
||||
|
||||
.macro sttilecfgMacro BaseReg
|
||||
|
||||
tilecfgMacro 0x79, \BaseReg\()
|
||||
|
||||
.endm
|
||||
|
||||
Loading…
Reference in a new issue