MLAS: quantized GEMM update (#6916)

Various updates to the int8_t GEMMs:

1) Add ARM64 udot kernel to take advantage of dot product instructions available in newer cores. Some models run 4x faster than the stock implementation we used before.
2) Refactor the x64 kernels to share common code for AVX2(u8u8/u8s8/avxvnni) vs AVX512(u8u8/u8s8/avx512vnni) to reduce binary size.
3) Extend kernels to support per-column zero points for matrix B. This is not currently wired to an operator.
This commit is contained in:
Tracy Sharpe 2021-03-10 09:54:43 -08:00 committed by GitHub
parent bc319bd7aa
commit a8b897f710
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
45 changed files with 6646 additions and 5723 deletions

View file

@ -28,6 +28,7 @@ if(MSVC)
if(onnxruntime_target_platform STREQUAL "ARM64")
set(mlas_platform_preprocess_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/SgemmKernelNeon.asm
)
@ -81,15 +82,13 @@ if(MSVC)
${mlas_platform_srcs_avx2}
${ONNXRUNTIME_ROOT}/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Core.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvxVnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Core.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelSse2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelAvx.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelFma3.asm
@ -183,6 +182,7 @@ else()
set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemvKernelNeon.S
)
@ -246,8 +246,8 @@ else()
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx2.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvxVnni.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvxVnni.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/DgemmKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S
@ -322,11 +322,9 @@ else()
if(COMPILES_AVX512CORE)
set(mlas_platform_srcs_avx512core
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Core.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Core.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S
)
if(HAS_AVX512CORE)
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl")

View file

@ -22,10 +22,7 @@ class QAttention : public OpKernel, public AttentionCPUBase {
QAttention(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
#endif
private:
BufferUniquePtr packed_weights_;
@ -51,7 +48,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
template <typename T>
QAttention<T>::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info) {}
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
template <typename T>
Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_packed) {
is_packed = false;
@ -98,7 +94,6 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
is_packed = true;
return Status::OK();
}
#endif
template <typename T>
Status QAttention<T>::Compute(OpKernelContext* context) const {
@ -217,44 +212,29 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
head_size,
&dequant_scale,
bias_data + weights_offset);
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
gemm_params.M = sequence_length;
gemm_params.N = head_size;
gemm_params.K = input_hidden_size;
gemm_params.A = input_data + input_offset;
gemm_params.lda = input_hidden_size;
gemm_params.ZeroPointA = input_zero_point;
if (packed_weights_) {
const auto* packed_weight =
static_cast<const uint8_t*>(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size);
MlasGemm(
sequence_length, // M = S
head_size, // N = H
input_hidden_size, // K = D
input_data + input_offset, // A
input_hidden_size, // lda = D
input_zero_point, // input zero point
packed_weight, // B
weight_zero_point, // weight zero point
weights_is_signed, // weight data type
reinterpret_cast<int32_t*>(qkv_dest + qkv_offset), // C
head_size, // ldc
nullptr, // use single-thread
&scale_bias_processor); // output processor
continue;
gemm_params.B = packed_weight;
gemm_params.BIsPacked = true;
} else {
gemm_params.B = weights_data + weights_offset;
gemm_params.ldb = 3 * hidden_size;
}
#endif
MlasGemm(
sequence_length, // M = S
head_size, // N = H
input_hidden_size, // K = D
input_data + input_offset, // A
input_hidden_size, // lda = D
input_zero_point, // input zero point
weights_data + weights_offset, // B
3 * hidden_size, // ldb = 3NH
weight_zero_point, // weight zero point
weights_is_signed, // weight data type
reinterpret_cast<int32_t*>(qkv_dest + qkv_offset), // C
head_size, // ldc
nullptr, // use single-thread
&scale_bias_processor); // post processor
gemm_params.ZeroPointB = &weight_zero_point;
gemm_params.BIsSigned = weights_is_signed;
gemm_params.C = reinterpret_cast<int32_t*>(qkv_dest + qkv_offset);
gemm_params.ldc = head_size;
gemm_params.OutputProcessor = &scale_bias_processor;
MlasGemm(&gemm_params, nullptr);
}
});
}

View file

@ -10,19 +10,14 @@ using namespace rnn::detail;
class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
public:
DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {}
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
#endif
Status Compute(OpKernelContext* context) const override;
~DynamicQuantizeLSTM() override = default;
private:
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed);
#endif
template <typename T>
Status ComputeImpl(OpKernelContext& context) const;
@ -33,7 +28,6 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
bool is_R_signed_;
};
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status DynamicQuantizeLSTM::TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed) {
const auto& shape = weights.Shape();
if (shape.NumDimensions() != 3) {
@ -83,7 +77,6 @@ Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, bool& i
return Status::OK();
}
#endif
#define WeightCheck(weight_shape, weight_name) \
if (weight_shape.NumDimensions() != 1 && weight_shape.NumDimensions() != 2 || \

View file

@ -44,11 +44,19 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
if (y->Shape().Size() == 0)
return Status::OK();
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
gemm_params.M = static_cast<size_t>(helper.M());
gemm_params.N = static_cast<size_t>(helper.N());
gemm_params.K = static_cast<size_t>(helper.K());
gemm_params.lda = gemm_params.K;
gemm_params.ZeroPointA = a_zero_point;
gemm_params.ldb = gemm_params.N;
gemm_params.ZeroPointB = &b_zero_point;
gemm_params.ldc = gemm_params.N;
auto* y_data = y->template MutableData<float>();
const auto* bias_data = bias_tensor != nullptr ? bias_tensor->Data<float>() : nullptr;
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(
y_data + helper.OutputOffsets()[i],
@ -56,40 +64,18 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
&multiplier,
bias_data);
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
gemm_params.A = a_data + helper.LeftOffsets()[i];
if (packed_b_) {
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
a_zero_point,
packed_b_.get(),
b_zero_point,
b_is_signed_,
reinterpret_cast<int32_t*>(y_data + helper.OutputOffsets()[i]),
static_cast<size_t>(helper.N()),
thread_pool,
&scale_bias_processor);
continue;
gemm_params.B = packed_b_.get();
gemm_params.BIsPacked = true;
gemm_params.BIsSigned = b_is_signed_;
} else {
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + + helper.RightOffsets()[i];
gemm_params.BIsSigned = b->IsDataType<int8_t>();
}
#endif
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
const bool b_is_signed = b->IsDataType<int8_t>();
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
a_zero_point,
b_data + helper.RightOffsets()[i],
static_cast<size_t>(helper.N()),
b_zero_point,
b_is_signed,
reinterpret_cast<int32_t*>(y_data + helper.OutputOffsets()[i]),
static_cast<size_t>(helper.N()),
thread_pool,
&scale_bias_processor);
gemm_params.C = reinterpret_cast<int32_t*>(y_data) + helper.OutputOffsets()[i];
gemm_params.OutputProcessor = &scale_bias_processor;
MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());
}
return Status::OK();

View file

@ -61,10 +61,6 @@ Abstract:
#define MLAS_SUPPORTS_GEMM_DOUBLE
#endif
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_ARM64) || (defined(MLAS_TARGET_ARM) && !defined(_MSC_VER))
#define MLAS_SUPPORTS_PACKED_GEMM_U8X8
#endif
//
// Basic Linear Algebra Subprograms (BLAS) types.
//
@ -273,41 +269,29 @@ private:
MLAS_QUANTIZATION_GRANULARITY QuantGran_;
};
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
uint8_t offb,
bool BIsSigned,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool,
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
);
struct MLAS_GEMM_U8X8_PARAMETERS {
size_t M = 0;
size_t N = 0;
size_t K = 0;
const uint8_t* A = nullptr;
size_t lda = 0;
uint8_t ZeroPointA = 0;
const void* B = 0;
size_t ldb = 0;
const uint8_t* ZeroPointB = nullptr;
bool BIsPacked = false;
bool BIsSigned = false;
bool PerColumnZeroPoints = false;
int32_t* C = nullptr;
size_t ldc = 0;
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr;
};
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const void* PackedB,
uint8_t offb,
bool BIsSigned,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool,
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
const MLAS_GEMM_U8X8_PARAMETERS* Parameters,
MLAS_THREADPOOL* ThreadPool
);
//

View file

@ -25,7 +25,7 @@ Abstract:
// Stack frame layout for the U8X8 kernel.
//
.equ .LGemmU8X8KernelFrame_SavedGeneralRegisters, (6 * 4)
.equ .LGemmU8X8KernelFrame_SavedGeneralRegisters, (7 * 4)
.equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (8 * 8)
.equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedGeneralRegisters + .LGemmU8X8KernelFrame_SavedNeonRegisters
.equ .LGemmU8X8KernelFrame_CountM, 0 + .LGemmU8X8KernelFrame_SavedRegisters
@ -33,7 +33,7 @@ Abstract:
.equ .LGemmU8X8KernelFrame_ldc, 8 + .LGemmU8X8KernelFrame_SavedRegisters
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 12 + .LGemmU8X8KernelFrame_SavedRegisters
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 16 + .LGemmU8X8KernelFrame_SavedRegisters
.equ .LGemmU8X8KernelFrame_DepthValue, 20 + .LGemmU8X8KernelFrame_SavedRegisters
.equ .LGemmU8X8KernelFrame_ZeroPointB, 20 + .LGemmU8X8KernelFrame_SavedRegisters
.equ .LGemmU8X8KernelFrame_ZeroMode, 24 + .LGemmU8X8KernelFrame_SavedRegisters
.text
@ -48,10 +48,10 @@ Routine Description:
Arguments:
A (r0) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackANeon.
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>.
B (r1) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackBNeon.
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_NEON>.
C (r2) - Supplies the address of matrix C.
@ -67,17 +67,18 @@ Arguments:
ldc - Supplies the first dimension of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by
the zero point offset of matrix B. These values are accumulated into every
row of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A. These values have
been pre-scaled by the zero point offset of matrix B if the offset is
per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
scaled by the per-column zero point offsets of matrix B. These values are
accumulated into every row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrix A multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
B, else nullptr if the matrix B is using per-tensor quantization.
ZeroMode - Supplies true if the output matrix must be zero initialized, else
false if the output matrix is accumulated into.
@ -95,26 +96,24 @@ Return Value:
//
// q0-q1 (d0-d3) matrix B data
// q2-q3 (d4-d7) matrix A data
// q4 (d8-d9) packed matrix B data
// q5 (d10-d11) RowSumBufferData + DepthValue
// q4 (d8-d9) packed matrix A data
// q5 (d10-d11) RowSumBuffer data
// q6-q7 (d12-d15) ColumnSumBuffer data
// q8-q15 accumulators[4][2]
//
push {r4,r5,r6,r7,r8,r10}
push {r4,r5,r6,r7,r8,r9,r10}
vpush {d8-d15}
ldr r4,[sp,#.LGemmU8X8KernelFrame_CountM]
ldr r5,[sp,#.LGemmU8X8KernelFrame_ZeroMode]
ldr r7,[sp,#.LGemmU8X8KernelFrame_RowSumBuffer]
ldr r7,[sp,#.LGemmU8X8KernelFrame_ZeroPointB]
ldr r8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer]
vldr s0,[sp,#.LGemmU8X8KernelFrame_DepthValue]
ldr r9,[sp,#.LGemmU8X8KernelFrame_RowSumBuffer]
ldr r10,[sp,#.LGemmU8X8KernelFrame_ldc]
ldr r12,[sp,#.LGemmU8X8KernelFrame_CountN]
vdup.32 q5,d0[0] // broadcast DepthValue
vld1.32 {d12-d13},[r7]
vld1.32 {d10-d11},[r9] // load RowSumBuffer
mov r6,r0
mov r7,r3
vadd.u32 q5,q5,q6 // add row fixups and DepthValue
mov r9,r3
cmp r4,#1 // CountM == 1?
beq .LGemmU8X8.M1.ProcessNextColumnLoop
cmp r4,#4 // CountM < 4?
@ -128,12 +127,36 @@ Return Value:
vldr d0,[r1] // load packed B0
mov r0,r6 // reload matrix A
vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer
mov r3,r7 // reload PackedCountK
mov r3,r9 // reload PackedCountK
vmovl.u8 q0,d0
vdup.32 q9,d10[0]
vdup.32 q11,d10[1]
vdup.32 q13,d11[0]
vdup.32 q15,d11[1]
cbz r7,.LGemmU8X8.M4.SkipScaleByZeroPointB
vld1.32 {d8-d9},[r7]! // load ZeroPointB0
vmul.u32 q8,q9,q4
vmul.u32 q10,q11,q4
vmul.u32 q12,q13,q4
vmul.u32 q14,q15,q4
vld1.32 {d8-d9},[r7]! // load ZeroPointB1
vmul.u32 q9,q9,q4
vmul.u32 q11,q11,q4
vmul.u32 q13,q13,q4
vmul.u32 q15,q15,q4
vldr d8,[r0] // load first packed A0
vadd.u32 q8,q8,q6
vadd.u32 q9,q9,q7
vadd.u32 q10,q10,q6
vadd.u32 q11,q11,q7
vldr d9,[r0,#8] // load first packed A1
vadd.u32 q12,q12,q6
vadd.u32 q13,q13,q7
vadd.u32 q14,q14,q6
vadd.u32 q15,q15,q7
b .LGemmU8X8.M4.ComputeBlockLoop
.LGemmU8X8.M4.SkipScaleByZeroPointB:
vldr d8,[r0] // load first packed A0
vadd.u32 q8,q9,q6
vadd.u32 q9,q9,q7
@ -244,7 +267,7 @@ Return Value:
.LGemmU8X8.M4.ExitKernel:
mov r0,#4 // return number of rows handled
vpop {d8-d15}
pop {r4,r5,r6,r7,r8,r10}
pop {r4,r5,r6,r7,r8,r9,r10}
bx lr
//
@ -353,10 +376,24 @@ Return Value:
vldr d0,[r1] // load packed B0
mov r0,r6 // reload matrix A
vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer
mov r3,r7 // reload PackedCountK
mov r3,r9 // reload PackedCountK
vmovl.u8 q0,d0
vdup.32 q9,d10[0]
vdup.32 q11,d10[1]
cbz r7,.LGemmU8X8.M2.SkipScaleByZeroPointB
vld1.32 {d28-d31},[r7]! // load ZeroPointB
vmul.u32 q8,q9,q14
vmul.u32 q9,q9,q15
vmul.u32 q10,q11,q14
vmul.u32 q11,q11,q15
vld1.32 d8,[r0]! // load first packed A0
vadd.u32 q8,q8,q6
vadd.u32 q9,q9,q7
vadd.u32 q10,q10,q6
vadd.u32 q11,q11,q7
b .LGemmU8X8.M2.ComputeBlockLoop
.LGemmU8X8.M2.SkipScaleByZeroPointB:
vld1.32 d8,[r0]! // load first packed A0
vadd.u32 q8,q9,q6
vadd.u32 q9,q9,q7
@ -425,7 +462,7 @@ Return Value:
.LGemmU8X8.M2.ExitKernel:
mov r0,#2 // return number of rows handled
vpop {d8-d15}
pop {r4,r5,r6,r7,r8,r10}
pop {r4,r5,r6,r7,r8,r9,r10}
bx lr
//
@ -502,9 +539,19 @@ Return Value:
vldr d0,[r1] // load packed B0
mov r0,r6 // reload matrix A
vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer
mov r3,r7 // reload PackedCountK
mov r3,r9 // reload PackedCountK
vmovl.u8 q0,d0
vdup.32 q9,d10[0]
cbz r7,.LGemmU8X8.M1.SkipScaleByZeroPointB
vld1.32 {d28-d31},[r7]! // load ZeroPointB
vmul.u32 q8,q9,q14
vmul.u32 q9,q9,q15
vld1.32 d8[0],[r0]! // load first packed A0
vadd.u32 q8,q8,q6
vadd.u32 q9,q9,q7
b .LGemmU8X8.M1.ComputeBlockLoop
.LGemmU8X8.M1.SkipScaleByZeroPointB:
vld1.32 d8[0],[r0]! // load first packed A0
vadd.u32 q8,q9,q6
vadd.u32 q9,q9,q7
@ -554,7 +601,7 @@ Return Value:
.LGemmU8X8.M1.ExitKernel:
mov r0,#1 // return number of rows handled
vpop {d8-d15}
pop {r4,r5,r6,r7,r8,r10}
pop {r4,r5,r6,r7,r8,r9,r10}
bx lr
//

View file

@ -0,0 +1,52 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
AssembleDotProduct.h
Abstract:
This module contains macros to build Advanced SIMD dot product instructions
for toolchains that do not natively support this newer instruction set
extension.
This implementation uses ARM v8.4 dot product instructions.
--*/
/*++
Macro Description:
This macro builds a UDOT instruction of the form:
UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index]
Arguments:
DestReg - Specifies the destination register.
Src1Reg - Specifies the first source register.
Src2Reg - Specifies the second source register.
Index - Specifies the element index of the second source register.
--*/
.macro UdotByElement DestReg, Src1Reg, Src2Reg, Index
.set Instruction, 0x6F80E000
.set Instruction, Instruction + (\DestReg\() << 0)
.set Instruction, Instruction + (\Src1Reg\() << 5)
.set Instruction, Instruction + (\Src2Reg\() << 16)
.set Instruction, Instruction + ((\Index\() & 2) << 10)
.set Instruction, Instruction + ((\Index\() & 1) << 21)
.inst Instruction
.endm

View file

@ -22,12 +22,8 @@ Abstract:
//
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0
.equ .LGemmU8X8KernelFrame_DepthValue, 8
#if defined(__APPLE__)
.equ .LGemmU8X8KernelFrame_ZeroMode, 12
#else
.equ .LGemmU8X8KernelFrame_ZeroPointB, 8
.equ .LGemmU8X8KernelFrame_ZeroMode, 16
#endif
.text
@ -41,10 +37,10 @@ Routine Description:
Arguments:
A (x0) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackANeon.
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>.
B (x1) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackBNeon.
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_NEON>.
C (x2) - Supplies the address of matrix C.
@ -60,17 +56,18 @@ Arguments:
ldc (x6) - Supplies the first dimension of matrix C.
RowSumBuffer (x7) - Supplies the sum of each row from matrix A multiplied by
the zero point offset of matrix B. These values are accumulated into every
row of matrix C.
RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values
have been pre-scaled by the zero point offset of matrix B if the offset
is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
scaled by the per-column zero point offsets of matrix B. These values are
accumulated into every row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrix A multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
B, else nullptr if the matrix B is using per-tensor quantization.
ZeroMode - Supplies true if the output matrix must be zero initialized, else
false if the output matrix is accumulated into.
@ -84,13 +81,11 @@ Return Value:
FUNCTION_ENTRY MlasGemmU8X8KernelNeon
ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer]
ldr s27,[sp,#.LGemmU8X8KernelFrame_DepthValue]
ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB]
ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode]
dup v27.4s,v27.s[0]
mov x14,x0
ld1 {v0.4s},[x7]
ld1 {v27.4s},[x7]
mov x15,x3
add v27.4s,v27.4s,v0.4s // broadcast add DepthValue
dup v24.4s,v27.s[0] // broadcast row fixups
cmp x4,#1 // CountM == 1?
beq .LGemmU8X8.M1.ProcessNextColumnLoop
@ -111,6 +106,30 @@ Return Value:
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
uxtl v0.8h,v0.8b
cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v24.4s,v28.4s
mul v17.4s,v24.4s,v29.4s
mul v18.4s,v25.4s,v28.4s
mul v19.4s,v25.4s,v29.4s
mul v20.4s,v26.4s,v28.4s
mul v21.4s,v26.4s,v29.4s
mul v22.4s,v27.4s,v28.4s
mul v23.4s,v27.4s,v29.4s
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
add v18.4s,v2.4s,v18.4s
add v19.4s,v3.4s,v19.4s
ld1 {v5.8b},[x0],#8 // load first packed A1
add v20.4s,v2.4s,v20.4s
add v21.4s,v3.4s,v21.4s
add v22.4s,v2.4s,v22.4s
add v23.4s,v3.4s,v23.4s
b .LGemmU8X8.M4.ComputeBlockLoop
.LGemmU8X8.M4.SkipScaleByZeroPointB:
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v24.4s
add v17.4s,v3.4s,v24.4s
@ -322,6 +341,21 @@ Return Value:
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
uxtl v0.8h,v0.8b
cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v24.4s,v28.4s
mul v17.4s,v24.4s,v29.4s
mul v18.4s,v25.4s,v28.4s
mul v19.4s,v25.4s,v29.4s
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
add v18.4s,v2.4s,v18.4s
add v19.4s,v3.4s,v19.4s
b .LGemmU8X8.M2.ComputeBlockLoop
.LGemmU8X8.M2.SkipScaleByZeroPointB:
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v24.4s
add v17.4s,v3.4s,v24.4s
@ -460,6 +494,17 @@ Return Value:
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
uxtl v0.8h,v0.8b
cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v24.4s,v28.4s
mul v17.4s,v24.4s,v29.4s
ldr s4,[x0],#4 // load first packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
b .LGemmU8X8.M1.ComputeBlockLoop
.LGemmU8X8.M1.SkipScaleByZeroPointB:
ldr s4,[x0],#4 // load first packed A0
add v16.4s,v2.4s,v24.4s
add v17.4s,v3.4s,v24.4s

View file

@ -0,0 +1,585 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelUdot.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses ARM v8.4 dot product instructions.
--*/
#include "asmmacro.h"
#include "AssembleDotProduct.h"
//
// Stack frame layout for the U8X8 kernel.
//
.equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (4 * 8)
.equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedNeonRegisters
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0 + .LGemmU8X8KernelFrame_SavedRegisters
.equ .LGemmU8X8KernelFrame_ZeroPointB, 8 + .LGemmU8X8KernelFrame_SavedRegisters
.equ .LGemmU8X8KernelFrame_ZeroMode, 16 + .LGemmU8X8KernelFrame_SavedRegisters
.text
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (x0) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>.
B (x1) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_UDOT>.
C (x2) - Supplies the address of matrix C.
PackedCountK (x3) - Supplies the number of packed columns from matrix A and
the number of packed rows from matrix B to iterate over.
CountM (x4) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (x5) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc (x6) - Supplies the first dimension of matrix C.
RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values
have been pre-scaled by the zero point offset of matrix B if the offset
is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
scaled by the per-column zero point offsets of matrix B. These values are
accumulated into every row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
B, else nullptr if the matrix B is using per-tensor quantization.
ZeroMode - Supplies true if the output matrix must be zero initialized, else
false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
FUNCTION_ENTRY MlasGemmU8X8KernelUdot
stp d8,d9,[sp,#-32]!
stp d10,d11,[sp,#16]
ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer]
ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB]
ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode]
mov x14,x0
ld1 {v11.4s},[x7]
mov x15,x3
dup v8.4s,v11.s[0] // broadcast row fixups
cmp x4,#1 // CountM == 1?
beq .LGemmU8X8.M1.ProcessNextColumnLoop
dup v9.4s,v11.s[1]
cmp x4,#4 // CountM < 4?
blo .LGemmU8X8.M2.ProcessNextColumnLoop
dup v10.4s,v11.s[2]
dup v11.4s,v11.s[3]
//
// Process 4 rows of the matrices.
//
.LGemmU8X8.M4.ProcessNextColumnLoop:
ld1 {v0.16b},[x1],#16 // load packed B0
mov x0,x14 // reload matrix A
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
mul v16.4s,v30.4s,v8.4s
mul v18.4s,v30.4s,v9.4s
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
mul v20.4s,v30.4s,v10.4s
mul v22.4s,v30.4s,v11.4s
mul v17.4s,v31.4s,v8.4s
mul v19.4s,v31.4s,v9.4s
mul v21.4s,v31.4s,v10.4s
mul v23.4s,v31.4s,v11.4s
add v16.4s,v2.4s,v16.4s
add v18.4s,v2.4s,v18.4s
add v20.4s,v2.4s,v20.4s
add v22.4s,v2.4s,v22.4s
add v17.4s,v3.4s,v17.4s
add v19.4s,v3.4s,v19.4s
add v21.4s,v3.4s,v21.4s
add v23.4s,v3.4s,v23.4s
b .LGemmU8X8.M4.ComputeBlockLoopStart
.LGemmU8X8.M4.SkipScaleByZeroPointB:
add v16.4s,v2.4s,v8.4s
add v18.4s,v2.4s,v9.4s
add v20.4s,v2.4s,v10.4s
add v22.4s,v2.4s,v11.4s
add v17.4s,v3.4s,v8.4s
add v19.4s,v3.4s,v9.4s
add v21.4s,v3.4s,v10.4s
add v23.4s,v3.4s,v11.4s
//
// The packing layout is setup to have a pair of four quad vectors from
// packed matrix A and a pair of eight quad vectors from packed matrix B.
// With this scheme, alternating loads from the packed matrices can be
// interleaved with the dot product instructions.
//
// One negative consequence of using four rows here is that the accumulator
// register tile is too small for processors with high out of order execution
// windows (such as the Apple M1). The dot product instructions for a given
// cell are too close to each other to avoid dependencies. To workaround this,
// the below loop uses a pair of accumulator registers that are then added
// together when the loop finishes.
//
// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for
// packed matrix A. At the time of this implementation, using a wider 128-bit
// load didn't affect performance for higher end cores.
//
.LGemmU8X8.M4.ComputeBlockLoopStart:
ldr d4,[x0],#32 // load packed A0.l
movi v24.4s,#0
movi v25.4s,#0
ldur d5,[x0,#-24] // load packed A0.h
movi v26.4s,#0
movi v27.4s,#0
ldur d6,[x0,#-16] // load packed A1.l
movi v28.4s,#0
movi v29.4s,#0
movi v30.4s,#0
movi v31.4s,#0
.LGemmU8X8.M4.ComputeBlockLoop:
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 16, 0, 4, 0
UdotByElement 18, 0, 4, 1
ldur d7,[x0,#-8] // load packed A1.h
UdotByElement 20, 0, 5, 0
UdotByElement 22, 0, 5, 1
ld1 {v0.16b},[x1],#16 // load packed B0
UdotByElement 17, 1, 4, 0
UdotByElement 19, 1, 4, 1
sub x3,x3,#1
cbz x3,.LGemmU8X8.M4.ComputeBlockLoopFinish
ldr d4,[x0],#32 // load packed A0.l
UdotByElement 21, 1, 5, 0
UdotByElement 23, 1, 5, 1
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 24, 0, 6, 0
UdotByElement 26, 0, 6, 1
ldur d5,[x0,#-24] // load packed A0.h
UdotByElement 28, 0, 7, 0
UdotByElement 30, 0, 7, 1
ld1 {v0.16b},[x1],#16 // load packed B0
UdotByElement 25, 1, 6, 0
UdotByElement 27, 1, 6, 1
ldur d6,[x0,#-16] // load packed A1.l
UdotByElement 29, 1, 7, 0
UdotByElement 31, 1, 7, 1
b .LGemmU8X8.M4.ComputeBlockLoop
.LGemmU8X8.M4.ComputeBlockLoopFinish:
UdotByElement 21, 1, 5, 0
UdotByElement 23, 1, 5, 1
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 24, 0, 6, 0
UdotByElement 26, 0, 6, 1
UdotByElement 28, 0, 7, 0
UdotByElement 30, 0, 7, 1
UdotByElement 25, 1, 6, 0
UdotByElement 27, 1, 6, 1
UdotByElement 29, 1, 7, 0
UdotByElement 31, 1, 7, 1
add x10,x2,x6,lsl #2 // compute output row 2
add v16.4s,v16.4s,v24.4s // fold high results into low results
add v18.4s,v18.4s,v26.4s
add v20.4s,v20.4s,v28.4s
add v22.4s,v22.4s,v30.4s
add x11,x10,x6,lsl #2 // compute output row 3
add v17.4s,v17.4s,v25.4s
add v19.4s,v19.4s,v27.4s
add v21.4s,v21.4s,v29.4s
add v23.4s,v23.4s,v31.4s
add x12,x11,x6,lsl #2 // compute output row 4
subs x5,x5,#8 // adjust CountN remaining
blo .LGemmU8X8.M4.StoreOutputPartial
cbnz x13,.LGemmU8X8.M4.SkipAccumulateOutput
ldp q0,q1,[x2]
ldp q2,q3,[x10]
add v16.4s,v16.4s,v0.4s
add v17.4s,v17.4s,v1.4s
ldp q4,q5,[x11]
add v18.4s,v18.4s,v2.4s
add v19.4s,v19.4s,v3.4s
ldp q6,q7,[x12]
add v20.4s,v20.4s,v4.4s
add v21.4s,v21.4s,v5.4s
add v22.4s,v22.4s,v6.4s
add v23.4s,v23.4s,v7.4s
.LGemmU8X8.M4.SkipAccumulateOutput:
stp q16,q17,[x2],#32
stp q18,q19,[x10]
stp q20,q21,[x11]
stp q22,q23,[x12]
cbnz x5,.LGemmU8X8.M4.ProcessNextColumnLoop
.LGemmU8X8.M4.ExitKernel:
mov x0,#4 // return number of rows handled
ldp d10,d11,[sp,#16]
ldp d8,d9,[sp],#32
ret
//
// Store the partial 1 to 7 columns either overwriting the output matrix or
// accumulating into the existing contents of the output matrix.
//
.LGemmU8X8.M4.StoreOutputPartial:
cbz x13,.LGemmU8X8.M4.StoreOutputPartial.AddMode
.LGemmU8X8.M4.StoreOutputPartial.ZeroMode:
tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
st1 {v20.4s},[x11],#16
mov v20.16b,v21.16b
st1 {v22.4s},[x12],#16
mov v22.16b,v23.16b
.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode:
tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
st1 {v20.2s},[x11],#8
dup v20.4s,v20.s[2]
st1 {v22.2s},[x12],#8
dup v22.4s,v22.s[2]
.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode:
tbz x5,#0,.LGemmU8X8.M4.ExitKernel
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
st1 {v20.s}[0],[x11]
st1 {v22.s}[0],[x12]
b .LGemmU8X8.M4.ExitKernel
.LGemmU8X8.M4.StoreOutputPartial.AddMode:
tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.AddMode
ld1 {v0.4s},[x2]
ld1 {v1.4s},[x10]
ld1 {v2.4s},[x11]
ld1 {v3.4s},[x12]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
add v20.4s,v20.4s,v2.4s
add v22.4s,v22.4s,v3.4s
st1 {v20.4s},[x11],#16
mov v20.16b,v21.16b
st1 {v22.4s},[x12],#16
mov v22.16b,v23.16b
.LGemmU8X8.M4.StoreOutputPartial2.AddMode:
tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.AddMode
ld1 {v0.2s},[x2]
ld1 {v1.2s},[x10]
ld1 {v2.2s},[x11]
ld1 {v3.2s},[x12]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
add v20.4s,v20.4s,v2.4s
add v22.4s,v22.4s,v3.4s
st1 {v20.2s},[x11],#8
dup v20.4s,v20.s[2]
st1 {v22.2s},[x12],#8
dup v22.4s,v22.s[2]
.LGemmU8X8.M4.StoreOutputPartial1.AddMode:
tbz x5,#0,.LGemmU8X8.M4.ExitKernel
ld1 {v0.s}[0],[x2]
ld1 {v1.s}[0],[x10]
add v16.4s,v16.4s,v0.4s
ld1 {v2.s}[0],[x11]
add v18.4s,v18.4s,v1.4s
ld1 {v3.s}[0],[x12]
add v20.4s,v20.4s,v2.4s
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
add v22.4s,v22.4s,v3.4s
st1 {v20.s}[0],[x11]
st1 {v22.s}[0],[x12]
b .LGemmU8X8.M4.ExitKernel
//
// Process 2 rows of the matrices.
//
.LGemmU8X8.M2.ProcessNextColumnLoop:
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
mov x0,x14 // reload matrix A
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
mul v16.4s,v30.4s,v8.4s
mul v18.4s,v30.4s,v9.4s
mul v17.4s,v31.4s,v8.4s
mul v19.4s,v31.4s,v9.4s
ld1 {v4.16b},[x0],#16 // load packed A0
add v16.4s,v2.4s,v16.4s
add v18.4s,v2.4s,v18.4s
add v17.4s,v3.4s,v17.4s
add v19.4s,v3.4s,v19.4s
b .LGemmU8X8.M2.ComputeBlockLoop
.LGemmU8X8.M2.SkipScaleByZeroPointB:
ld1 {v4.16b},[x0],#16 // load packed A0
add v16.4s,v2.4s,v8.4s
add v18.4s,v2.4s,v9.4s
add v17.4s,v3.4s,v8.4s
add v19.4s,v3.4s,v9.4s
.LGemmU8X8.M2.ComputeBlockLoop:
UdotByElement 16, 0, 4, 0
UdotByElement 17, 1, 4, 0
UdotByElement 18, 0, 4, 1
UdotByElement 19, 1, 4, 1
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 16, 0, 4, 2
UdotByElement 17, 1, 4, 2
UdotByElement 18, 0, 4, 3
UdotByElement 19, 1, 4, 3
sub x3,x3,#1
cbz x3,.LGemmU8X8.M2.ComputeBlockLoopFinish
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
ld1 {v4.16b},[x0],#16 // load packed A0
b .LGemmU8X8.M2.ComputeBlockLoop
.LGemmU8X8.M2.ComputeBlockLoopFinish:
add x10,x2,x6,lsl #2 // compute output row 2
subs x5,x5,#8 // adjust CountN remaining
blo .LGemmU8X8.M2.StoreOutputPartial
cbnz x13,.LGemmU8X8.M2.SkipAccumulateOutput
ldp q0,q1,[x2]
ldp q2,q3,[x10]
add v16.4s,v16.4s,v0.4s
add v17.4s,v17.4s,v1.4s
add v18.4s,v18.4s,v2.4s
add v19.4s,v19.4s,v3.4s
.LGemmU8X8.M2.SkipAccumulateOutput:
stp q16,q17,[x2],#32
stp q18,q19,[x10]
cbnz x5,.LGemmU8X8.M2.ProcessNextColumnLoop
.LGemmU8X8.M2.ExitKernel:
mov x0,#2 // return number of rows handled
ldp d10,d11,[sp,#16]
ldp d8,d9,[sp],#32
ret
//
// Store the partial 1 to 7 columns either overwriting the output matrix or
// accumulating into the existing contents of the output matrix.
//
.LGemmU8X8.M2.StoreOutputPartial:
cbz x13,.LGemmU8X8.M2.StoreOutputPartial.AddMode
.LGemmU8X8.M2.StoreOutputPartial.ZeroMode:
tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode:
tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode:
tbz x5,#0,.LGemmU8X8.M2.ExitKernel
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
b .LGemmU8X8.M2.ExitKernel
.LGemmU8X8.M2.StoreOutputPartial.AddMode:
tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.AddMode
ld1 {v0.4s},[x2]
ld1 {v1.4s},[x10]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
.LGemmU8X8.M2.StoreOutputPartial2.AddMode:
tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.AddMode
ld1 {v0.2s},[x2]
ld1 {v1.2s},[x10]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
.LGemmU8X8.M2.StoreOutputPartial1.AddMode:
tbz x5,#0,.LGemmU8X8.M2.ExitKernel
ld1 {v0.s}[0],[x2]
ld1 {v1.s}[0],[x10]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
b .LGemmU8X8.M2.ExitKernel
//
// Process 1 row of the matrices.
//
.LGemmU8X8.M1.ProcessNextColumnLoop:
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
mov x0,x14 // reload matrix A
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB
ld1 {v30.4s},[x9],#16 // load ZeroPointB0
ld1 {v31.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v30.4s,v8.4s
mul v17.4s,v31.4s,v8.4s
ldr d4,[x0],#8 // load packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
b .LGemmU8X8.M1.ComputeBlockLoop
.LGemmU8X8.M1.SkipScaleByZeroPointB:
ldr d4,[x0],#8 // load packed A0
add v16.4s,v2.4s,v8.4s
add v17.4s,v3.4s,v8.4s
.LGemmU8X8.M1.ComputeBlockLoop:
UdotByElement 16, 0, 4, 0
UdotByElement 17, 1, 4, 0
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 16, 0, 4, 1
UdotByElement 17, 1, 4, 1
sub x3,x3,#1
cbz x3,.LGemmU8X8.M1.ComputeBlockLoopFinish
ldr d4,[x0],#8 // load packed A0
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
b .LGemmU8X8.M1.ComputeBlockLoop
.LGemmU8X8.M1.ComputeBlockLoopFinish:
subs x5,x5,#8 // adjust CountN remaining
blo .LGemmU8X8.M1.StoreOutputPartial
cbnz x13,.LGemmU8X8.M1.SkipAccumulateOutput
ldp q0,q1,[x2]
add v16.4s,v16.4s,v0.4s
add v17.4s,v17.4s,v1.4s
.LGemmU8X8.M1.SkipAccumulateOutput:
stp q16,q17,[x2],#32
cbnz x5,.LGemmU8X8.M1.ProcessNextColumnLoop
.LGemmU8X8.M1.ExitKernel:
mov x0,#1 // return number of rows handled
ldp d10,d11,[sp,#16]
ldp d8,d9,[sp],#32
ret
//
// Store the partial 1 to 7 columns either overwriting the output matrix or
// accumulating into the existing contents of the output matrix.
//
.LGemmU8X8.M1.StoreOutputPartial:
cbz x13,.LGemmU8X8.M1.StoreOutputPartial.AddMode
.LGemmU8X8.M1.StoreOutputPartial.ZeroMode:
tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode:
tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode:
tbz x5,#0,.LGemmU8X8.M1.ExitKernel
st1 {v16.s}[0],[x2]
b .LGemmU8X8.M1.ExitKernel
.LGemmU8X8.M1.StoreOutputPartial.AddMode:
tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.AddMode
ld1 {v0.4s},[x2]
add v16.4s,v16.4s,v0.4s
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
.LGemmU8X8.M1.StoreOutputPartial2.AddMode:
tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.AddMode
ld1 {v0.2s},[x2]
add v16.4s,v16.4s,v0.4s
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
.LGemmU8X8.M1.StoreOutputPartial1.AddMode:
tbz x5,#0,.LGemmU8X8.M1.ExitKernel
ld1 {v0.s}[0],[x2]
add v16.4s,v16.4s,v0.4s
st1 {v16.s}[0],[x2]
b .LGemmU8X8.M1.ExitKernel
.end

View file

@ -19,9 +19,10 @@
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8X8KernelAvx2Common.inc
.list
EXTERN MlasMaskMoveTableAvx:NEAR
;
; Stack frame layout for the U8S8 CopyPackA routine.
;
@ -142,9 +143,9 @@ GemmU8S8CopyPackBFrame ENDS
and eax,15 ; isolate unaligned count
add eax,3
shr eax,2 ; align unaligned count to quad count
mov DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp],eax
vpbroadcastd xmm10,DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp]
vpcmpgtd xmm10,xmm10,XMMWORD PTR [MlasMaskMoveAvx]
neg rax
lea rbx,MlasMaskMoveTableAvx+8*4
vmovdqu xmm10,XMMWORD PTR [rbx+rax*4]
;
; Zero initialize the padded stack buffers.
@ -776,281 +777,4 @@ StoreColumnSumBufferNUnaligned:
NESTED_END MlasGemmU8S8CopyPackBAvx2, _TEXT
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single row of the
; output block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
; is 16).
;
; Vec2Reg - Supplies the low block accumulator register.
;
; Implicit Arguments:
;
; ymm0 - Supplies the first vector loaded from matrix B.
;
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
; is 16).
;
; ymm2 - Supplies the broadcast value loaded from matrix A.
;
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
;
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
vpmaddubsw ymm3,ymm2,ymm0
vpmaddwd ymm3,ymm3,ymm12
IF ColumnCount EQ 16
vpaddd Vec1Reg,Vec1Reg,ymm3
vpmaddubsw ymm2,ymm2,ymm1
vpmaddwd ymm2,ymm2,ymm12
vpaddd Vec2Reg,Vec2Reg,ymm2
ELSE
vpaddd Vec2Reg,Vec2Reg,ymm3
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm11 - Supplies the block accumulators.
;
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF RowCount EQ 1
vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset]
vpmaddwd ymm3,ymm3,ymm12
IF ColumnCount EQ 16
vpaddd ymm4,ymm4,ymm3
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32]
vpmaddwd ymm2,ymm2,ymm12
vpaddd ymm5,ymm5,ymm2
ELSE
vpaddd ymm5,ymm5,ymm3
ENDIF
ELSE
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm11 - Supplies the block accumulators.
;
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockBy1Loop
mov rsi,r9 ; reload row length remaining
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,64 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDM
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8S8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8S8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
; the number of packed rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
; zero point offset of matrix B. These values are accumulated into every
; row of matrix C.
;
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; DepthValue - Supplies the value CountK multiplied by the zero point offset
; of matrix A multplied by the zero point offset of matrix B. This value is
; accumulated into every element of matrix C.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemmU8S8KernelAvx2, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
alloc_stack (GemmU8X8KernelFrame.SavedR13)
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
END_PROLOGUE
mov rdi,rcx
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.CountM[rsp]
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF]
vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001]
;
; Process CountM rows of the matrices.
;
cmp r11,3
ja ProcessCountM4
je ProcessCountM3
cmp r11,1
je ProcessCountM1
ProcessCountM2:
ProcessCountM 2
ProcessCountM4:
mov r11d,4 ; return 4 rows handled
ProcessCountM 4, Fallthrough
;
; Restore non-volatile registers and return.
;
ExitKernel:
mov eax,r11d
vzeroupper
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR13)
BEGIN_EPILOGUE
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
ProcessCountM1:
ProcessCountM 1
ProcessCountM3:
ProcessCountM 3
NESTED_END MlasGemmU8S8KernelAvx2, _TEXT
END

View file

@ -1,91 +0,0 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8S8KernelAvx512Common.inc
;
; Abstract:
;
; This module contains common kernel macros and structures for the quantized
; integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
; AVX512VNNI kernels.
;
;--
INCLUDE QgemmU8X8KernelAvx512Common.inc
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockBy4Loop
LOCAL ProcessRemainingBlocks
LOCAL ComputeBlockBy1Loop
LOCAL ComputeBlockLoopExit
mov rsi,r9 ; reload row length remaining
IF ((RowCount AND 1) EQ 0)
sub rsi,4*4
jb ProcessRemainingBlocks
ComputeBlockBy4Loop:
ComputeBlock ColumnCount, RowCount, 0*64, 0
ComputeBlock ColumnCount, RowCount, 1*64, 4
ComputeBlock ColumnCount, RowCount, 2*64, 8
ComputeBlock ColumnCount, RowCount, 3*64, 12
add rcx,4*4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,4*64 ; advance matrix B
sub rsi,4*4 ; decrement quads remaining
jae ComputeBlockBy4Loop
ProcessRemainingBlocks:
add rsi,4*4 ; correct for over-subtract above
jz ComputeBlockLoopExit
ENDIF
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,64 ; advance matrix B
sub rsi,4 ; decrement quads remaining
jnz ComputeBlockBy1Loop
ComputeBlockLoopExit:
ENDM

View file

@ -1,130 +0,0 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8S8KernelAvx512Core.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVX512 core instructions (BW/DQ/VL).
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8S8KernelAvx512Common.inc
.list
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single cell of the
; output block.
;
; Arguments:
;
; AccumReg - Supplies the register to accumulate into.
;
; Mult1Reg - Supplies the first multiplication operand register.
;
; Mult2Reg - Supplies the second multiplication operand register.
;
; Implicit Arguments:
;
; zmm4 - Supplies a scratch register for intermediate results.
;
; zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
;
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
vpmaddubsw zmm4,Mult1Reg,Mult2Reg
vpmaddwd zmm4,zmm4,zmm5
vpaddd AccumReg,AccumReg,zmm4
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF ColumnCount GE 48
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
ELSEIF ColumnCount GE 32
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
ELSE
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
ENDM
;
; Generate the GEMM kernel.
;
GemmU8X8KernelAvx512Function U8S8, Avx512Core
END

View file

@ -1,102 +0,0 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8S8KernelAvx512Vnni.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVX512VNNI instructions.
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8S8KernelAvx512Common.inc
INCLUDE AssembleAvx512Vnni.inc
.list
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF ColumnCount GE 48
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
ELSEIF ColumnCount GE 32
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
ELSE
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2>
ENDM
;
; Generate the GEMM kernel.
;
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
END

View file

@ -1,298 +0,0 @@
;++
;
; Copyright (c) Intel Corporation 2020. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8S8KernelAvxVnni.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVXVNNI instructions.
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8X8KernelAvx2Common.inc
INCLUDE AssembleAvxVnni.inc
.list
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single row of the
; output block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
; is 16).
;
; Vec2Reg - Supplies the low block accumulator register.
;
; Implicit Arguments:
;
; ymm0 - Supplies the first vector loaded from matrix B.
;
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
; is 16).
;
; ymm2 - Supplies the broadcast value loaded from matrix A.
;
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
IF ColumnCount EQ 16
VpdpbusdsYmmYmmYmm Vec1Reg,ymm2,ymm0
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm1
ELSE
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm0
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, ymm12, ymm13>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, ymm14, ymm15>
ENDM
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockBy1Loop
mov rsi,r9 ; reload row length remaining
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,64 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDM
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8S8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8S8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
; the number of packed rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
; zero point offset of matrix B. These values are accumulated into every
; row of matrix C.
;
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; DepthValue - Supplies the value CountK multiplied by the zero point offset
; of matrix A multplied by the zero point offset of matrix B. This value is
; accumulated into every element of matrix C.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemmU8S8KernelAvxVnni, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
alloc_stack (GemmU8X8KernelFrame.SavedR13)
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
END_PROLOGUE
mov rdi,rcx
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.CountM[rsp]
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
;
; Process CountM rows of the matrices.
;
cmp r11,5
ja ProcessCountM6
je ProcessCountM5
cmp r11,3
ja ProcessCountM4
je ProcessCountM3
cmp r11,1
je ProcessCountM1
ProcessCountM2:
ProcessCountM 2
ProcessCountM4:
ProcessCountM 4
ProcessCountM6:
mov r11d,6 ; return 6 rows handled
ProcessCountM 6, Fallthrough
;
; Restore non-volatile registers and return.
;
ExitKernel:
mov eax,r11d
vzeroupper
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR13)
BEGIN_EPILOGUE
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
ProcessCountM1:
ProcessCountM 1
ProcessCountM3:
ProcessCountM 3
ProcessCountM5:
ProcessCountM 5
NESTED_END MlasGemmU8S8KernelAvxVnni, _TEXT
END

View file

@ -19,9 +19,10 @@
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8X8KernelAvx2Common.inc
.list
EXTERN MlasMaskMoveTableAvx:NEAR
;
; Stack frame layout for the U8U8 CopyPackA routine.
;
@ -136,9 +137,9 @@ GemmU8U8CopyPackBFrame ENDS
and eax,15 ; isolate unaligned count
inc eax
shr eax,1 ; align unaligned count to pair count
mov DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp],eax
vpbroadcastd ymm9,DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp]
vpcmpgtd ymm9,ymm9,YMMWORD PTR [MlasMaskMoveAvx]
neg rax
lea rbx,MlasMaskMoveTableAvx+8*4
vmovdqu ymm9,YMMWORD PTR [rbx+rax*4]
;
; Zero initialize the padded stack buffers.
@ -663,305 +664,4 @@ StoreColumnSumBufferNUnaligned:
NESTED_END MlasGemmU8U8CopyPackBAvx2, _TEXT
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single row of the
; output block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
; is 16).
;
; Vec2Reg - Supplies the low block accumulator register.
;
; Implicit Arguments:
;
; ymm0 - Supplies the first vector loaded from matrix B.
;
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
; is 16).
;
; ymm2 - Supplies the broadcast value loaded from matrix A.
;
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
vpmaddwd ymm3,ymm2,ymm0
IF ColumnCount EQ 16
vpaddd Vec1Reg,Vec1Reg,ymm3
vpmaddwd ymm2,ymm2,ymm1
vpaddd Vec2Reg,Vec2Reg,ymm2
ELSE
vpaddd Vec2Reg,Vec2Reg,ymm3
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset]
EmitIfCountGE ColumnCount, 16, <vpmovzxbw ymm1,XMMWORD PTR [rdx+VectorOffset+16]>
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, ymm12, ymm13>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, ymm14, ymm15>
ENDM
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockBy2Loop
LOCAL ProcessRemainingBlocks
LOCAL ComputeBlockBy1Loop
LOCAL ComputeBlockLoopExit
mov rsi,r9 ; reload row length remaining
IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0)
sub rsi,2*4
jb ProcessRemainingBlocks
ComputeBlockBy2Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
ComputeBlock ColumnCount, RowCount, 32, 4
add rcx,2*4 ; advance matrix A by 2 pairs
IF RowCount GT 3
add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs
ENDIF
add rdx,2*32 ; advance matrix B
sub rsi,2*4
jae ComputeBlockBy2Loop
ProcessRemainingBlocks:
add rsi,2*4 ; correct for over-subtract above
jz ComputeBlockLoopExit
ComputeBlock ColumnCount, RowCount, 0, 0
add rdx,32 ; advance matrix B
ELSE
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 pair
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
ENDIF
add rdx,32 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDIF
ComputeBlockLoopExit:
ENDM
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8U8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8U8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
; the number of packed rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
; zero point offset of matrix B. These values are accumulated into every
; row of matrix C.
;
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; DepthValue - Supplies the value CountK multiplied by the zero point offset
; of matrix A multplied by the zero point offset of matrix B. This value is
; accumulated into every element of matrix C.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemmU8U8KernelAvx2, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
alloc_stack (GemmU8X8KernelFrame.SavedR13)
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
END_PROLOGUE
mov rdi,rcx
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.CountM[rsp]
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
;
; Process CountM rows of the matrices.
;
cmp r11,5
ja ProcessCountM6
je ProcessCountM5
cmp r11,3
ja ProcessCountM4
je ProcessCountM3
cmp r11,1
je ProcessCountM1
ProcessCountM2:
ProcessCountM 2
ProcessCountM4:
ProcessCountM 4
ProcessCountM6:
mov r11d,6 ; return 6 rows handled
ProcessCountM 6, Fallthrough
;
; Restore non-volatile registers and return.
;
ExitKernel:
mov eax,r11d
vzeroupper
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR13)
BEGIN_EPILOGUE
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
ProcessCountM1:
ProcessCountM 1
ProcessCountM3:
ProcessCountM 3
ProcessCountM5:
ProcessCountM 5
NESTED_END MlasGemmU8U8KernelAvx2, _TEXT
END

View file

@ -1,172 +0,0 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8U8KernelAvx512Core.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVX512 core instructions (BW/DQ/VL).
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8X8KernelAvx512Common.inc
.list
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single cell of the
; output block.
;
; Arguments:
;
; AccumReg - Supplies the register to accumulate into.
;
; Mult1Reg - Supplies the first multiplication operand register.
;
; Mult2Reg - Supplies the second multiplication operand register.
;
; Implicit Arguments:
;
; zmm4 - Supplies a scratch register for intermediate results.
;
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
vpmaddwd zmm4,Mult1Reg,Mult2Reg
vpaddd AccumReg,AccumReg,zmm4
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF ColumnCount GE 48
vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset]
vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset]
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset]
ELSEIF ColumnCount GE 32
vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset]
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset]
ELSE
vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
ENDM
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockBy1Loop
mov rsi,r9 ; reload row length remaining
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 pair
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
ENDIF
add rdx,32 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDM
;
; Generate the GEMM kernel.
;
GemmU8X8KernelAvx512Function U8U8, Avx512Core
END

View file

@ -0,0 +1,940 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8X8KernelAvx2.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVX2 and AVX VNNI instructions.
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE AssembleAvxVnni.inc
.list
EXTERN MlasMaskMoveTableAvx:NEAR
;
; Stack frame layout for the U8X8 kernel.
;
GemmU8X8KernelFrame STRUCT
SavedXmm6 OWORD ?
SavedXmm7 OWORD ?
SavedXmm8 OWORD ?
SavedXmm9 OWORD ?
SavedXmm10 OWORD ?
SavedXmm11 OWORD ?
SavedXmm12 OWORD ?
SavedXmm13 OWORD ?
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
Padding QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumBuffer QWORD ?
ColumnSumBuffer QWORD ?
ZeroPointB QWORD ?
ZeroMode QWORD ?
GemmU8X8KernelFrame ENDS
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single row of the
; output block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
; is 16).
;
; Vec2Reg - Supplies the low block accumulator register.
;
; Implicit Arguments:
;
; ymm0 - Supplies the first vector loaded from matrix B.
;
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
; is 16).
;
; ymm2 - Supplies the broadcast value loaded from matrix A.
;
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
;
MultiplyAccumulateRowU8S8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg
vpmaddubsw ymm3,ymm2,ymm0
vpmaddwd ymm3,ymm3,ymm12
IF ColumnCount EQ 16
vpaddd Vec1Reg,Vec1Reg,ymm3
vpmaddubsw ymm2,ymm2,ymm1
vpmaddwd ymm2,ymm2,ymm12
vpaddd Vec2Reg,Vec2Reg,ymm2
ELSE
vpaddd Vec2Reg,Vec2Reg,ymm3
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm11 - Supplies the block accumulators.
;
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
;
ComputeBlockU8S8Avx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF RowCount EQ 1
vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset]
vpmaddwd ymm3,ymm3,ymm12
IF ColumnCount EQ 16
vpaddd ymm4,ymm4,ymm3
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32]
vpmaddwd ymm2,ymm2,ymm12
vpaddd ymm5,ymm5,ymm2
ELSE
vpaddd ymm5,ymm5,ymm3
ENDIF
ELSE
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm4, ymm5>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm6, ymm7>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm8, ymm9>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm10, ymm11>
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single row of the
; output block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
; is 16).
;
; Vec2Reg - Supplies the low block accumulator register.
;
; Implicit Arguments:
;
; ymm0 - Supplies the first vector loaded from matrix B.
;
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
; is 16).
;
; ymm2 - Supplies the broadcast value loaded from matrix A.
;
MultiplyAccumulateRowU8S8AvxVnni MACRO ColumnCount, Vec1Reg, Vec2Reg
IF ColumnCount EQ 16
VpdpbusdsYmmYmmYmm Vec1Reg,ymm2,ymm0
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm1
ELSE
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm0
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ComputeBlockU8S8AvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm4, ymm5>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm6, ymm7>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm8, ymm9>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm10, ymm11>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm12, ymm13>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm14, ymm15>
ENDM
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple times
; and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; Isa - Supplies the instruction set architecture string.
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm11 - Supplies the block accumulators.
;
ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
LOCAL ComputeBlockBy1Loop
mov rsi,r9 ; reload row length remaining
ComputeBlockBy1Loop:
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,64 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single row of the
; output block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
; is 16).
;
; Vec2Reg - Supplies the low block accumulator register.
;
; Implicit Arguments:
;
; ymm0 - Supplies the first vector loaded from matrix B.
;
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
; is 16).
;
; ymm2 - Supplies the broadcast value loaded from matrix A.
;
MultiplyAccumulateRowU8U8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg
vpmaddwd ymm3,ymm2,ymm0
IF ColumnCount EQ 16
vpaddd Vec1Reg,Vec1Reg,ymm3
vpmaddwd ymm2,ymm2,ymm1
vpaddd Vec2Reg,Vec2Reg,ymm2
ELSE
vpaddd Vec2Reg,Vec2Reg,ymm3
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ComputeBlockU8U8Avx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset]
EmitIfCountGE ColumnCount, 16, <vpmovzxbw ymm1,XMMWORD PTR [rdx+VectorOffset+16]>
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm4, ymm5>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm6, ymm7>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm8, ymm9>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm10, ymm11>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm12, ymm13>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm14, ymm15>
ENDM
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple times
; and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; Isa - Supplies the instruction set architecture string.
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount
LOCAL ComputeBlockBy2Loop
LOCAL ProcessRemainingBlocks
LOCAL ComputeBlockBy1Loop
LOCAL ExitComputeBlockLoop
mov rsi,r9 ; reload row length remaining
IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0)
sub rsi,2*4
jb ProcessRemainingBlocks
ComputeBlockBy2Loop:
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 32, 4
add rcx,2*4 ; advance matrix A by 2 pairs
IF RowCount GT 3
add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs
ENDIF
add rdx,2*32 ; advance matrix B
sub rsi,2*4
jae ComputeBlockBy2Loop
ProcessRemainingBlocks:
add rsi,2*4 ; correct for over-subtract above
jz ExitComputeBlockLoop
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0
add rdx,32 ; advance matrix B
ELSE
ComputeBlockBy1Loop:
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 pair
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
ENDIF
add rdx,32 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDIF
ExitComputeBlockLoop:
ENDM
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r11 - Supplies the address of the row sum buffer.
;
; r12 - Supplies the address of the column sum buffer.
;
; r13 - Optionally supplies the address of the matrix B zero point buffer.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
LOCAL SkipScaleByZeroPointB
LOCAL AccumulatorsInitialized
LOCAL ProduceWithU8S8AvxVnni
LOCAL ProduceWithU8U8Avx2
LOCAL ExitProduceOutputBlock
;
; Initialize the accumulators with the row and column sums.
;
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r11]>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r11+4]>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r11+8]>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r11+12]>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r11+16]>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r11+20]>
IF ColumnCount EQ 16
vmovdqu ymm0,YMMWORD PTR [r12]
vmovdqu ymm1,YMMWORD PTR [r12+32]
add r12,16*4 ; advance ColumnSumBuffer by 16 columns
ELSE
vmovdqu ymm1,YMMWORD PTR [r12]
ENDIF
test r13,r13 ; per column zero points?
jz SkipScaleByZeroPointB
IF ColumnCount EQ 16
vmovdqu ymm2,YMMWORD PTR [r13]
vmovdqu ymm3,YMMWORD PTR [r13+32]
add r13,16*4 ; advance ZeroPointB by 16 columns
ELSE
vmovdqu ymm3,YMMWORD PTR [r13]
ENDIF
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpmulld ymm4,ymm5,ymm2>
EmitIfCountGE RowCount, 1, <vpmulld ymm5,ymm5,ymm3>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm0,ymm4>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm1,ymm5>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpmulld ymm6,ymm7,ymm2>
EmitIfCountGE RowCount, 2, <vpmulld ymm7,ymm7,ymm3>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm0,ymm6>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm1,ymm7>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpmulld ymm8,ymm9,ymm2>
EmitIfCountGE RowCount, 3, <vpmulld ymm9,ymm9,ymm3>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm0,ymm8>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm1,ymm9>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpmulld ymm10,ymm11,ymm2>
EmitIfCountGE RowCount, 4, <vpmulld ymm11,ymm11,ymm3>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm0,ymm10>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm1,ymm11>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpmulld ymm12,ymm13,ymm2>
EmitIfCountGE RowCount, 5, <vpmulld ymm13,ymm13,ymm3>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm0,ymm12>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm1,ymm13>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpmulld ymm14,ymm15,ymm2>
EmitIfCountGE RowCount, 6, <vpmulld ymm15,ymm15,ymm3>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm0,ymm14>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm1,ymm15>
jmp AccumulatorsInitialized
SkipScaleByZeroPointB:
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm0,ymm5>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm1,ymm5>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm0,ymm7>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm1,ymm7>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm0,ymm9>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm1,ymm9>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm0,ymm11>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm1,ymm11>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm0,ymm13>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm1,ymm13>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm0,ymm15>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm1,ymm15>
AccumulatorsInitialized:
;
; Iterate over the length of a matrix A row to produce the output accumulators.
;
IF RowCount GT 3
lea rbx,[r9*2+r9]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
jg ProduceWithU8U8Avx2
IF RowCount LE 4
jl ProduceWithU8S8AvxVnni
ComputeBlockLoopU8S8 Avx2, ColumnCount, RowCount
jmp ExitProduceOutputBlock
ENDIF
ProduceWithU8S8AvxVnni:
ComputeBlockLoopU8S8 AvxVnni, ColumnCount, RowCount
jmp ExitProduceOutputBlock
ProduceWithU8U8Avx2:
ComputeBlockLoopU8U8 Avx2, ColumnCount, RowCount
ExitProduceOutputBlock:
IF RowCount GT 3
lea rbx,[rax*2+rax]
add rbx,r8 ; compute matrix C plus 3 rows
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10b - Supplies the zero mode flag.
;
; r11 - Supplies the address of the row sum buffer.
;
; r12 - Supplies the address of the column sum buffer.
;
; r13 - Optionally supplies the address of the matrix B zero point buffer.
;
ProcessCountM MACRO RowCount, Fallthrough
LOCAL ProcessNextColumnLoop16xN
LOCAL SkipAccumulateOutput16xNBlock
LOCAL OutputMasked16xNBlock
LOCAL ExitProcessCountM
LOCAL ProcessRemainingCountN
LOCAL SkipAccumulateOutput8xNBlock
LOCAL SkipAccumulateOutputMasked16xNBlock
LOCAL OutputMasked8xNBlock
LOCAL SkipAccumulateOutputMasked8xNBlock
cmp rbp,8
jbe ProcessRemainingCountN
ProcessNextColumnLoop16xN:
ProduceOutputBlock 16, RowCount
sub rbp,16
jb OutputMasked16xNBlock
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
SkipAccumulateOutput16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,8
ja ProcessNextColumnLoop16xN
test rbp,rbp
jnz ProcessRemainingCountN
ExitProcessCountM:
mov eax,RowCount
jmp ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 8, RowCount
cmp rbp,8
jb OutputMasked8xNBlock
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput8xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput8xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
jmp ExitProcessCountM
OutputMasked16xNBlock:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutputMasked16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutputMasked16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
add r8,8*4 ; advance matrix C by 8 columns
IF RowCount GT 3
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
ENDIF
add rbp,8 ; correct for over-subtract above
OutputMasked8xNBlock:
neg rbp
lea rcx,MlasMaskMoveTableAvx+8*4
vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4]
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutputMasked8xNBlock
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
SkipAccumulateOutputMasked8xNBlock:
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
jmp ExitProcessCountM
ENDM
;
; Reduce code size for the various types of kernels by sharing the outer logic
; and switching on the selector codes (using sign bit to discriminate).
;
LEAF_ENTRY MlasGemmU8S8KernelAvxVnni, _TEXT
mov eax,-1
jmp MlasGemmU8X8KernelAvx2
LEAF_END MlasGemmU8S8KernelAvxVnni, _TEXT
LEAF_ENTRY MlasGemmU8U8KernelAvx2, _TEXT
mov eax,1
jmp MlasGemmU8X8KernelAvx2
LEAF_END MlasGemmU8U8KernelAvx2, _TEXT
LEAF_ENTRY MlasGemmU8S8KernelAvx2, _TEXT
xor eax,eax
jmp MlasGemmU8X8KernelAvx2
LEAF_END MlasGemmU8S8KernelAvx2, _TEXT
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8X8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8X8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
; the number of packed rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumBuffer - Supplies the sum of each row from matrix A. These values have
; been pre-scaled by the zero point offset of matrix B if the offset is
; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
; scaled by the per-column zero point offsets of matrix B. These values are
; accumulated into every row of matrix C.
;
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
; B, else nullptr if the matrix B is using per-tensor quantization.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemmU8X8KernelAvx2, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
alloc_stack (GemmU8X8KernelFrame.SavedR13)
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
END_PROLOGUE
mov DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],eax
mov rdi,rcx
mov rbx,GemmU8X8KernelFrame.CountM[rsp]
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.RowSumBuffer[rsp]
mov r12,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
mov r13,GemmU8X8KernelFrame.ZeroPointB[rsp]
vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF]
vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001]
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
je CheckCountM4OrMore ; U8S8 AVX2 kernel requires extra registers
;
; Process CountM rows of the matrices.
;
CheckCountM6OrMore:
cmp rbx,5
ja ProcessCountM6
je ProcessCountM5
CheckCountM4OrMore:
cmp rbx,3
ja ProcessCountM4
je ProcessCountM3
cmp rbx,1
je ProcessCountM1
ProcessCountM2:
ProcessCountM 2
ProcessCountM4:
ProcessCountM 4
ProcessCountM6:
ProcessCountM 6
;
; Restore non-volatile registers and return.
;
ExitKernel:
vzeroupper
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR13)
BEGIN_EPILOGUE
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
ProcessCountM1:
ProcessCountM 1
ProcessCountM3:
ProcessCountM 3
ProcessCountM5:
ProcessCountM 5
NESTED_END MlasGemmU8X8KernelAvx2, _TEXT
END

View file

@ -1,302 +0,0 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8X8KernelAvx2Common.inc
;
; Abstract:
;
; This module contains common kernel macros and structures for the quantized
; integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
;
;--
EXTERN MlasMaskMoveAvx:NEAR
;
; Stack frame layout for the U8S8 and U8U8 kernels.
;
GemmU8X8KernelFrame STRUCT
SavedXmm6 OWORD ?
SavedXmm7 OWORD ?
SavedXmm8 OWORD ?
SavedXmm9 OWORD ?
SavedXmm10 OWORD ?
SavedXmm11 OWORD ?
SavedXmm12 OWORD ?
SavedXmm13 OWORD ?
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
Padding QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumBuffer QWORD ?
ColumnSumBuffer QWORD ?
DepthValue QWORD ?
ZeroMode QWORD ?
GemmU8X8KernelFrame ENDS
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r12 - Supplies the address of the row sum buffer.
;
; r13 - Supplies the address of the column sum buffer.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
;
; Initialize the accumulators with the sum of the global depth value constant,
; the column sums, and the row sums.
;
vpbroadcastd ymm1,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
IF ColumnCount EQ 16
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
add r13,16*4 ; advance ColumnSumBuffer by 16 columns
ELSE
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r12]>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r12+4]>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r12+8]>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r12+12]>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r12+16]>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r12+20]>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm5,ymm0>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm7,ymm0>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm9,ymm0>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm11,ymm0>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm13,ymm0>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm15,ymm0>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm1>
;
; Iterate over the length of a matrix A row to produce the output accumulators.
;
IF RowCount GT 3
lea rbx,[r9*2+r9]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
ComputeBlockLoop ColumnCount, RowCount
IF RowCount GT 3
lea rbx,[rax*2+rax]
add rbx,r8 ; compute matrix C plus 3 rows
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Fallthrough - Supplies a non-blank value if the macro may fall through to
; the ExitKernel label.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10b - Supplies the zero mode flag.
;
; r12 - Supplies the address of the row sum buffer.
;
; r13 - Supplies the address of the column sum buffer.
;
ProcessCountM MACRO RowCount, Fallthrough
LOCAL ProcessNextColumnLoop16xN
LOCAL SkipAccumulateOutput16xNBlock
LOCAL OutputMasked16xNBlock
LOCAL ProcessRemainingCountN
LOCAL SkipAccumulateOutput8xNBlock
LOCAL SkipAccumulateOutputMasked16xNBlock
LOCAL OutputMasked8xNBlock
LOCAL SkipAccumulateOutputMasked8xNBlock
cmp rbp,8
jbe ProcessRemainingCountN
ProcessNextColumnLoop16xN:
ProduceOutputBlock 16, RowCount
sub rbp,16
jb OutputMasked16xNBlock
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
SkipAccumulateOutput16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,8
ja ProcessNextColumnLoop16xN
test rbp,rbp
jz ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 8, RowCount
cmp rbp,8
jb OutputMasked8xNBlock
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput8xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput8xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
jmp ExitKernel
OutputMasked16xNBlock:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutputMasked16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutputMasked16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
add r8,8*4 ; advance matrix C by 8 columns
IF RowCount GT 3
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
ENDIF
add rbp,8 ; correct for over-subtract above
OutputMasked8xNBlock:
mov DWORD PTR GemmU8X8KernelFrame.CountN[rsp],ebp
vpbroadcastd ymm0,DWORD PTR GemmU8X8KernelFrame.CountN[rsp]
vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx]
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutputMasked8xNBlock
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
SkipAccumulateOutputMasked8xNBlock:
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
IFB <Fallthrough>
jmp ExitKernel
ENDIF
ENDM

View file

@ -1,438 +0,0 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8X8KernelAvx512Common.inc
;
; Abstract:
;
; This module contains common kernel macros and structures for the quantized
; integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
; AVX512VNNI kernels.
;
;--
;
; Stack frame layout for the U8S8 and U8U8 kernels.
;
GemmU8X8KernelFrame STRUCT
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
SavedR14 QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumBuffer QWORD ?
ColumnSumBuffer QWORD ?
DepthValue QWORD ?
ZeroMode QWORD ?
GemmU8X8KernelFrame ENDS
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r12 - Supplies the address of the row sum buffer.
;
; r13 - Supplies the address of the column sum buffer.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
;
; Initialize the accumulators with the sum of the global depth value constant,
; the column sums, and the row sums.
;
vpbroadcastd zmm3,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
IF ColumnCount GE 32
IF ColumnCount GE 48
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
ELSE
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
ENDIF
add_immed r13,ColumnCount*4 ; advance ColumnSumBuffer by N columns
ELSE
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
ENDIF
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,DWORD BCST [r12]>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,DWORD BCST [r12]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,DWORD BCST [r12]>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,DWORD BCST [r12+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,DWORD BCST [r12+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,DWORD BCST [r12+4]>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,DWORD BCST [r12+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,DWORD BCST [r12+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,DWORD BCST [r12+8]>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,DWORD BCST [r12+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,DWORD BCST [r12+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,DWORD BCST [r12+12]>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,DWORD BCST [r12+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,DWORD BCST [r12+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,DWORD BCST [r12+16]>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,DWORD BCST [r12+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,DWORD BCST [r12+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,DWORD BCST [r12+20]>
;
; Iterate over the length of a matrix A row to produce the output accumulators.
;
IF RowCount GT 3
lea rbx,[r9*2+r9]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
ComputeBlockLoop ColumnCount, RowCount
IF RowCount GT 3
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
add rbx,rax
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10b - Supplies the zero mode flag.
;
; r12 - Supplies the address of the row sum buffer.
;
; r13 - Supplies the address of the column sum buffer.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
ProcessCountM MACRO RowCount
LOCAL ProcessNextColumnLoop32xN
LOCAL Output32xNBlock
LOCAL SkipAccumulateOutput32xNBlock
LOCAL Output16xNBlock
LOCAL Output16xNBlockWithMask
LOCAL SkipAccumulateOutput16xNBlockWithMask
LOCAL ProcessRemainingCountN
LOCAL ProcessNextColumnLoop48xN
LOCAL SkipAccumulateOutput48xNBlock
cmp rbp,32
ja ProcessNextColumnLoop48xN
cmp rbp,16
jbe ProcessRemainingCountN
ProcessNextColumnLoop32xN:
ProduceOutputBlock 32, RowCount
add rdx,r14 ; advance matrix B by packed block stride
Output32xNBlock:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput32xNBlock
EmitIfCountGE RowCount, 1, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm21,zmm21,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm22,zmm22,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput32xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm20>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm21>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm22>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm23>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25>
add r8,16*4 ; advance matrix C by 16 columns
IF RowCount GT 3
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
ENDIF
sub rbp,16
Output16xNBlock:
sub rbp,16
jae Output16xNBlockWithMask
lea ecx,[ebp+16] ; correct for over-subtract above
mov esi,1
shl esi,cl
dec esi
kmovw k1,esi ; update mask for remaining columns
xor ebp,ebp ; no more columns remaining
Output16xNBlockWithMask:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlockWithMask
EmitIfCountGE RowCount, 1, <vpaddd zmm14{k1},zmm14,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm15{k1},zmm15,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm16{k1},zmm16,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput16xNBlockWithMask:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm14>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm15>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm16>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,32
ja ProcessNextColumnLoop48xN
cmp rbp,16
ja ProcessNextColumnLoop32xN
test rbp,rbp
jz ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 16, RowCount
jmp Output16xNBlock
ProcessNextColumnLoop48xN:
ProduceOutputBlock 48, RowCount
lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput48xNBlock
EmitIfCountGE RowCount, 1, <vpaddd zmm26,zmm26,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm27,zmm27,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm28,zmm28,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput48xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm26>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm27>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm28>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm29>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31>
add r8,16*4 ; advance matrix C by 16 columns
IF RowCount GT 3
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
ENDIF
sub rbp,16
jmp Output32xNBlock
ENDM
;
; Macro Description:
;
; This macro generates the common AVX512 code for the inner kernel to compute
; matrix multiplication.
;
; Arguments:
;
; Type - Supplies the kernel type string for function tags.
;
; Isa - Supplies the instruction set architecture string for function tags.
;
GemmU8X8KernelAvx512Function MACRO Type, Isa
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8X8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8X8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
; the number of packed rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
; zero point offset of matrix B. These values are accumulated into every
; row of matrix C.
;
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; DepthValue - Supplies the value CountK multiplied by the zero point offset
; of matrixA multplied by the zero point offset of matrix B. This value is
; accumulated into every element of matrix C.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemm&Type&Kernel&Isa&, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
push_reg r14
alloc_stack (GemmU8X8KernelFrame.SavedR14)
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
END_PROLOGUE
mov rdi,rcx
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.CountM[rsp]
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
mov esi,-1
kmovw k1,esi ; update mask to write all columns
IFIDNI <Type>, <U8S8>
IFIDNI <Isa>, <Avx512Core>
neg esi
vpbroadcastw zmm5,esi ; generate 512-bit word vector [0x0001]
ENDIF
mov r14,r9
shl r14,4 ; compute matrix B packed stride
ELSE
lea r14,[r9*8] ; compute matrix B packed stride
ENDIF
;
; Process CountM rows of the matrices.
;
cmp r11,5
ja ProcessCountM6
je ProcessCountM5
cmp r11,3
ja ProcessCountM4
je ProcessCountM3
cmp r11,1
je ProcessCountM1
ProcessCountM2:
ProcessCountM 2
ProcessCountM4:
ProcessCountM 4
ProcessCountM6:
mov r11d,6 ; return 6 rows handled
ProcessCountM 6
;
; Restore non-volatile registers and return.
;
ExitKernel:
mov eax,r11d
vzeroupper
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR14)
BEGIN_EPILOGUE
pop r14
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
ProcessCountM1:
ProcessCountM 1
ProcessCountM3:
ProcessCountM 3
ProcessCountM5:
ProcessCountM 5
NESTED_END MlasGemm&Type&Kernel&Isa&, _TEXT
ENDM

View file

@ -0,0 +1,764 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8X8KernelAvx512Core.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions.
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE AssembleAvx512Vnni.inc
.list
;
; Stack frame layout for the U8X8 kernel.
;
GemmU8X8KernelFrame STRUCT
SavedXmm13 OWORD ?
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
SavedR14 QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumBuffer QWORD ?
ColumnSumBuffer QWORD ?
ZeroPointB QWORD ?
ZeroMode QWORD ?
GemmU8X8KernelFrame ENDS
;
; Macro Description:
;
; This macro generates code to load packed data from matrix B.
;
; Arguments:
;
; VecReg - Supplies the register to load the data into.
;
; AddressOperand - Supplies the address operand.
;
LoadPackedMatrixBU8S8 MACRO VecReg, AddressOperand
vmovdqu32 VecReg,ZMMWORD PTR AddressOperand
ENDM
LoadPackedMatrixBU8U8 MACRO VecReg, AddressOperand
vpmovzxbw VecReg,YMMWORD PTR AddressOperand
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single cell of the
; output block.
;
; Arguments:
;
; AccumReg - Supplies the register to accumulate into.
;
; Mult1Reg - Supplies the first multiplication operand register.
;
; Mult2Reg - Supplies the second multiplication operand register.
;
; Implicit Arguments:
;
; zmm4 - Supplies a scratch register for intermediate results.
;
; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001.
;
MultiplyAccumulateCellU8S8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg
vpmaddubsw zmm4,Mult1Reg,Mult2Reg
vpmaddwd zmm4,zmm4,zmm13
vpaddd AccumReg,AccumReg,zmm4
ENDM
MultiplyAccumulateCellU8S8Avx512Vnni MACRO AccumReg, Mult1Reg, Mult2Reg
VpdpbusdsZmmZmmZmm AccumReg,Mult1Reg,Mult2Reg
ENDM
MultiplyAccumulateCellU8U8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg
vpmaddwd zmm4,Mult1Reg,Mult2Reg
vpaddd AccumReg,AccumReg,zmm4
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; Type - Supplies the type of kernel to generate (U8S8 or U8U8).
;
; Isa - Supplies the instruction set architecture string.
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF ColumnCount GE 48
LoadPackedMatrixB&Type& zmm0,[rdx+VectorOffset]
LoadPackedMatrixB&Type& zmm1,[rdx+r14+VectorOffset]
LoadPackedMatrixB&Type& zmm2,[rdx+r14*2+VectorOffset]
ELSEIF ColumnCount GE 32
LoadPackedMatrixB&Type& zmm1,[rdx+VectorOffset]
LoadPackedMatrixB&Type& zmm2,[rdx+r14+VectorOffset]
ELSE
LoadPackedMatrixB&Type& zmm2,[rdx+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm19,zmm3,zmm2>
ENDM
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple times
; and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; Isa - Supplies the instruction set architecture string.
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
LOCAL ComputeBlockBy4Loop
LOCAL ProcessRemainingBlocks
LOCAL ComputeBlockBy1Loop
LOCAL ComputeBlockLoopExit
mov rsi,r9 ; reload row length remaining
IF ((RowCount AND 1) EQ 0)
sub rsi,4*4
jb ProcessRemainingBlocks
ComputeBlockBy4Loop:
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0*64, 0
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 1*64, 4
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 2*64, 8
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 3*64, 12
add rcx,4*4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,4*64 ; advance matrix B
sub rsi,4*4 ; decrement quads remaining
jae ComputeBlockBy4Loop
ProcessRemainingBlocks:
add rsi,4*4 ; correct for over-subtract above
jz ComputeBlockLoopExit
ENDIF
ComputeBlockBy1Loop:
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,64 ; advance matrix B
sub rsi,4 ; decrement quads remaining
jnz ComputeBlockBy1Loop
ComputeBlockLoopExit:
ENDM
ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount
LOCAL ComputeBlockBy1Loop
mov rsi,r9 ; reload row length remaining
ComputeBlockBy1Loop:
ComputeBlock U8U8, Isa, ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 pair
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
ENDIF
add rdx,32 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDM
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r11 - Supplies the address of the row sum buffer.
;
; r12 - Supplies the address of the column sum buffer.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
LOCAL SkipScaleByZeroPointB
LOCAL AccumulatorsInitialized
LOCAL ProduceWithU8S8Avx512Core
LOCAL ProduceWithU8U8Avx512Core
LOCAL ExitProduceOutputBlock
;
; Initialize the accumulators with the row and column sums.
;
IF ColumnCount GE 32
IF ColumnCount GE 48
vmovdqu32 zmm2,ZMMWORD PTR [r12]
vmovdqu32 zmm1,ZMMWORD PTR [r12+64]
vmovdqu32 zmm0,ZMMWORD PTR [r12+128]
ELSE
vmovdqu32 zmm1,ZMMWORD PTR [r12]
vmovdqu32 zmm0,ZMMWORD PTR [r12+64]
ENDIF
add_immed r12,ColumnCount*4 ; advance ColumnSumBuffer by N columns
ELSE
vmovdqu32 zmm0,ZMMWORD PTR [r12]
ENDIF
test r13,r13 ; per column zero points?
jz SkipScaleByZeroPointB
IF ColumnCount GE 32
IF ColumnCount GE 48
vmovdqu32 zmm5,ZMMWORD PTR [r13]
vmovdqu32 zmm4,ZMMWORD PTR [r13+64]
vmovdqu32 zmm3,ZMMWORD PTR [r13+128]
ELSE
vmovdqu32 zmm4,ZMMWORD PTR [r13]
vmovdqu32 zmm3,ZMMWORD PTR [r13+64]
ENDIF
add_immed r13,ColumnCount*4 ; advance ZeroPointB by N columns
ELSE
vmovdqu32 zmm3,ZMMWORD PTR [r13]
ENDIF
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpmulld zmm14,zmm3,DWORD BCST [r11]>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpmulld zmm20,zmm4,DWORD BCST [r11]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpmulld zmm26,zmm5,DWORD BCST [r11]>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,zmm14>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,zmm20>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,zmm26>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpmulld zmm15,zmm3,DWORD BCST [r11+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpmulld zmm21,zmm4,DWORD BCST [r11+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpmulld zmm27,zmm5,DWORD BCST [r11+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,zmm15>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,zmm21>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,zmm27>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpmulld zmm16,zmm3,DWORD BCST [r11+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpmulld zmm22,zmm4,DWORD BCST [r11+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpmulld zmm28,zmm5,DWORD BCST [r11+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,zmm16>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,zmm22>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,zmm28>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpmulld zmm17,zmm3,DWORD BCST [r11+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpmulld zmm23,zmm4,DWORD BCST [r11+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpmulld zmm29,zmm5,DWORD BCST [r11+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,zmm17>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,zmm23>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,zmm29>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpmulld zmm18,zmm3,DWORD BCST [r11+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpmulld zmm24,zmm4,DWORD BCST [r11+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpmulld zmm30,zmm5,DWORD BCST [r11+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,zmm18>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,zmm24>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,zmm30>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpmulld zmm19,zmm3,DWORD BCST [r11+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpmulld zmm25,zmm4,DWORD BCST [r11+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpmulld zmm31,zmm5,DWORD BCST [r11+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,zmm19>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,zmm25>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,zmm31>
jmp AccumulatorsInitialized
SkipScaleByZeroPointB:
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,DWORD BCST [r11]>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,DWORD BCST [r11]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,DWORD BCST [r11]>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,DWORD BCST [r11+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,DWORD BCST [r11+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,DWORD BCST [r11+4]>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,DWORD BCST [r11+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,DWORD BCST [r11+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,DWORD BCST [r11+8]>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,DWORD BCST [r11+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,DWORD BCST [r11+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,DWORD BCST [r11+12]>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,DWORD BCST [r11+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,DWORD BCST [r11+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,DWORD BCST [r11+16]>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,DWORD BCST [r11+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,DWORD BCST [r11+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,DWORD BCST [r11+20]>
AccumulatorsInitialized:
;
; Iterate over the length of a matrix A row to produce the output accumulators.
;
IF RowCount GT 3
lea rbx,[r9*2+r9]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
je ProduceWithU8S8Avx512Core
jg ProduceWithU8U8Avx512Core
ComputeBlockLoopU8S8 Avx512Vnni, ColumnCount, RowCount
jmp ExitProduceOutputBlock
ProduceWithU8U8Avx512Core:
ComputeBlockLoopU8U8 Avx512Core, ColumnCount, RowCount
jmp ExitProduceOutputBlock
ProduceWithU8S8Avx512Core:
ComputeBlockLoopU8S8 Avx512Core, ColumnCount, RowCount
ExitProduceOutputBlock:
IF RowCount GT 3
lea rbx,[rax*2+rax]
add rbx,r8 ; compute matrix C plus 3 rows
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10b - Supplies the zero mode flag.
;
; r11 - Supplies the address of the row sum buffer.
;
; r12 - Supplies the address of the column sum buffer.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
ProcessCountM MACRO RowCount
LOCAL ProcessNextColumnLoop32xN
LOCAL Output32xNBlock
LOCAL SkipAccumulateOutput32xNBlock
LOCAL Output16xNBlock
LOCAL Output16xNBlockWithMask
LOCAL SkipAccumulateOutput16xNBlockWithMask
LOCAL ProcessRemainingCountN
LOCAL ProcessNextColumnLoop48xN
LOCAL SkipAccumulateOutput48xNBlock
cmp rbp,32
ja ProcessNextColumnLoop48xN
cmp rbp,16
jbe ProcessRemainingCountN
ProcessNextColumnLoop32xN:
ProduceOutputBlock 32, RowCount
add rdx,r14 ; advance matrix B by packed block stride
Output32xNBlock:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput32xNBlock
EmitIfCountGE RowCount, 1, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm21,zmm21,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm22,zmm22,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput32xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm20>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm21>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm22>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm23>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25>
add r8,16*4 ; advance matrix C by 16 columns
IF RowCount GT 3
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
ENDIF
sub rbp,16
Output16xNBlock:
sub rbp,16
jae Output16xNBlockWithMask
lea ecx,[ebp+16] ; correct for over-subtract above
mov esi,1
shl esi,cl
dec esi
kmovw k1,esi ; update mask for remaining columns
xor ebp,ebp ; no more columns remaining
Output16xNBlockWithMask:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlockWithMask
EmitIfCountGE RowCount, 1, <vpaddd zmm14{k1},zmm14,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm15{k1},zmm15,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm16{k1},zmm16,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput16xNBlockWithMask:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm14>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm15>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm16>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,32
ja ProcessNextColumnLoop48xN
cmp rbp,16
ja ProcessNextColumnLoop32xN
test rbp,rbp
jnz ProcessRemainingCountN
mov eax,RowCount
jmp ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 16, RowCount
jmp Output16xNBlock
ProcessNextColumnLoop48xN:
ProduceOutputBlock 48, RowCount
lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput48xNBlock
EmitIfCountGE RowCount, 1, <vpaddd zmm26,zmm26,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm27,zmm27,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm28,zmm28,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput48xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm26>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm27>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm28>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm29>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31>
add r8,16*4 ; advance matrix C by 16 columns
IF RowCount GT 3
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
ENDIF
sub rbp,16
jmp Output32xNBlock
ENDM
;
; Reduce code size for the various types of kernels by sharing the outer logic
; and switching on the selector codes (using sign bit to discriminate).
;
LEAF_ENTRY MlasGemmU8S8KernelAvx512Vnni, _TEXT
mov eax,-1
jmp MlasGemmU8X8KernelAvx512Core
LEAF_END MlasGemmU8S8KernelAvx512Vnni, _TEXT
LEAF_ENTRY MlasGemmU8U8KernelAvx512Core, _TEXT
mov eax,1
jmp MlasGemmU8X8KernelAvx512Core
LEAF_END MlasGemmU8U8KernelAvx512Core, _TEXT
LEAF_ENTRY MlasGemmU8S8KernelAvx512Core, _TEXT
xor eax,eax
jmp MlasGemmU8X8KernelAvx512Core
LEAF_END MlasGemmU8S8KernelAvx512Core, _TEXT
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8X8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8X8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
; the number of packed rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumBuffer - Supplies the sum of each row from matrix A. These values have
; been pre-scaled by the zero point offset of matrix B if the offset is
; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
; scaled by the per-column zero point offsets of matrix B. These values are
; accumulated into every row of matrix C.
;
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
; B, else nullptr if the matrix B is using per-tensor quantization.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemmU8X8KernelAvx512Core, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
push_reg r14
alloc_stack (GemmU8X8KernelFrame.SavedR14)
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
END_PROLOGUE
mov DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],eax
mov rdi,rcx
mov rbx,GemmU8X8KernelFrame.CountM[rsp]
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.RowSumBuffer[rsp]
mov r12,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
mov r13,GemmU8X8KernelFrame.ZeroPointB[rsp]
mov esi,-1
kmovw k1,esi ; update mask to write all columns
neg esi
vpbroadcastw zmm13,esi ; generate 512-bit word vector [0x0001]
lea rsi,[r9*8] ; compute matrix B packed stride (U8U8)
lea r14,[rsi*2] ; compute matrix B packed stride (U8S8)
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
cmovg r14,rsi ; select matrix B packed stride
;
; Process CountM rows of the matrices.
;
cmp rbx,5
ja ProcessCountM6
je ProcessCountM5
cmp rbx,3
ja ProcessCountM4
je ProcessCountM3
cmp rbx,1
ja ProcessCountM2
ProcessCountM1:
ProcessCountM 1
ProcessCountM2:
ProcessCountM 2
ProcessCountM3:
ProcessCountM 3
ProcessCountM4:
ProcessCountM 4
ProcessCountM5:
ProcessCountM 5
ProcessCountM6:
ProcessCountM 6
;
; Restore non-volatile registers and return.
;
ExitKernel:
vzeroupper
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR14)
BEGIN_EPILOGUE
pop r14
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
NESTED_END MlasGemmU8X8KernelAvx512Core, _TEXT
END

View file

@ -0,0 +1,46 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
AssembleDotProduct.h
Abstract:
This module contains macros to build Advanced SIMD dot product instructions
for toolchains that do not natively support this newer instruction set
extension.
This implementation uses ARM v8.4 dot product instructions.
--*/
/*++
Macro Description:
This macro builds a UDOT instruction of the form:
UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index]
Arguments:
DestReg - Specifies the destination register.
Src1Reg - Specifies the first source register.
Src2Reg - Specifies the second source register.
Index - Specifies the element index of the second source register.
--*/
MACRO
UdotByElement $DestReg, $Src1Reg, $Src2Reg, $Index
DCD 0x6F80E000:OR:($DestReg):OR:($Src1Reg:SHL:5):OR:($Src2Reg:SHL:16):OR:(($Index:AND:2):SHL:10):OR:(($Index:AND:1):SHL:21)
MEND

View file

@ -22,7 +22,7 @@ Abstract:
//
#define GemmU8X8KernelFrame_ColumnSumBuffer 0
#define GemmU8X8KernelFrame_DepthValue 8
#define GemmU8X8KernelFrame_ZeroPointB 8
#define GemmU8X8KernelFrame_ZeroMode 16
//
@ -75,10 +75,6 @@ Arguments:
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrix A multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized, else
false if the output matrix is accumulated into.
@ -91,18 +87,16 @@ Return Value:
LEAF_ENTRY MlasGemmU8X8KernelNeon
ldr x8,[sp,#GemmU8X8KernelFrame_ColumnSumBuffer]
ldr s27,[sp,#GemmU8X8KernelFrame_DepthValue]
ldr x9,[sp,#GemmU8X8KernelFrame_ZeroPointB]
ldrb w13,[sp,#GemmU8X8KernelFrame_ZeroMode]
dup v27.4s,v27.s[0]
mov x14,x0
ld1 {v0.4s},[x7]
ld1 {v27.4s},[x7]
mov x15,x3
add v27.4s,v27.4s,v0.4s // broadcast add DepthValue
dup v24.4s,v27.s[0] // broadcast row fixups
cmp x4,#1 // CountM == 1?
beq ProcessNextColumnLoopM1
dup v25.4s,v27.s[1]
cmp x4,#4 // CountM < 4 ?
cmp x4,#4 // CountM < 4?
blo ProcessNextColumnLoopM2
dup v26.4s,v27.s[2]
dup v27.4s,v27.s[3]
@ -118,6 +112,30 @@ ProcessNextColumnLoopM4
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
uxtl v0.8h,v0.8b
cbz x9,SkipScaleByZeroPointBM4
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v24.4s,v28.4s
mul v17.4s,v24.4s,v29.4s
mul v18.4s,v25.4s,v28.4s
mul v19.4s,v25.4s,v29.4s
mul v20.4s,v26.4s,v28.4s
mul v21.4s,v26.4s,v29.4s
mul v22.4s,v27.4s,v28.4s
mul v23.4s,v27.4s,v29.4s
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
add v18.4s,v2.4s,v18.4s
add v19.4s,v3.4s,v19.4s
ld1 {v5.8b},[x0],#8 // load first packed A1
add v20.4s,v2.4s,v20.4s
add v21.4s,v3.4s,v21.4s
add v22.4s,v2.4s,v22.4s
add v23.4s,v3.4s,v23.4s
b ComputeBlockLoopM4
SkipScaleByZeroPointBM4
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v24.4s
add v17.4s,v3.4s,v24.4s
@ -329,6 +347,21 @@ ProcessNextColumnLoopM2
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
uxtl v0.8h,v0.8b
cbz x9,SkipScaleByZeroPointBM2
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v24.4s,v28.4s
mul v17.4s,v24.4s,v29.4s
mul v18.4s,v25.4s,v28.4s
mul v19.4s,v25.4s,v29.4s
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
add v18.4s,v2.4s,v18.4s
add v19.4s,v3.4s,v19.4s
b ComputeBlockLoopM2
SkipScaleByZeroPointBM2
ld1 {v4.8b},[x0],#8 // load first packed A0
add v16.4s,v2.4s,v24.4s
add v17.4s,v3.4s,v24.4s
@ -467,6 +500,17 @@ ProcessNextColumnLoopM1
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
uxtl v0.8h,v0.8b
cbz x9,SkipScaleByZeroPointBM1
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v24.4s,v28.4s
mul v17.4s,v24.4s,v29.4s
ldr s4,[x0],#4 // load first packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
b ComputeBlockLoopM1
SkipScaleByZeroPointBM1
ldr s4,[x0],#4 // load first packed A0
add v16.4s,v2.4s,v24.4s
add v17.4s,v3.4s,v24.4s

View file

@ -0,0 +1,587 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelUdot.asm
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses ARM v8.4 dot product instructions.
--*/
#include "kxarm64.h"
#include "AssembleDotProduct.h"
//
// Stack frame layout for the U8X8 kernel.
//
#define GemmU8XKernelFrame_SavedNeonRegisters (4 * 8)
#define GemmU8XKernelFrame_SavedRegisters GemmU8XKernelFrame_SavedNeonRegisters
#define GemmU8XKernelFrame_ColumnSumBuffer (0 + GemmU8XKernelFrame_SavedRegisters)
#define GemmU8XKernelFrame_ZeroPointB (8 + GemmU8XKernelFrame_SavedRegisters)
#define GemmU8XKernelFrame_ZeroMode (16 + GemmU8XKernelFrame_SavedRegisters)
TEXTAREA
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (x0) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>.
B (x1) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_UDOT>.
C (x2) - Supplies the address of matrix C.
PackedCountK (x3) - Supplies the number of packed columns from matrix A and
the number of packed rows from matrix B to iterate over.
CountM (x4) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (x5) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc (x6) - Supplies the first dimension of matrix C.
RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values
have been pre-scaled by the zero point offset of matrix B if the offset
is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
scaled by the per-column zero point offsets of matrix B. These values are
accumulated into every row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
B, else nullptr if the matrix B is using per-tensor quantization.
ZeroMode - Supplies true if the output matrix must be zero initialized, else
false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
NESTED_ENTRY MlasGemmU8X8KernelUdot
PROLOG_SAVE_REG_PAIR d8,d9,#-32!
PROLOG_SAVE_REG_PAIR d10,d11,#16
ldr x8,[sp,#GemmU8XKernelFrame_ColumnSumBuffer]
ldr x9,[sp,#GemmU8XKernelFrame_ZeroPointB]
ldrb w13,[sp,#GemmU8XKernelFrame_ZeroMode]
mov x14,x0
ld1 {v11.4s},[x7]
mov x15,x3
dup v8.4s,v11.s[0] // broadcast row fixups
cmp x4,#1 // CountM == 1?
beq ProcessNextColumnLoopM1
dup v9.4s,v11.s[1]
cmp x4,#4 // CountM < 4?
blo ProcessNextColumnLoopM2
dup v10.4s,v11.s[2]
dup v11.4s,v11.s[3]
//
// Process 4 rows of the matrices.
//
ProcessNextColumnLoopM4
ld1 {v0.16b},[x1],#16 // load packed B0
mov x0,x14 // reload matrix A
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
cbz x9,SkipScaleByZeroPointBM4
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
mul v16.4s,v30.4s,v8.4s
mul v18.4s,v30.4s,v9.4s
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
mul v20.4s,v30.4s,v10.4s
mul v22.4s,v30.4s,v11.4s
mul v17.4s,v31.4s,v8.4s
mul v19.4s,v31.4s,v9.4s
mul v21.4s,v31.4s,v10.4s
mul v23.4s,v31.4s,v11.4s
add v16.4s,v2.4s,v16.4s
add v18.4s,v2.4s,v18.4s
add v20.4s,v2.4s,v20.4s
add v22.4s,v2.4s,v22.4s
add v17.4s,v3.4s,v17.4s
add v19.4s,v3.4s,v19.4s
add v21.4s,v3.4s,v21.4s
add v23.4s,v3.4s,v23.4s
b ComputeBlockLoopStartM4
SkipScaleByZeroPointBM4
add v16.4s,v2.4s,v8.4s
add v18.4s,v2.4s,v9.4s
add v20.4s,v2.4s,v10.4s
add v22.4s,v2.4s,v11.4s
add v17.4s,v3.4s,v8.4s
add v19.4s,v3.4s,v9.4s
add v21.4s,v3.4s,v10.4s
add v23.4s,v3.4s,v11.4s
//
// The packing layout is setup to have a pair of four quad vectors from
// packed matrix A and a pair of eight quad vectors from packed matrix B.
// With this scheme, alternating loads from the packed matrices can be
// interleaved with the dot product instructions.
//
// One negative consequence of using four rows here is that the accumulator
// register tile is too small for processors with high out of order execution
// windows (such as the Apple M1). The dot product instructions for a given
// cell are too close to each other to avoid dependencies. To workaround this,
// the below loop uses a pair of accumulator registers that are then added
// together when the loop finishes.
//
// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for
// packed matrix A. At the time of this implementation, using a wider 128-bit
// load didn't affect performance for higher end cores.
//
ComputeBlockLoopStartM4
ldr d4,[x0],#32 // load packed A0.l
movi v24.4s,#0
movi v25.4s,#0
ldur d5,[x0,#-24] // load packed A0.h
movi v26.4s,#0
movi v27.4s,#0
ldur d6,[x0,#-16] // load packed A1.l
movi v28.4s,#0
movi v29.4s,#0
movi v30.4s,#0
movi v31.4s,#0
ComputeBlockLoopM4
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 16, 0, 4, 0
UdotByElement 18, 0, 4, 1
ldur d7,[x0,#-8] // load packed A1.h
UdotByElement 20, 0, 5, 0
UdotByElement 22, 0, 5, 1
ld1 {v0.16b},[x1],#16 // load packed B0
UdotByElement 17, 1, 4, 0
UdotByElement 19, 1, 4, 1
sub x3,x3,#1
cbz x3,ComputeBlockLoopFinishM4
ldr d4,[x0],#32 // load packed A0.l
UdotByElement 21, 1, 5, 0
UdotByElement 23, 1, 5, 1
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 24, 0, 6, 0
UdotByElement 26, 0, 6, 1
ldur d5,[x0,#-24] // load packed A0.h
UdotByElement 28, 0, 7, 0
UdotByElement 30, 0, 7, 1
ld1 {v0.16b},[x1],#16 // load packed B0
UdotByElement 25, 1, 6, 0
UdotByElement 27, 1, 6, 1
ldur d6,[x0,#-16] // load packed A1.l
UdotByElement 29, 1, 7, 0
UdotByElement 31, 1, 7, 1
b ComputeBlockLoopM4
ComputeBlockLoopFinishM4
UdotByElement 21, 1, 5, 0
UdotByElement 23, 1, 5, 1
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 24, 0, 6, 0
UdotByElement 26, 0, 6, 1
UdotByElement 28, 0, 7, 0
UdotByElement 30, 0, 7, 1
UdotByElement 25, 1, 6, 0
UdotByElement 27, 1, 6, 1
UdotByElement 29, 1, 7, 0
UdotByElement 31, 1, 7, 1
add x10,x2,x6,lsl #2 // compute output row 2
add v16.4s,v16.4s,v24.4s // fold high results into low results
add v18.4s,v18.4s,v26.4s
add v20.4s,v20.4s,v28.4s
add v22.4s,v22.4s,v30.4s
add x11,x10,x6,lsl #2 // compute output row 3
add v17.4s,v17.4s,v25.4s
add v19.4s,v19.4s,v27.4s
add v21.4s,v21.4s,v29.4s
add v23.4s,v23.4s,v31.4s
add x12,x11,x6,lsl #2 // compute output row 4
subs x5,x5,#8 // adjust CountN remaining
blo StoreOutputPartialM4
cbnz x13,SkipAccumulateOutputM4
ldp q0,q1,[x2]
ldp q2,q3,[x10]
add v16.4s,v16.4s,v0.4s
add v17.4s,v17.4s,v1.4s
ldp q4,q5,[x11]
add v18.4s,v18.4s,v2.4s
add v19.4s,v19.4s,v3.4s
ldp q6,q7,[x12]
add v20.4s,v20.4s,v4.4s
add v21.4s,v21.4s,v5.4s
add v22.4s,v22.4s,v6.4s
add v23.4s,v23.4s,v7.4s
SkipAccumulateOutputM4
stp q16,q17,[x2],#32
stp q18,q19,[x10]
stp q20,q21,[x11]
stp q22,q23,[x12]
cbnz x5,ProcessNextColumnLoopM4
ExitKernelM4
mov x0,#4 // return number of rows handled
EPILOG_RESTORE_REG_PAIR d10,d11,#16
EPILOG_RESTORE_REG_PAIR d8,d9,#32!
EPILOG_RETURN
//
// Store the partial 1 to 7 columns either overwriting the output matrix or
// accumulating into the existing contents of the output matrix.
//
StoreOutputPartialM4
cbz x13,StoreOutputPartialAddModeM4
StoreOutputPartialZeroModeM4
tbz x5,#2,StoreOutputPartial2ZeroModeM4
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
st1 {v20.4s},[x11],#16
mov v20.16b,v21.16b
st1 {v22.4s},[x12],#16
mov v22.16b,v23.16b
StoreOutputPartial2ZeroModeM4
tbz x5,#1,StoreOutputPartial1ZeroModeM4
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
st1 {v20.2s},[x11],#8
dup v20.4s,v20.s[2]
st1 {v22.2s},[x12],#8
dup v22.4s,v22.s[2]
StoreOutputPartial1ZeroModeM4
tbz x5,#0,ExitKernelM4
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
st1 {v20.s}[0],[x11]
st1 {v22.s}[0],[x12]
b ExitKernelM4
StoreOutputPartialAddModeM4
tbz x5,#2,StoreOutputPartial2AddModeM4
ld1 {v0.4s},[x2]
ld1 {v1.4s},[x10]
ld1 {v2.4s},[x11]
ld1 {v3.4s},[x12]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
add v20.4s,v20.4s,v2.4s
add v22.4s,v22.4s,v3.4s
st1 {v20.4s},[x11],#16
mov v20.16b,v21.16b
st1 {v22.4s},[x12],#16
mov v22.16b,v23.16b
StoreOutputPartial2AddModeM4
tbz x5,#1,StoreOutputPartial1AddModeM4
ld1 {v0.2s},[x2]
ld1 {v1.2s},[x10]
ld1 {v2.2s},[x11]
ld1 {v3.2s},[x12]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
add v20.4s,v20.4s,v2.4s
add v22.4s,v22.4s,v3.4s
st1 {v20.2s},[x11],#8
dup v20.4s,v20.s[2]
st1 {v22.2s},[x12],#8
dup v22.4s,v22.s[2]
StoreOutputPartial1AddModeM4
tbz x5,#0,ExitKernelM4
ld1 {v0.s}[0],[x2]
ld1 {v1.s}[0],[x10]
add v16.4s,v16.4s,v0.4s
ld1 {v2.s}[0],[x11]
add v18.4s,v18.4s,v1.4s
ld1 {v3.s}[0],[x12]
add v20.4s,v20.4s,v2.4s
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
add v22.4s,v22.4s,v3.4s
st1 {v20.s}[0],[x11]
st1 {v22.s}[0],[x12]
b ExitKernelM4
//
// Process 2 rows of the matrices.
//
ProcessNextColumnLoopM2
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
mov x0,x14 // reload matrix A
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
cbz x9,SkipScaleByZeroPointBM2
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
mul v16.4s,v30.4s,v8.4s
mul v18.4s,v30.4s,v9.4s
mul v17.4s,v31.4s,v8.4s
mul v19.4s,v31.4s,v9.4s
ld1 {v4.16b},[x0],#16 // load packed A0
add v16.4s,v2.4s,v16.4s
add v18.4s,v2.4s,v18.4s
add v17.4s,v3.4s,v17.4s
add v19.4s,v3.4s,v19.4s
b ComputeBlockLoopM2
SkipScaleByZeroPointBM2
ld1 {v4.16b},[x0],#16 // load packed A0
add v16.4s,v2.4s,v8.4s
add v18.4s,v2.4s,v9.4s
add v17.4s,v3.4s,v8.4s
add v19.4s,v3.4s,v9.4s
ComputeBlockLoopM2
UdotByElement 16, 0, 4, 0
UdotByElement 17, 1, 4, 0
UdotByElement 18, 0, 4, 1
UdotByElement 19, 1, 4, 1
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 16, 0, 4, 2
UdotByElement 17, 1, 4, 2
UdotByElement 18, 0, 4, 3
UdotByElement 19, 1, 4, 3
sub x3,x3,#1
cbz x3,ComputeBlockLoopFinishM2
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
ld1 {v4.16b},[x0],#16 // load packed A0
b ComputeBlockLoopM2
ComputeBlockLoopFinishM2
add x10,x2,x6,lsl #2 // compute output row 2
subs x5,x5,#8 // adjust CountN remaining
blo StoreOutputPartialM2
cbnz x13,SkipAccumulateOutputM2
ldp q0,q1,[x2]
ldp q2,q3,[x10]
add v16.4s,v16.4s,v0.4s
add v17.4s,v17.4s,v1.4s
add v18.4s,v18.4s,v2.4s
add v19.4s,v19.4s,v3.4s
SkipAccumulateOutputM2
stp q16,q17,[x2],#32
stp q18,q19,[x10]
cbnz x5,ProcessNextColumnLoopM2
ExitKernelM2
mov x0,#2 // return number of rows handled
EPILOG_RESTORE_REG_PAIR d10,d11,#16
EPILOG_RESTORE_REG_PAIR d8,d9,#32!
EPILOG_RETURN
//
// Store the partial 1 to 7 columns either overwriting the output matrix or
// accumulating into the existing contents of the output matrix.
//
StoreOutputPartialM2
cbz x13,StoreOutputPartialAddModeM2
StoreOutputPartialZeroModeM2
tbz x5,#2,StoreOutputPartial2ZeroModeM2
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
StoreOutputPartial2ZeroModeM2
tbz x5,#1,StoreOutputPartial1ZeroModeM2
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
StoreOutputPartial1ZeroModeM2
tbz x5,#0,ExitKernelM2
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
b ExitKernelM2
StoreOutputPartialAddModeM2
tbz x5,#2,StoreOutputPartial2AddModeM2
ld1 {v0.4s},[x2]
ld1 {v1.4s},[x10]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
st1 {v18.4s},[x10],#16
mov v18.16b,v19.16b
StoreOutputPartial2AddModeM2
tbz x5,#1,StoreOutputPartial1AddModeM2
ld1 {v0.2s},[x2]
ld1 {v1.2s},[x10]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
st1 {v18.2s},[x10],#8
dup v18.4s,v18.s[2]
StoreOutputPartial1AddModeM2
tbz x5,#0,ExitKernelM2
ld1 {v0.s}[0],[x2]
ld1 {v1.s}[0],[x10]
add v16.4s,v16.4s,v0.4s
add v18.4s,v18.4s,v1.4s
st1 {v16.s}[0],[x2]
st1 {v18.s}[0],[x10]
b ExitKernelM2
//
// Process 1 row of the matrices.
//
ProcessNextColumnLoopM1
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
mov x0,x14 // reload matrix A
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0
mov x3,x15 // reload PackedCountK
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
cbz x9,SkipScaleByZeroPointBM1
ld1 {v30.4s},[x9],#16 // load ZeroPointB0
ld1 {v31.4s},[x9],#16 // load ZeroPointB1
mul v16.4s,v30.4s,v8.4s
mul v17.4s,v31.4s,v8.4s
ldr d4,[x0],#8 // load packed A0
add v16.4s,v2.4s,v16.4s
add v17.4s,v3.4s,v17.4s
b ComputeBlockLoopM1
SkipScaleByZeroPointBM1
ldr d4,[x0],#8 // load packed A0
add v16.4s,v2.4s,v8.4s
add v17.4s,v3.4s,v8.4s
ComputeBlockLoopM1
UdotByElement 16, 0, 4, 0
UdotByElement 17, 1, 4, 0
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
UdotByElement 16, 0, 4, 1
UdotByElement 17, 1, 4, 1
sub x3,x3,#1
cbz x3,ComputeBlockLoopFinishM1
ldr d4,[x0],#8 // load packed A0
ld1 {v0.16b},[x1],#16 // load packed B0
ld1 {v1.16b},[x1],#16 // load packed B1
b ComputeBlockLoopM1
ComputeBlockLoopFinishM1
subs x5,x5,#8 // adjust CountN remaining
blo StoreOutputPartialM1
cbnz x13,SkipAccumulateOutputM1
ldp q0,q1,[x2]
add v16.4s,v16.4s,v0.4s
add v17.4s,v17.4s,v1.4s
SkipAccumulateOutputM1
stp q16,q17,[x2],#32
cbnz x5,ProcessNextColumnLoopM1
ExitKernelM1
mov x0,#1 // return number of rows handled
EPILOG_RESTORE_REG_PAIR d10,d11,#16
EPILOG_RESTORE_REG_PAIR d8,d9,#32!
EPILOG_RETURN
//
// Store the partial 1 to 7 columns either overwriting the output matrix or
// accumulating into the existing contents of the output matrix.
//
StoreOutputPartialM1
cbz x13,StoreOutputPartialAddModeM1
StoreOutputPartialZeroModeM1
tbz x5,#2,StoreOutputPartial2ZeroModeM1
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
StoreOutputPartial2ZeroModeM1
tbz x5,#1,StoreOutputPartial1ZeroModeM1
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
StoreOutputPartial1ZeroModeM1
tbz x5,#0,ExitKernelM1
st1 {v16.s}[0],[x2]
b ExitKernelM1
StoreOutputPartialAddModeM1
tbz x5,#2,StoreOutputPartial2AddModeM1
ld1 {v0.4s},[x2]
add v16.4s,v16.4s,v0.4s
st1 {v16.4s},[x2],#16
mov v16.16b,v17.16b // shift remaining elements down
StoreOutputPartial2AddModeM1
tbz x5,#1,StoreOutputPartial1AddModeM1
ld1 {v0.2s},[x2]
add v16.4s,v16.4s,v0.4s
st1 {v16.2s},[x2],#8
dup v16.4s,v16.s[2] // shift remaining elements down
StoreOutputPartial1AddModeM1
tbz x5,#0,ExitKernelM1
ld1 {v0.s}[0],[x2]
add v16.4s,v16.4s,v0.4s
st1 {v16.s}[0],[x2]
b ExitKernelM1
NESTED_END MlasGemmU8X8KernelUdot
END

View file

@ -245,14 +245,6 @@ void
typedef MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE;
typedef
void
(MLASCALL MLAS_GEMM_U8X8_OPERATION)(
const struct MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
);
typedef MLAS_GEMM_U8X8_OPERATION* PMLAS_GEMM_U8X8_OPERATION;
typedef
size_t
(MLASCALL MLAS_GEMM_U8S8_KERNEL)(
@ -265,7 +257,7 @@ size_t
size_t ldc,
const int32_t* RowSumVector,
const int32_t* ColumnSumVector,
int32_t DepthValue,
const int32_t* ZeroPointB,
bool ZeroMode
);
@ -296,7 +288,7 @@ size_t
size_t ldc,
const int32_t* RowSumVector,
const int32_t* ColumnSumVector,
int32_t DepthValue,
const int32_t* ZeroPointB,
bool ZeroMode
);
@ -697,26 +689,17 @@ MlasSgemmOperation(
);
//
// Quantized integer matrix/matrix multiply operation.
// Quantized integer matrix/matrix dispatch structure.
//
struct MLAS_GEMM_U8X8_KERNEL_SSE;
struct MLAS_GEMM_U8S8_KERNEL_AVX2;
struct MLAS_GEMM_U8U8_KERNEL_AVX2;
struct MLAS_GEMM_U8X8_DISPATCH;
template<typename KernelType>
void
MLASCALL
MlasGemmU8X8Operation(
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
);
template<typename KernelType>
void
MLASCALL
MlasGemmU8X8PackedOperation(
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
);
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchSse;
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8S8DispatchAvx2;
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8U8DispatchAvx2;
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchNeon;
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchUdot;
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchDefault;
//
// Quantized depthwise convolution kernels.
@ -767,12 +750,10 @@ struct MLAS_PLATFORM {
PMLAS_SGEMM_KERNEL_M1_ROUTINE KernelM1TransposeBRoutine;
PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE TransposePackB16x4Routine;
PMLAS_GEMM_DOUBLE_KERNEL GemmDoubleKernel;
PMLAS_GEMM_U8X8_OPERATION GemmU8S8Operation;
PMLAS_GEMM_U8X8_OPERATION GemmU8S8PackedOperation;
const MLAS_GEMM_U8X8_DISPATCH* GemmU8S8Dispatch;
PMLAS_GEMM_U8S8_KERNEL GemmU8S8Kernel;
PMLAS_GEMV_U8S8_KERNEL GemvU8S8Kernel;
PMLAS_GEMM_U8X8_OPERATION GemmU8U8Operation;
PMLAS_GEMM_U8X8_OPERATION GemmU8U8PackedOperation;
const MLAS_GEMM_U8X8_DISPATCH* GemmU8U8Dispatch;
PMLAS_GEMM_U8U8_KERNEL GemmU8U8Kernel;
PMLAS_CONV_FLOAT_KERNEL ConvNchwFloatKernel;
PMLAS_CONV_FLOAT_KERNEL ConvNchwcFloatKernel;
@ -800,6 +781,10 @@ struct MLAS_PLATFORM {
#else
static constexpr uint32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
#endif
#if defined(MLAS_TARGET_ARM64)
const MLAS_GEMM_U8X8_DISPATCH* GemmU8X8Dispatch;
#endif
};
extern MLAS_PLATFORM MlasPlatform;

View file

@ -17,6 +17,16 @@ Abstract:
#include "mlasi.h"
#if defined(MLAS_TARGET_ARM64) && defined(__linux__)
#include <sys/auxv.h>
#include <asm/hwcap.h>
// N.B. Support building with older versions of asm/hwcap.h that do not define
// this capability bit.
#ifndef HWCAP_ASIMDDP
#define HWCAP_ASIMDDP (1 << 20)
#endif
#endif
//
// Stores the platform information.
//
@ -120,8 +130,8 @@ Return Value:
this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse;
this->GemmDoubleKernel = MlasGemmDoubleKernelSse;
this->GemmU8S8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8X8_KERNEL_SSE>;
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8X8_KERNEL_SSE>;
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchSse;
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchSse;
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse;
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelSse;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelSse;
@ -206,12 +216,10 @@ Return Value:
if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) {
this->GemmU8S8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
this->GemmU8S8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2;
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8U8_KERNEL_AVX2>;
this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8U8_KERNEL_AVX2>;
this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2;
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2;
this->GemmFloatKernel = MlasGemmFloatKernelFma3;
@ -229,7 +237,7 @@ Return Value:
this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernelAvx2<int8_t>;
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernelAvx2<uint8_t>;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
//
// Check if the processor supports Hybrid core architecture.
//
@ -251,8 +259,7 @@ Return Value:
if ((Cpuid7_1[0] & 0x10) != 0) {
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
}
@ -304,8 +311,7 @@ Return Value:
if ((Cpuid7[2] & 0x800) != 0) {
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni;
}
@ -326,6 +332,24 @@ Return Value:
#endif // MLAS_TARGET_AMD64_IX86
#if defined(MLAS_TARGET_ARM64)
this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchNeon;
#if defined(__linux__)
//
// Check if the processor supports ASIMD dot product instructions.
//
if ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0) {
this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchUdot;
}
#endif
#endif
}
size_t

File diff suppressed because it is too large Load diff

View file

@ -18,7 +18,6 @@ Abstract:
--*/
#include "asmmacro.h"
#include "QgemmU8X8KernelAvx2Common.h"
.intel_syntax noprefix
@ -27,7 +26,7 @@ Abstract:
//
.equ .LGemmU8S8CopyPackAFrame_PaddedMatrixAData, -72
.equ .LGemmU8S8CopyPackAFrame_mask, -8
.equ .LGemmU8S8CopyPackAFrame_Padding, -8
.equ .LGemmU8S8CopyPackAFrame_SavedR13, 0
.equ .LGemmU8S8CopyPackAFrame_SavedR12, 8
.equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16
@ -76,8 +75,7 @@ Return Value:
--*/
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2)
C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2):
FUNCTION_ENTRY MlasGemmU8S8CopyPackAAvx2
push rbp
push rbx
@ -101,9 +99,9 @@ C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2):
and eax,15 # isolate unaligned count
add eax,3
shr eax,2 # align unaligned count to quad count
mov DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp],eax
vpbroadcastd xmm10,DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp]
vpcmpgtd xmm10,xmm10,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
neg rax
lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4]
vmovdqu xmm10,XMMWORD PTR [rbx+rax*4]
//
// Zero initialize the padded stack buffers.
@ -430,8 +428,7 @@ Return Value:
--*/
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2)
C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2):
FUNCTION_ENTRY MlasGemmU8S8CopyPackBAvx2
push rbp
push rbx
@ -698,258 +695,4 @@ C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2):
vmovdqu YMMWORD PTR [r9+32],ymm1
jmp .LCopyPackB.ExitRoutine
/*++
Macro Description:
This macro generates code to multiply and accumulator a single row of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 16).
Vec2Reg - Supplies the low block accumulator register.
Implicit Arguments:
ymm0 - Supplies the first vector loaded from matrix B.
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
is 16).
ymm2 - Supplies the broadcast value loaded from matrix A.
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
--*/
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
vpmaddubsw ymm3,ymm2,ymm0
vpmaddwd ymm3,ymm3,ymm12
.if \ColumnCount\() == 16
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
vpmaddubsw ymm2,ymm2,ymm1
vpmaddwd ymm2,ymm2,ymm12
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
.else
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rcx - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm11 - Supplies the block accumulators.
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \RowCount\() == 1
vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()]
vpmaddwd ymm3,ymm3,ymm12
.if \ColumnCount\() == 16
vpaddd ymm4,ymm4,ymm3
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32]
vpmaddwd ymm2,ymm2,ymm12
vpaddd ymm5,ymm5,ymm2
.else
vpaddd ymm5,ymm5,ymm3
.endif
.else
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
.endif
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm11 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoop ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,64 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endm
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8S8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8S8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
and the number of packed rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrix A multplied by the zero point offset of matrix B. This value
is accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
.globl C_UNDERSCORE(MlasGemmU8S8KernelAvx2)
C_UNDERSCORE(MlasGemmU8S8KernelAvx2):
push rbp
push rbx
push r12
push r13
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,rdi
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF]
vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001]
//
// Process CountM rows of the matrices.
//
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
mov r8d,4 # return 4 rows handled
ProcessCountM 4, Fallthrough
//
// Restore non-volatile registers and return.
//
.LExitKernel:
mov eax,r8d
vzeroupper
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.end

View file

@ -1,88 +0,0 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvx512Common.h
Abstract:
This module contains common kernel macros and structures for the quantized
integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
AVX512VNNI kernels.
--*/
#include "QgemmU8X8KernelAvx512Common.h"
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoop ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.if ((\RowCount\() & 1) == 0)
sub rbp,4*4
jb .LProcessRemainingBlocks\@
.LComputeBlockBy4Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0*64, 0
ComputeBlock \ColumnCount\(), \RowCount\(), 1*64, 4
ComputeBlock \ColumnCount\(), \RowCount\(), 2*64, 8
ComputeBlock \ColumnCount\(), \RowCount\(), 3*64, 12
add rdi,4*4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add rbx,4*4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,4*64 # advance matrix B
sub rbp,4*4 # decrement quads remaining
jae .LComputeBlockBy4Loop\@
.LProcessRemainingBlocks\@:
add rbp,4*4 # correct for over-subtract above
jz .LComputeBlockLoopExit\@
.endif
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,64 # advance matrix B
sub rbp,4 # decrement quads remaining
jnz .LComputeBlockBy1Loop\@
.LComputeBlockLoopExit\@:
.endm

View file

@ -1,136 +0,0 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvx512Core.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX512 core instructions (BW/DQ/VL).
--*/
#include "asmmacro.h"
#include "QgemmU8S8KernelAvx512Common.h"
.intel_syntax noprefix
.text
/*++
Macro Description:
This macro generates code to multiply and accumulator a single cell of the
output block.
Arguments:
AccumReg - Supplies the register to accumulate into.
Mult1Reg - Supplies the first multiplication operand register.
Mult2Reg - Supplies the second multiplication operand register.
Implicit Arguments:
zmm4 - Supplies a scratch register for intermediate results.
zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
--*/
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\()
vpmaddwd zmm4,zmm4,zmm5
vpaddd \AccumReg\(),\AccumReg\(),zmm4
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \ColumnCount\() >= 48
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
.elseif \ColumnCount\() >= 32
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
.else
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
.endm
//
// Generate the GEMM kernel.
//
GemmU8X8KernelAvx512Function U8S8, Avx512Core
.end

View file

@ -1,106 +0,0 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvx512Vnni.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX512VNNI instructions.
--*/
#include "asmmacro.h"
#include "QgemmU8S8KernelAvx512Common.h"
#include "AssembleAvx512Vnni.h"
.intel_syntax noprefix
.text
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \ColumnCount\() >= 48
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
.elseif \ColumnCount\() >= 32
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
.else
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2"
.endm
//
// Generate the GEMM kernel.
//
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
.end

View file

@ -1,268 +0,0 @@
/*++
Copyright (c) 2020 Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvxVnni.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVXVNNI instructions.
--*/
#include "asmmacro.h"
#include "QgemmU8X8KernelAvx2Common.h"
#include "AssembleAvxVnni.h"
.intel_syntax noprefix
/*++
Macro Description:
This macro generates code to multiply and accumulator a single row of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 16).
Vec2Reg - Supplies the low block accumulator register.
Implicit Arguments:
ymm0 - Supplies the first vector loaded from matrix B.
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
is 16).
ymm2 - Supplies the broadcast value loaded from matrix A.
--*/
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
.if \ColumnCount\() == 16
VpdpbusdsYmmYmmYmm \Vec1Reg\(),ymm2,ymm0
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm1
.else
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm0
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rcx - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15"
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoop ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,64 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endm
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8S8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8S8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
and the number of packed rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrix A multplied by the zero point offset of matrix B. This value
is accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
.globl C_UNDERSCORE(MlasGemmU8S8KernelAvxVnni)
C_UNDERSCORE(MlasGemmU8S8KernelAvxVnni):
push rbp
push rbx
push r12
push r13
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,rdi
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
//
// Process CountM rows of the matrices.
//
cmp r8,5
ja .LProcessCountM6
je .LProcessCountM5
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
ProcessCountM 4
.LProcessCountM6:
mov r8d,6 # return 6 rows handled
ProcessCountM 6, Fallthrough
//
// Restore non-volatile registers and return.
//
.LExitKernel:
mov eax,r8d
vzeroupper
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.LProcessCountM5:
ProcessCountM 5
.end

View file

@ -18,7 +18,6 @@ Abstract:
--*/
#include "asmmacro.h"
#include "QgemmU8X8KernelAvx2Common.h"
.intel_syntax noprefix
@ -27,7 +26,7 @@ Abstract:
//
.equ .LGemmU8U8CopyPackAFrame_PaddedMatrixAData, -72
.equ .LGemmU8U8CopyPackAFrame_mask, -8
.equ .LGemmU8U8CopyPackAFrame_Padding, -8
.equ .LGemmU8U8CopyPackAFrame_SavedR13, 0
.equ .LGemmU8U8CopyPackAFrame_SavedR12, 8
.equ .LGemmU8U8CopyPackAFrame_SavedRbx, 16
@ -79,8 +78,7 @@ Return Value:
--*/
.globl C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2)
C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
FUNCTION_ENTRY MlasGemmU8U8CopyPackAAvx2
push rbp
push rbx
@ -102,9 +100,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
and eax,15 # isolate unaligned count
inc eax
shr eax,1 # align unaligned count to pair count
mov DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp],eax
vpbroadcastd ymm9,DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp]
vpcmpgtd ymm9,ymm9,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
neg rax
lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4]
vmovdqu ymm9,YMMWORD PTR [rbx+rax*4]
//
// Zero initialize the padded stack buffers.
@ -394,8 +392,7 @@ Return Value:
--*/
.globl C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2)
C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
FUNCTION_ENTRY MlasGemmU8U8CopyPackBAvx2
push rbp
push rbx
@ -599,273 +596,4 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
vmovdqu YMMWORD PTR [r9+32],ymm1
jmp .LCopyPackB.ExitRoutine
/*++
Macro Description:
This macro generates code to multiply and accumulator a single row of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 16).
Vec2Reg - Supplies the low block accumulator register.
Implicit Arguments:
ymm0 - Supplies the first vector loaded from matrix B.
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
is 16).
ymm2 - Supplies the broadcast value loaded from matrix A.
--*/
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
vpmaddwd ymm3,ymm2,ymm0
.if \ColumnCount\() == 16
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
vpmaddwd ymm2,ymm2,ymm1
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
.else
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rdi - Supplies the address into the matrix A data.
rbx - Supplies the address into the matrix A data plus 3 rows.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15"
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoop ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0)
sub rbp,2*4
jb .LProcessRemainingBlocks\@
.LComputeBlockBy2Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
ComputeBlock \ColumnCount\(), \RowCount\(), 32, 4
add rdi,2*4 # advance matrix A by 2 pairs
.if \RowCount\() > 3
add rbx,2*4 # advance matrix A plus 3 rows by 2 pairs
.endif
add rsi,2*32 # advance matrix B
sub rbp,2*4
jae .LComputeBlockBy2Loop\@
.LProcessRemainingBlocks\@:
add rbp,2*4 # correct for over-subtract above
jz .LComputeBlockLoopExit\@
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rsi,32 # advance matrix B
.else
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 pair
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 pair
.endif
add rsi,32 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endif
.LComputeBlockLoopExit\@:
.endm
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8U8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8U8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
and the number of packed rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrix A multplied by the zero point offset of matrix B. This value
is accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
.globl C_UNDERSCORE(MlasGemmU8U8KernelAvx2)
C_UNDERSCORE(MlasGemmU8U8KernelAvx2):
push rbp
push rbx
push r12
push r13
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,rdi
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
//
// Process CountM rows of the matrices.
//
cmp r8,5
ja .LProcessCountM6
je .LProcessCountM5
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
ProcessCountM 4
.LProcessCountM6:
mov r8d,6 # return 6 rows handled
ProcessCountM 6, Fallthrough
//
// Restore non-volatile registers and return.
//
.LExitKernel:
mov eax,r8d
vzeroupper
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.LProcessCountM5:
ProcessCountM 5
.end

View file

@ -1,178 +0,0 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8U8KernelAvx512Core.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX512 core instructions (BW/DQ/VL).
--*/
#include "asmmacro.h"
#include "QgemmU8X8KernelAvx512Common.h"
.intel_syntax noprefix
.text
/*++
Macro Description:
This macro generates code to multiply and accumulator a single cell of the
output block.
Arguments:
AccumReg - Supplies the register to accumulate into.
Mult1Reg - Supplies the first multiplication operand register.
Mult2Reg - Supplies the second multiplication operand register.
Implicit Arguments:
zmm4 - Supplies a scratch register for intermediate results.
--*/
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\()
vpaddd \AccumReg\(),\AccumReg\(),zmm4
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \ColumnCount\() >= 48
vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()]
vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()]
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()]
.elseif \ColumnCount\() >= 32
vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()]
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()]
.else
vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoop ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 pair
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 pair
.endif
add rsi,32 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endm
//
// Generate the GEMM kernel.
//
GemmU8X8KernelAvx512Function U8U8, Avx512Core
.end

View file

@ -0,0 +1,868 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelAvx2.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX2 and AVX VNNI instructions.
--*/
#include "asmmacro.h"
#include "AssembleAvxVnni.h"
.intel_syntax noprefix
//
// Stack frame layout for the U8X8 kernel.
//
.equ .LGemmU8X8KernelFrame_type, -8
.equ .LGemmU8X8KernelFrame_SavedR13, 0
.equ .LGemmU8X8KernelFrame_SavedR12, 8
.equ .LGemmU8X8KernelFrame_SavedRbx, 16
.equ .LGemmU8X8KernelFrame_SavedRbp, 24
.equ .LGemmU8X8KernelFrame_ReturnAddress, 32
.equ .LGemmU8X8KernelFrame_ldc, 40
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 48
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 56
.equ .LGemmU8X8KernelFrame_ZeroPointB, 64
.equ .LGemmU8X8KernelFrame_ZeroMode, 72
/*++
Macro Description:
This macro generates code to multiply and accumulator a single row of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 16).
Vec2Reg - Supplies the low block accumulator register.
Implicit Arguments:
ymm0 - Supplies the first vector loaded from matrix B.
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
is 16).
ymm2 - Supplies the broadcast value loaded from matrix A.
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
--*/
.macro MultiplyAccumulateRowU8S8Avx2 ColumnCount, Vec1Reg, Vec2Reg
vpmaddubsw ymm3,ymm2,ymm0
vpmaddwd ymm3,ymm3,ymm12
.if \ColumnCount\() == 16
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
vpmaddubsw ymm2,ymm2,ymm1
vpmaddwd ymm2,ymm2,ymm12
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
.else
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rdi - Supplies the address into the matrix A data.
r8 - Supplies the address into the matrix A data plus 3 rows.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm11 - Supplies the block accumulators.
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
--*/
.macro ComputeBlockU8S8Avx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \RowCount\() == 1
vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()]
vpmaddwd ymm3,ymm3,ymm12
.if \ColumnCount\() == 16
vpaddd ymm4,ymm4,ymm3
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32]
vpmaddwd ymm2,ymm2,ymm12
vpaddd ymm5,ymm5,ymm2
.else
vpaddd ymm5,ymm5,ymm3
.endif
.else
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm10, ymm11"
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulator a single row of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 16).
Vec2Reg - Supplies the low block accumulator register.
Implicit Arguments:
ymm0 - Supplies the first vector loaded from matrix B.
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
is 16).
ymm2 - Supplies the broadcast value loaded from matrix A.
--*/
.macro MultiplyAccumulateRowU8S8AvxVnni ColumnCount, Vec1Reg, Vec2Reg
.if \ColumnCount\() == 16
VpdpbusdsYmmYmmYmm \Vec1Reg\(),ymm2,ymm0
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm1
.else
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm0
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rdi - Supplies the address into the matrix A data.
r8 - Supplies the address into the matrix A data plus 3 rows.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ComputeBlockU8S8AvxVnni ColumnCount, RowCount, VectorOffset, BroadcastOffset
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm10, ymm11"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm12, ymm13"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm14, ymm15"
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple times
and advancing the matrix A and matrix B data pointers.
Arguments:
Isa - Supplies the instruction set architecture string.
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
r8 - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm11 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoopU8S8 Isa, ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.LComputeBlockBy1Loop\@:
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add r8,4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,64 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulator a single row of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 16).
Vec2Reg - Supplies the low block accumulator register.
Implicit Arguments:
ymm0 - Supplies the first vector loaded from matrix B.
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
is 16).
ymm2 - Supplies the broadcast value loaded from matrix A.
--*/
.macro MultiplyAccumulateRowU8U8Avx2 ColumnCount, Vec1Reg, Vec2Reg
vpmaddwd ymm3,ymm2,ymm0
.if \ColumnCount\() == 16
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
vpmaddwd ymm2,ymm2,ymm1
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
.else
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rdi - Supplies the address into the matrix A data.
r8 - Supplies the address into the matrix A data plus 3 rows.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ComputeBlockU8U8Avx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm10, ymm11"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm12, ymm13"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm14, ymm15"
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple times
and advancing the matrix A and matrix B data pointers.
Arguments:
Isa - Supplies the instruction set architecture string.
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rdi - Supplies the address into the matrix A data.
r8 - Supplies the address into the matrix A data plus 3 rows.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0)
sub rbp,2*4
jb .LProcessRemainingBlocks\@
.LComputeBlockBy2Loop\@:
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 32, 4
add rdi,2*4 # advance matrix A by 2 pairs
.if \RowCount\() > 3
add r8,2*4 # advance matrix A plus 3 rows by 2 pairs
.endif
add rsi,2*32 # advance matrix B
sub rbp,2*4
jae .LComputeBlockBy2Loop\@
.LProcessRemainingBlocks\@:
add rbp,2*4 # correct for over-subtract above
jz .LComputeBlockLoopExit\@
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
add rsi,32 # advance matrix B
.else
.LComputeBlockBy1Loop\@:
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 pair
.if \RowCount\() > 3
add r8,4 # advance matrix A plus 3 rows by 1 pair
.endif
add rsi,32 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endif
.LComputeBlockLoopExit\@:
.endm
/*++
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r11 - Supplies the address of the row sum buffer.
r12 - Supplies the address of the column sum buffer.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
//
// Initialize the accumulators with the row and column sums.
//
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r11]"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r11+4]"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r11+8]"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r11+12]"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r11+16]"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r11+20]"
.if \ColumnCount\() == 16
vmovdqu ymm0,YMMWORD PTR [r12]
vmovdqu ymm1,YMMWORD PTR [r12+32]
add r12,16*4 # advance ColumnSumBuffer by 16 columns
.else
vmovdqu ymm1,YMMWORD PTR [r12]
.endif
test r13,r13 # per column zero points?
jz .LSkipScaleByZeroPointB\@
.if \ColumnCount\() == 16
vmovdqu ymm2,YMMWORD PTR [r13]
vmovdqu ymm3,YMMWORD PTR [r13+32]
add r13,16*4 # advance ZeroPointB by 16 columns
.else
vmovdqu ymm3,YMMWORD PTR [r13]
.endif
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld ymm4,ymm5,ymm2"
EmitIfCountGE \RowCount\(), 1, "vpmulld ymm5,ymm5,ymm3"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm0,ymm4"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm1,ymm5"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld ymm6,ymm7,ymm2"
EmitIfCountGE \RowCount\(), 2, "vpmulld ymm7,ymm7,ymm3"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm0,ymm6"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm1,ymm7"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld ymm8,ymm9,ymm2"
EmitIfCountGE \RowCount\(), 3, "vpmulld ymm9,ymm9,ymm3"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm0,ymm8"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm1,ymm9"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld ymm10,ymm11,ymm2"
EmitIfCountGE \RowCount\(), 4, "vpmulld ymm11,ymm11,ymm3"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm0,ymm10"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm1,ymm11"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld ymm12,ymm13,ymm2"
EmitIfCountGE \RowCount\(), 5, "vpmulld ymm13,ymm13,ymm3"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm0,ymm12"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm1,ymm13"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld ymm14,ymm15,ymm2"
EmitIfCountGE \RowCount\(), 6, "vpmulld ymm15,ymm15,ymm3"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm0,ymm14"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm1,ymm15"
jmp .LAccumulatorsInitialized\@
.LSkipScaleByZeroPointB\@:
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
.LAccumulatorsInitialized\@:
//
// Iterate over the length of a matrix A row to produce the output accumulators.
//
.if \RowCount\() > 3
lea r8,[rcx*2+rcx]
add r8,rdi # compute matrix A plus 3 rows
.endif
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
jg .LProduceWithU8U8Avx2\@
.if \RowCount\() <= 4
jl .LProduceWithU8S8AvxVnni\@
ComputeBlockLoopU8S8 Avx2, \ColumnCount\(), \RowCount\()
jmp .LExitProduceOutputBlock\@
.endif
.LProduceWithU8S8AvxVnni\@:
ComputeBlockLoopU8S8 AvxVnni, \ColumnCount\(), \RowCount\()
jmp .LExitProduceOutputBlock\@
.LProduceWithU8U8Avx2\@:
ComputeBlockLoopU8U8 Avx2, \ColumnCount\(), \RowCount\()
.LExitProduceOutputBlock\@:
.if \RowCount\() > 3
lea r8,[rax*2+rax]
add r8,rdx # compute matrix C plus 3 rows
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
rbx - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the length in bytes of a row from matrix A.
r10b - Supplies the zero mode flag.
r11 - Supplies the address of the row sum buffer.
r12 - Supplies the address of the column sum buffer.
--*/
.macro ProcessCountM RowCount
cmp r9,8
jbe .LProcessRemainingCountN\@
.LProcessNextColumnLoop16xN\@:
ProduceOutputBlock 16, \RowCount\()
sub r9,16
jb .LOutputMasked16xNBlock\@
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8+32]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax+32]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2+32]"
.LSkipAccumulateOutput16xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8+32],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax+32],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2+32],ymm15"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,rbx # reload matrix A
cmp r9,8
ja .LProcessNextColumnLoop16xN\@
test r9,r9
jnz .LProcessRemainingCountN\@
.LExitProcessCountM\@:
mov eax,\RowCount\()
jmp .LExitKernel
.LProcessRemainingCountN\@:
ProduceOutputBlock 8, \RowCount\()
cmp r9,8
jb .LOutputMasked8xNBlock\@
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput8xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2]"
.LSkipAccumulateOutput8xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm15"
jmp .LExitProcessCountM\@
.LOutputMasked16xNBlock\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutputMasked16xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]"
.LSkipAccumulateOutputMasked16xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14"
add rdx,8*4 # advance matrix C by 8 columns
.if \RowCount\() > 3
add r8,8*4 # advance matrix C plus 3 rows by 8 columns
.endif
add r9,8 # correct for over-subtract above
.LOutputMasked8xNBlock\@:
neg r9
lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4]
vmovdqu ymm0,YMMWORD PTR [rdi+r9*4]
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutputMasked8xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [r8]"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [r8+rax]"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [r8+rax*2]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
.LSkipAccumulateOutputMasked8xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [r8],ymm0,ymm11"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm13"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm15"
jmp .LExitProcessCountM\@
.endm
//
// Reduce code size for the various types of kernels by sharing the outer logic
// and switching on the selector codes (using sign bit to discriminate).
//
FUNCTION_ENTRY MlasGemmU8S8KernelAvxVnni
mov eax,-1
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2)
FUNCTION_ENTRY MlasGemmU8U8KernelAvx2
mov eax,1
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2)
FUNCTION_ENTRY MlasGemmU8S8KernelAvx2
xor eax,eax
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2)
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
and the number of packed rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A. These values have
been pre-scaled by the zero point offset of matrix B if the offset is
per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
scaled by the per-column zero point offsets of matrix B. These values are
accumulated into every row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
B, else nullptr if the matrix B is using per-tensor quantization.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
FUNCTION_ENTRY MlasGemmU8X8KernelAvx2
push rbp
push rbx
push r12
push r13
mov DWORD PTR .LGemmU8X8KernelFrame_type[rsp],eax
mov rbx,rdi
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
mov r12,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
mov r13,.LGemmU8X8KernelFrame_ZeroPointB[rsp]
vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF]
vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001]
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
je .LCheckCountM4OrMore # U8S8 AVX2 kernel requires extra registers
//
// Process CountM rows of the matrices.
//
.LCheckCountM6OrMore:
cmp r8,5
ja .LProcessCountM6
je .LProcessCountM5
.LCheckCountM4OrMore:
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
ProcessCountM 4
.LProcessCountM6:
ProcessCountM 6
//
// Restore non-volatile registers and return.
//
.LExitKernel:
vzeroupper
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.LProcessCountM5:
ProcessCountM 5
.end

View file

@ -1,273 +0,0 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelAvx2Common.h
Abstract:
This module contains common kernel macros and structures for the quantized
integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
--*/
//
// Stack frame layout for the U8S8 and U8U8 kernels.
//
.equ .LGemmU8X8KernelFrame_mask, -8
.equ .LGemmU8X8KernelFrame_SavedR13, 0
.equ .LGemmU8X8KernelFrame_SavedR12, 8
.equ .LGemmU8X8KernelFrame_SavedRbx, 16
.equ .LGemmU8X8KernelFrame_SavedRbp, 24
.equ .LGemmU8X8KernelFrame_ReturnAddress, 32
.equ .LGemmU8X8KernelFrame_ldc, 40
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 48
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 56
.equ .LGemmU8X8KernelFrame_DepthValue, 64
.equ .LGemmU8X8KernelFrame_ZeroMode, 72
/*++
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r12 - Supplies the address of the row sum buffer.
r13 - Supplies the address of the column sum buffer.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
//
// Initialize the accumulators with the sum of the global depth value constant,
// the column sums, and the row sums.
//
vpbroadcastd ymm1,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
.if \ColumnCount\() == 16
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
add r13,16*4 # advance ColumnSumBuffer by 16 columns
.else
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
//
// Iterate over the length of a matrix A row to produce the output accumulators.
//
.if \RowCount\() > 3
lea rbx,[rcx*2+rcx]
add rbx,rdi # compute matrix A plus 3 rows
.endif
ComputeBlockLoop \ColumnCount\(), \RowCount\()
.if \RowCount\() > 3
lea rbx,[rax*2+rax]
add rbx,rdx # compute matrix C plus 3 rows
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Fallthrough - Supplies a non-blank value if the macro may fall through to
the ExitKernel label.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
r11 - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the length in bytes of a row from matrix A.
r10b - Supplies the zero mode flag.
r12 - Supplies the address of the row sum buffer.
r13 - Supplies the address of the column sum buffer.
--*/
.macro ProcessCountM RowCount, Fallthrough
cmp r9,8
jbe .LProcessRemainingCountN\@
.LProcessNextColumnLoop16xN\@:
ProduceOutputBlock 16, \RowCount\()
sub r9,16
jb .LOutputMasked16xNBlock\@
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]"
.LSkipAccumulateOutput16xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,r11 # reload matrix A
cmp r9,8
ja .LProcessNextColumnLoop16xN\@
test r9,r9
jz .LExitKernel
.LProcessRemainingCountN\@:
ProduceOutputBlock 8, \RowCount\()
cmp r9,8
jb .LOutputMasked8xNBlock\@
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput8xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput8xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15"
jmp .LExitKernel
.LOutputMasked16xNBlock\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutputMasked16xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutputMasked16xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
add rdx,8*4 # advance matrix C by 8 columns
.if \RowCount\() > 3
add rbx,8*4 # advance matrix C plus 3 rows by 8 columns
.endif
add r9,8 # correct for over-subtract above
.LOutputMasked8xNBlock\@:
mov DWORD PTR .LGemmU8X8KernelFrame_mask[rsp],r9d
vpbroadcastd ymm0,DWORD PTR .LGemmU8X8KernelFrame_mask[rsp]
vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutputMasked8xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
.LSkipAccumulateOutputMasked8xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15"
.ifb \Fallthrough\()
jmp .LExitKernel
.endif
.endm

View file

@ -1,403 +0,0 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelAvx512Common.h
Abstract:
This module contains common kernel macros and structures for the quantized
integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
AVX512VNNI kernels.
--*/
//
// Stack frame layout for the U8S8 and U8U8 kernels.
//
.equ .LGemmU8X8KernelFrame_SavedR14, 0
.equ .LGemmU8X8KernelFrame_SavedR13, 8
.equ .LGemmU8X8KernelFrame_SavedR12, 16
.equ .LGemmU8X8KernelFrame_SavedRbx, 24
.equ .LGemmU8X8KernelFrame_SavedRbp, 32
.equ .LGemmU8X8KernelFrame_ReturnAddress, 40
.equ .LGemmU8X8KernelFrame_ldc, 48
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 56
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 64
.equ .LGemmU8X8KernelFrame_DepthValue, 72
.equ .LGemmU8X8KernelFrame_ZeroMode, 80
/*++
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r12 - Supplies the address of the row sum buffer.
r13 - Supplies the address of the column sum buffer.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
//
// Initialize the accumulators with the sum of the global depth value constant,
// the column sums, and the row sums.
//
vpbroadcastd zmm3,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
.if \ColumnCount\() >= 32
.if \ColumnCount\() >= 48
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
.else
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
.endif
add_immed r13,\ColumnCount\()*4 # advance ColumnSumBuffer by N columns
.else
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
.endif
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r12]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r12]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r12]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r12+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r12+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r12+4]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r12+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r12+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r12+8]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r12+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r12+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r12+12]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r12+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r12+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r12+16]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r12+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r12+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r12+20]{1to16}"
//
// Iterate over the length of a matrix A row to produce the output accumulators.
//
.if \RowCount\() > 3
lea rbx,[rcx*2+rcx]
add rbx,rdi # compute matrix A plus 3 rows
.endif
ComputeBlockLoop \ColumnCount\(), \RowCount\()
.if \RowCount\() > 3
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
add rbx,rax
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
r11 - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the length in bytes of a row from matrix A.
r10b - Supplies the zero mode flag.
r12 - Supplies the address of the row sum buffer.
r13 - Supplies the address of the column sum buffer.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
--*/
.macro ProcessCountM RowCount
cmp r9,32
ja .LProcessNextColumnLoop48xN\@
cmp r9,16
jbe .LProcessRemainingCountN\@
.LProcessNextColumnLoop32xN\@:
ProduceOutputBlock 32, \RowCount\()
add rsi,r14 # advance matrix B by packed block stride
.LOutput32xNBlock\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput32xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput32xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm23"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25"
add rdx,16*4 # advance matrix C by 16 columns
.if \RowCount\() > 3
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
.endif
sub r9,16
.LOutput16xNBlock\@:
sub r9,16
jae .LOutput16xNBlockWithMask\@
lea rcx,[r9+16] # correct for over-subtract above
mov ebp,1
shl ebp,cl
dec ebp
kmovw k1,ebp # update mask for remaining columns
xor r9,r9 # no more columns remaining
.LOutput16xNBlockWithMask\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlockWithMask\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput16xNBlockWithMask\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,r11 # reload matrix A
cmp r9,32
ja .LProcessNextColumnLoop48xN\@
cmp r9,16
ja .LProcessNextColumnLoop32xN\@
test r9,r9
jz .LExitKernel
.LProcessRemainingCountN\@:
ProduceOutputBlock 16, \RowCount\()
jmp .LOutput16xNBlock\@
.LProcessNextColumnLoop48xN\@:
ProduceOutputBlock 48, \RowCount\()
lea rsi,[rsi+r14*2] # advance matrix B by packed block stride
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput48xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput48xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm29"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31"
add rdx,16*4 # advance matrix C by 16 columns
.if \RowCount\() > 3
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
.endif
sub r9,16
jmp .LOutput32xNBlock\@
.endm
/*++
Macro Description:
This macro generates the common AVX512 code for the inner kernel to compute
matrix multiplication.
Arguments:
Type - Supplies the kernel type string for function tags.
Isa - Supplies the instruction set architecture string for function tags.
--*/
.macro GemmU8X8KernelAvx512Function Type, Isa
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PackedCountK (rcx) - Supplies the number of packed columns from matrix A and
the number of packed rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrixA multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
.globl C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\())
C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()):
push rbp
push rbx
push r12
push r13
push r14
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,rdi
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
mov ebp,-1
kmovw k1,ebp # update mask to write all columns
.ifeqs "\Type\()", "U8S8"
.ifeqs "\Isa\()", "Avx512Core"
neg ebp
vpbroadcastw zmm5,ebp # generate 512-bit word vector [0x0001]
.endif
mov r14,rcx
shl r14,4 # compute matrix B packed stride
.else
lea r14,[rcx*8] # compute matrix B packed stride
.endif
//
// Process CountM rows of the matrices.
//
cmp r8,5
ja .LProcessCountM6
je .LProcessCountM5
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
ProcessCountM 4
.LProcessCountM6:
mov r8d,6 # return 6 rows handled
ProcessCountM 6
//
// Restore non-volatile registers and return.
//
.LExitKernel:
mov eax,r8d
vzeroupper
pop r14
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.LProcessCountM5:
ProcessCountM 5
.endm

View file

@ -0,0 +1,709 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelAvx512Core.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions.
--*/
#include "asmmacro.h"
#include "AssembleAvx512Vnni.h"
.intel_syntax noprefix
//
// Stack frame layout for the U8X8 kernel.
//
.equ .LGemmU8X8KernelFrame_type, -8
.equ .LGemmU8X8KernelFrame_SavedR14, 0
.equ .LGemmU8X8KernelFrame_SavedR13, 8
.equ .LGemmU8X8KernelFrame_SavedR12, 16
.equ .LGemmU8X8KernelFrame_SavedRbx, 24
.equ .LGemmU8X8KernelFrame_SavedRbp, 32
.equ .LGemmU8X8KernelFrame_ReturnAddress, 40
.equ .LGemmU8X8KernelFrame_ldc, 48
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 56
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 64
.equ .LGemmU8X8KernelFrame_ZeroPointB, 72
.equ .LGemmU8X8KernelFrame_ZeroMode, 80
.text
/*++
Macro Description:
This macro generates code to load packed data from matrix B.
Arguments:
VecReg - Supplies the register to load the data into.
AddressOperand - Supplies the address operand.
--*/
.macro LoadPackedMatrixBU8S8 VecReg, AddressOperand
vmovdqu32 \VecReg\(),ZMMWORD PTR \AddressOperand\()
.endm
.macro LoadPackedMatrixBU8U8 VecReg, AddressOperand
vpmovzxbw \VecReg\(),YMMWORD PTR \AddressOperand\()
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulator a single cell of the
output block.
Arguments:
AccumReg - Supplies the register to accumulate into.
Mult1Reg - Supplies the first multiplication operand register.
Mult2Reg - Supplies the second multiplication operand register.
Implicit Arguments:
zmm4 - Supplies a scratch register for intermediate results.
zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001.
--*/
.macro MultiplyAccumulateCellU8S8Avx512Core AccumReg, Mult1Reg, Mult2Reg
vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\()
vpmaddwd zmm4,zmm4,zmm13
vpaddd \AccumReg\(),\AccumReg\(),zmm4
.endm
.macro MultiplyAccumulateCellU8S8Avx512Vnni AccumReg, Mult1Reg, Mult2Reg
VpdpbusdsZmmZmmZmm \AccumReg\(),\Mult1Reg\(),\Mult2Reg\()
.endm
.macro MultiplyAccumulateCellU8U8Avx512Core AccumReg, Mult1Reg, Mult2Reg
vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\()
vpaddd \AccumReg\(),\AccumReg\(),zmm4
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rdi - Supplies the address into the matrix A data.
r8 - Supplies the address into the matrix A data plus 3 rows.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \ColumnCount\() >= 48
LoadPackedMatrixB\Type\() zmm0,"[rsi+\VectorOffset\()]"
LoadPackedMatrixB\Type\() zmm1,"[rsi+r14+\VectorOffset\()]"
LoadPackedMatrixB\Type\() zmm2,"[rsi+r14*2+\VectorOffset\()]"
.elseif \ColumnCount\() >= 32
LoadPackedMatrixB\Type\() zmm1,"[rsi+\VectorOffset\()]"
LoadPackedMatrixB\Type\() zmm2,"[rsi+r14+\VectorOffset\()]"
.else
LoadPackedMatrixB\Type\() zmm2,"[rsi+\VectorOffset\()]"
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [r8+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [r8+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm19,zmm3,zmm2"
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple times
and advancing the matrix A and matrix B data pointers.
Arguments:
Isa - Supplies the instruction set architecture string.
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rdi - Supplies the address into the matrix A data.
r8 - Supplies the address into the matrix A data plus 3 rows.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoopU8S8 Isa, ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.if ((\RowCount\() & 1) == 0)
sub rbp,4*4
jb .LProcessRemainingBlocks\@
.LComputeBlockBy4Loop\@:
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0*64, 0
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 1*64, 4
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 2*64, 8
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 3*64, 12
add rdi,4*4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add r8,4*4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,4*64 # advance matrix B
sub rbp,4*4 # decrement quads remaining
jae .LComputeBlockBy4Loop\@
.LProcessRemainingBlocks\@:
add rbp,4*4 # correct for over-subtract above
jz .LComputeBlockLoopExit\@
.endif
.LComputeBlockBy1Loop\@:
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add r8,4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,64 # advance matrix B
sub rbp,4 # decrement quads remaining
jnz .LComputeBlockBy1Loop\@
.LComputeBlockLoopExit\@:
.endm
.macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.LComputeBlockBy1Loop\@:
ComputeBlock U8U8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 pair
.if \RowCount\() > 3
add r8,4 # advance matrix A plus 3 rows by 1 pair
.endif
add rsi,32 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endm
/*++
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r11 - Supplies the address of the row sum buffer.
r12 - Supplies the address of the column sum buffer.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
//
// Initialize the accumulators with the row and column sums.
//
.if \ColumnCount\() >= 32
.if \ColumnCount\() >= 48
vmovdqu32 zmm2,ZMMWORD PTR [r12]
vmovdqu32 zmm1,ZMMWORD PTR [r12+64]
vmovdqu32 zmm0,ZMMWORD PTR [r12+128]
.else
vmovdqu32 zmm1,ZMMWORD PTR [r12]
vmovdqu32 zmm0,ZMMWORD PTR [r12+64]
.endif
add_immed r12,\ColumnCount\()*4 # advance ColumnSumBuffer by N columns
.else
vmovdqu32 zmm0,ZMMWORD PTR [r12]
.endif
test r13,r13 # per column zero points?
jz .LSkipScaleByZeroPointB\@
.if \ColumnCount\() >= 32
.if \ColumnCount\() >= 48
vmovdqu32 zmm5,ZMMWORD PTR [r13]
vmovdqu32 zmm4,ZMMWORD PTR [r13+64]
vmovdqu32 zmm3,ZMMWORD PTR [r13+128]
.else
vmovdqu32 zmm4,ZMMWORD PTR [r13]
vmovdqu32 zmm3,ZMMWORD PTR [r13+64]
.endif
add_immed r13,\ColumnCount\()*4 # advance ZeroPointB by N columns
.else
vmovdqu32 zmm3,ZMMWORD PTR [r13]
.endif
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld zmm14,zmm3,DWORD PTR [r11]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld zmm20,zmm4,DWORD PTR [r11]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpmulld zmm26,zmm5,DWORD PTR [r11]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,zmm14"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,zmm20"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,zmm26"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld zmm15,zmm3,DWORD PTR [r11+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpmulld zmm21,zmm4,DWORD PTR [r11+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpmulld zmm27,zmm5,DWORD PTR [r11+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,zmm15"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,zmm21"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,zmm27"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld zmm16,zmm3,DWORD PTR [r11+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpmulld zmm22,zmm4,DWORD PTR [r11+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpmulld zmm28,zmm5,DWORD PTR [r11+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,zmm16"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,zmm22"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,zmm28"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld zmm17,zmm3,DWORD PTR [r11+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpmulld zmm23,zmm4,DWORD PTR [r11+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpmulld zmm29,zmm5,DWORD PTR [r11+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,zmm17"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,zmm23"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,zmm29"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld zmm18,zmm3,DWORD PTR [r11+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpmulld zmm24,zmm4,DWORD PTR [r11+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpmulld zmm30,zmm5,DWORD PTR [r11+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,zmm18"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,zmm24"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,zmm30"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld zmm19,zmm3,DWORD PTR [r11+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpmulld zmm25,zmm4,DWORD PTR [r11+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpmulld zmm31,zmm5,DWORD PTR [r11+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,zmm19"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,zmm25"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,zmm31"
jmp .LAccumulatorsInitialized\@
.LSkipScaleByZeroPointB\@:
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r11]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r11]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r11]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r11+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r11+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r11+4]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r11+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r11+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r11+8]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r11+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r11+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r11+12]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r11+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r11+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r11+16]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r11+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r11+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r11+20]{1to16}"
.LAccumulatorsInitialized\@:
//
// Iterate over the length of a matrix A row to produce the output accumulators.
//
.if \RowCount\() > 3
lea r8,[rcx*2+rcx]
add r8,rdi # compute matrix A plus 3 rows
.endif
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
je .LProduceWithU8S8Avx512Core\@
jg .LProduceWithU8U8Avx512Core\@
ComputeBlockLoopU8S8 Avx512Vnni, \ColumnCount\(), \RowCount\()
jmp .LExitProduceOutputBlock\@
.LProduceWithU8U8Avx512Core\@:
ComputeBlockLoopU8U8 Avx512Core, \ColumnCount\(), \RowCount\()
jmp .LExitProduceOutputBlock\@
.LProduceWithU8S8Avx512Core\@:
ComputeBlockLoopU8S8 Avx512Core, \ColumnCount\(), \RowCount\()
.LExitProduceOutputBlock\@:
.if \RowCount\() > 3
lea r8,[rax*2+rax]
add r8,rdx # compute matrix C plus 3 rows
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
rbx - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the length in bytes of a row from matrix A.
r10b - Supplies the zero mode flag.
r11 - Supplies the address of the row sum buffer.
r12 - Supplies the address of the column sum buffer.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
--*/
.macro ProcessCountM RowCount
cmp r9,32
ja .LProcessNextColumnLoop48xN\@
cmp r9,16
jbe .LProcessRemainingCountN\@
.LProcessNextColumnLoop32xN\@:
ProduceOutputBlock 32, \RowCount\()
add rsi,r14 # advance matrix B by packed block stride
.LOutput32xNBlock\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput32xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [r8]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [r8+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [r8+rax*2]"
.LSkipAccumulateOutput32xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm23"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm24"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm25"
add rdx,16*4 # advance matrix C by 16 columns
.if \RowCount\() > 3
add r8,16*4 # advance matrix C plus 3 rows by 16 columns
.endif
sub r9,16
.LOutput16xNBlock\@:
sub r9,16
jae .LOutput16xNBlockWithMask\@
lea rcx,[r9+16] # correct for over-subtract above
mov ebp,1
shl ebp,cl
dec ebp
kmovw k1,ebp # update mask for remaining columns
xor r9,r9 # no more columns remaining
.LOutput16xNBlockWithMask\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlockWithMask\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [r8]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [r8+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [r8+rax*2]"
.LSkipAccumulateOutput16xNBlockWithMask\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8]{k1},zmm17"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm18"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm19"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,rbx # reload matrix A
cmp r9,32
ja .LProcessNextColumnLoop48xN\@
cmp r9,16
ja .LProcessNextColumnLoop32xN\@
test r9,r9
jnz .LProcessRemainingCountN\@
mov eax,\RowCount\()
jmp .LExitKernel
.LProcessRemainingCountN\@:
ProduceOutputBlock 16, \RowCount\()
jmp .LOutput16xNBlock\@
.LProcessNextColumnLoop48xN\@:
ProduceOutputBlock 48, \RowCount\()
lea rsi,[rsi+r14*2] # advance matrix B by packed block stride
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput48xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [r8]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [r8+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [r8+rax*2]"
.LSkipAccumulateOutput48xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm29"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm30"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm31"
add rdx,16*4 # advance matrix C by 16 columns
.if \RowCount\() > 3
add r8,16*4 # advance matrix C plus 3 rows by 16 columns
.endif
sub r9,16
jmp .LOutput32xNBlock\@
.endm
//
// Reduce code size for the various types of kernels by sharing the outer logic
// and switching on the selector codes (using sign bit to discriminate).
//
FUNCTION_ENTRY MlasGemmU8U8KernelAvx512Core
mov eax,1
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core)
FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Core
xor eax,eax
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core)
FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Vnni
mov eax,-1
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core)
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PackedCountK (rcx) - Supplies the number of packed columns from matrix A and
the number of packed rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumBuffer - Supplies the sum of each row from matrix A. These values have
been pre-scaled by the zero point offset of matrix B if the offset is
per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
scaled by the per-column zero point offsets of matrix B. These values are
accumulated into every row of matrix C.
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
B, else nullptr if the matrix B is using per-tensor quantization.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
FUNCTION_ENTRY MlasGemmU8X8KernelAvx512Core
push rbp
push rbx
push r12
push r13
push r14
mov DWORD PTR .LGemmU8X8KernelFrame_type[rsp],eax
mov rbx,rdi
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
mov r12,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
mov r13,.LGemmU8X8KernelFrame_ZeroPointB[rsp]
mov ebp,-1
kmovw k1,ebp # update mask to write all columns
neg ebp
vpbroadcastw zmm13,ebp # generate 512-bit word vector [0x0001]
lea rbp,[rcx*8]
lea r14,[rbp*2]
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
cmovg r14,rbp # select matrix B packed stride
//
// Process CountM rows of the matrices.
//
cmp r8,5
ja .LProcessCountM6
je .LProcessCountM5
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
ProcessCountM 4
.LProcessCountM6:
ProcessCountM 6
//
// Restore non-volatile registers and return.
//
.LExitKernel:
vzeroupper
pop r14
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.LProcessCountM5:
ProcessCountM 5
.end

View file

@ -22,6 +22,32 @@ Abstract:
/*++
Macro Description:
This macro emits the assembler directives to annotate a new function.
Arguments:
FunctionName - Supplies the name of the function.
--*/
.macro FUNCTION_ENTRY FunctionName
.p2align 4
#if defined(__APPLE__)
.globl _\FunctionName\()
_\FunctionName\():
#else
.globl \FunctionName\()
.type \FunctionName\(),@function
\FunctionName\():
#endif
.endm
/*++
Macro Description:
This macro generates an optimization for "add reg,128" which can instead

View file

@ -29,8 +29,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
MatMulInteger);
Status MatMulInteger::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const auto* a = ctx->Input<Tensor>(0);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
@ -58,49 +56,33 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
b_offset = *static_cast<const uint8_t*>(b_zero_point->DataRaw());
}
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
gemm_params.M = static_cast<size_t>(helper.M());
gemm_params.N = static_cast<size_t>(helper.N());
gemm_params.K = static_cast<size_t>(helper.K());
gemm_params.lda = gemm_params.K;
gemm_params.ZeroPointA = a_offset;
gemm_params.ldb = gemm_params.N;
gemm_params.ZeroPointB = &b_offset;
gemm_params.ldc = gemm_params.N;
const auto* a_data = a->template Data<uint8_t>();
auto* y_data = y->template MutableData<int32_t>();
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (packed_b_) {
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
a_offset,
packed_b_.get(),
b_offset,
b_is_signed_,
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
thread_pool);
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
gemm_params.A = a_data + helper.LeftOffsets()[i];
if (packed_b_) {
gemm_params.B = packed_b_.get();
gemm_params.BIsPacked = true;
gemm_params.BIsSigned = b_is_signed_;
} else if (b != nullptr) {
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + + helper.RightOffsets()[i];
gemm_params.BIsSigned = b->IsDataType<int8_t>();
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null.");
}
return Status::OK();
}
#endif
if (b != nullptr) {
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
const bool b_is_signed = b->IsDataType<int8_t>();
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
a_offset,
b_data + helper.RightOffsets()[i],
static_cast<size_t>(helper.N()),
b_offset,
b_is_signed,
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
thread_pool);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null.");
gemm_params.C = y_data + helper.OutputOffsets()[i];
MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());
}
return Status::OK();

View file

@ -11,7 +11,6 @@ class MatMulIntegerBase : public OpKernel {
public:
MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {}
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override {
is_packed = false;
@ -43,7 +42,6 @@ class MatMulIntegerBase : public OpKernel {
}
return Status::OK();
}
#endif
protected:
bool b_is_signed_{true};

View file

@ -76,20 +76,22 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
gemm_params.M = static_cast<size_t>(helper.M());
gemm_params.N = static_cast<size_t>(helper.N());
gemm_params.K = static_cast<size_t>(helper.K());
gemm_params.lda = gemm_params.K;
gemm_params.ZeroPointA = *a_offset->template Data<uint8_t>();
gemm_params.ldb = gemm_params.N;
gemm_params.ZeroPointB = static_cast<const uint8_t*>(b_offset->DataRaw());
gemm_params.BIsSigned = b->IsDataType<int8_t>();
gemm_params.C = gemm_output;
gemm_params.ldc = gemm_params.N;
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
*a_offset->template Data<uint8_t>(),
static_cast<const uint8_t*>(b->DataRaw()) + helper.RightOffsets()[i],
static_cast<size_t>(helper.N()),
*static_cast<const uint8_t*>(b_offset->DataRaw()),
b->IsDataType<int8_t>(),
gemm_output,
static_cast<size_t>(helper.N()),
ctx->GetOperatorThreadPool());
gemm_params.A = a->template Data<uint8_t>() + helper.LeftOffsets()[i];
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + helper.RightOffsets()[i];
MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());
MlasRequantizeOutput(gemm_output,
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],

View file

@ -149,19 +149,19 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
}
}
MlasGemm(static_cast<size_t>(M / conv_attrs_.group),
static_cast<size_t>(output_image_size),
static_cast<size_t>(kernel_dim),
Wdata + group_id * W_offset,
static_cast<size_t>(kernel_dim),
filter_offset,
col_buffer_data == nullptr ? Xdata : col_buffer_data,
static_cast<size_t>(output_image_size),
input_offset,
false,
Ydata,
static_cast<size_t>(output_image_size),
thread_pool);
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
gemm_params.M = static_cast<size_t>(M / conv_attrs_.group);
gemm_params.N = static_cast<size_t>(output_image_size);
gemm_params.K = static_cast<size_t>(kernel_dim);
gemm_params.A = Wdata + group_id * W_offset;
gemm_params.lda = static_cast<size_t>(kernel_dim);
gemm_params.ZeroPointA = filter_offset;
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
gemm_params.ldb = static_cast<size_t>(output_image_size);
gemm_params.ZeroPointB = &input_offset;
gemm_params.C = Ydata;
gemm_params.ldc = static_cast<size_t>(output_image_size);
MlasGemm(&gemm_params, thread_pool);
Xdata += X_offset;
Ydata += Y_offset;

View file

@ -43,10 +43,8 @@ class QLinearConv : public OpKernel {
ConvAttributes conv_attrs_;
TensorShape W_shape_;
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
BufferUniquePtr packed_W_buffer_;
size_t packed_W_size_;
#endif
BufferUniquePtr reordered_W_buffer_;
bool is_W_signed_;
bool is_W_packed_;
@ -116,7 +114,6 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed
auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
const size_t group_count = static_cast<size_t>(conv_attrs_.group);
const size_t group_output_channels = output_channels / group_count;
const size_t kernel_dim = group_input_channels * kernel_size;
@ -151,7 +148,6 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed
return Status::OK();
}
}
#endif
auto* reordered_W = static_cast<uint8_t*>(alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * output_channels * group_input_channels * kernel_size));
reordered_W_buffer_ = BufferUniquePtr(reordered_W, BufferDeleter(alloc));
@ -279,13 +275,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
// Handle the case of a dynamic weight filter.
BufferUniquePtr reordered_W_buffer;
uint8_t* reordered_W = nullptr;
bool use_reordered_W = true;
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (packed_W_buffer_) {
use_reordered_W = false;
}
#endif
if (use_reordered_W) {
if (!packed_W_buffer_) {
if (W == nullptr) {
// Weight was constant and reordered.
reordered_W = static_cast<uint8_t*>(reordered_W_buffer_.get());
@ -307,7 +297,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
int64_t group_output_channels = M / group_count;
// Test for depthwise convolution.
const bool is_depthwise_conv = (use_reordered_W && group_input_channels == 1 && group_output_channels == 1);
const bool is_depthwise_conv = (reordered_W != nullptr && group_input_channels == 1 && group_output_channels == 1);
if (is_depthwise_conv) {
// Update the input and output channels to the number of groups in order to
// reuse as much of the below standard convolution path.
@ -510,39 +500,25 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
worker_gemm_input = input_data + output_start * kernel_dim;
}
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
gemm_params.M = static_cast<size_t>(output_count);
gemm_params.N = static_cast<size_t>(group_output_channels);
gemm_params.K = static_cast<size_t>(kernel_dim);
gemm_params.A = worker_gemm_input;
gemm_params.lda = static_cast<size_t>(kernel_dim);
gemm_params.ZeroPointA = X_zero_point_value;
if (packed_W_buffer_) {
MlasGemm(
static_cast<size_t>(output_count),
static_cast<size_t>(group_output_channels),
static_cast<size_t>(kernel_dim),
worker_gemm_input,
static_cast<size_t>(kernel_dim),
X_zero_point_value,
static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
W_zero_point_value,
is_W_signed,
worker_gemm_output + group_id * group_output_channels,
static_cast<size_t>(M),
nullptr);
} else
#endif
{
MlasGemm(
static_cast<size_t>(output_count),
static_cast<size_t>(group_output_channels),
static_cast<size_t>(kernel_dim),
worker_gemm_input,
static_cast<size_t>(kernel_dim),
X_zero_point_value,
reordered_W + group_id * group_output_channels,
static_cast<size_t>(M),
W_zero_point_value,
is_W_signed,
worker_gemm_output + group_id * group_output_channels,
static_cast<size_t>(M),
nullptr);
gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
gemm_params.BIsPacked = true;
} else {
gemm_params.B = reordered_W + group_id * group_output_channels,
gemm_params.ldb = static_cast<size_t>(M);
}
gemm_params.ZeroPointB = &W_zero_point_value;
gemm_params.BIsSigned = is_W_signed;
gemm_params.C = worker_gemm_output + group_id * group_output_channels;
gemm_params.ldc = static_cast<size_t>(M);
MlasGemm(&gemm_params, nullptr);
}
}

View file

@ -286,39 +286,23 @@ void ComputeGemm(const int M,
C, ldc, scale_multiplier.data(), nullptr,
beta == 1.0f ? MLAS_QGEMM_OUTPUT_MODE::AccumulateMode : MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
scale_multiplier.size() == 1 ? MLAS_QUANTIZATION_GRANULARITY::PerMatrix : MLAS_QUANTIZATION_GRANULARITY::PerColumn);
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (weights.is_prepacked_) {
MlasGemm(static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
a_data_quant,
static_cast<size_t>(K),
a_zero_point,
weights.buffer_,
b_zero_point,
b_is_signed,
C_buffer,
ld_C_buffer,
thread_pool,
&output_processor);
return;
}
#endif
MlasGemm(static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
a_data_quant,
static_cast<size_t>(K),
a_zero_point,
static_cast<const uint8_t*>(weights.buffer_),
static_cast<size_t>(N),
b_zero_point,
b_is_signed,
C_buffer,
ld_C_buffer,
thread_pool,
&output_processor);
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
gemm_params.M = static_cast<size_t>(M);
gemm_params.N = static_cast<size_t>(N);
gemm_params.K = static_cast<size_t>(K);
gemm_params.A = a_data_quant;
gemm_params.lda = static_cast<size_t>(K);
gemm_params.ZeroPointA = a_zero_point;
gemm_params.B = weights.buffer_;
gemm_params.ldb = static_cast<size_t>(N);
gemm_params.ZeroPointB = &b_zero_point;
gemm_params.BIsPacked = weights.is_prepacked_;
gemm_params.BIsSigned = b_is_signed;
gemm_params.C = C_buffer;
gemm_params.ldc = ld_C_buffer;
gemm_params.OutputProcessor = &output_processor;
MlasGemm(&gemm_params, thread_pool);
}
namespace deepcpu {

View file

@ -537,63 +537,7 @@ public:
};
template<bool Packed>
class MlasQgemmU8X8U8X8TestBase;
template<>
class MlasQgemmU8X8U8X8TestBase<false> : public MlasTestBase
{
protected:
void
TestGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
uint8_t offb,
bool BIsSigned,
int32_t* C,
size_t ldc
)
{
MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc, threadpool);
}
void
TestGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
uint8_t offb,
bool BIsSigned,
float* C,
size_t ldc,
float CScale,
const float* Bias
)
{
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
MlasGemm(M, N, K,
A, lda, offa,
B, ldb, offb, BIsSigned,
reinterpret_cast<int32_t*>(C), ldc,
threadpool,
&scale_bias_processor);
}
};
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
template<>
class MlasQgemmU8X8U8X8TestBase<true> : public MlasTestBase
class MlasQgemmU8X8U8X8TestBase : public MlasTestBase
{
private:
void*
@ -628,8 +572,69 @@ protected:
size_t ldc
)
{
const void* PackedB = PackB(N, K, B, ldb, BIsSigned);
MlasGemm(M, N, K, A, lda, offa, PackedB, offb, BIsSigned, C, ldc, threadpool);
MLAS_GEMM_U8X8_PARAMETERS GemmParameters;
GemmParameters.M = M;
GemmParameters.N = N;
GemmParameters.K = K;
GemmParameters.A = A;
GemmParameters.lda = lda;
GemmParameters.ZeroPointA = offa;
GemmParameters.ZeroPointB = &offb;
GemmParameters.BIsSigned = BIsSigned;
GemmParameters.C = C;
GemmParameters.ldc = ldc;
if (Packed) {
GemmParameters.B = PackB(N, K, B, ldb, BIsSigned);
GemmParameters.BIsPacked = true;
} else {
GemmParameters.B = B;
GemmParameters.ldb = ldb;
}
MlasGemm(&GemmParameters, threadpool);
}
void
TestGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
const uint8_t* offb,
bool BIsSigned,
int32_t* C,
size_t ldc
)
{
MLAS_GEMM_U8X8_PARAMETERS GemmParameters;
GemmParameters.M = M;
GemmParameters.N = N;
GemmParameters.K = K;
GemmParameters.A = A;
GemmParameters.lda = lda;
GemmParameters.ZeroPointA = offa;
GemmParameters.ZeroPointB = offb;
GemmParameters.BIsSigned = BIsSigned;
GemmParameters.PerColumnZeroPoints = true;
GemmParameters.C = C;
GemmParameters.ldc = ldc;
if (Packed) {
GemmParameters.B = PackB(N, K, B, ldb, BIsSigned);
GemmParameters.BIsPacked = true;
} else {
GemmParameters.B = B;
GemmParameters.ldb = ldb;
}
MlasGemm(&GemmParameters, threadpool);
}
void
@ -650,22 +655,37 @@ protected:
const float* Bias
)
{
const void* PackedB = PackB(N, K, B, ldb, BIsSigned);
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
MlasGemm(M, N, K,
A, lda, offa,
PackedB, offb, BIsSigned,
reinterpret_cast<int32_t*>(C), ldc,
threadpool,
&scale_bias_processor);
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR ScaleBiasProcessor(C, ldc, &CScale, Bias);
MLAS_GEMM_U8X8_PARAMETERS GemmParameters;
GemmParameters.M = M;
GemmParameters.N = N;
GemmParameters.K = K;
GemmParameters.A = A;
GemmParameters.lda = lda;
GemmParameters.ZeroPointA = offa;
GemmParameters.ZeroPointB = &offb;
GemmParameters.BIsSigned = BIsSigned;
GemmParameters.C = reinterpret_cast<int32_t*>(C);
GemmParameters.ldc = ldc;
GemmParameters.OutputProcessor = &ScaleBiasProcessor;
if (Packed) {
GemmParameters.B = PackB(N, K, B, ldb, BIsSigned);
GemmParameters.BIsPacked = true;
} else {
GemmParameters.B = B;
GemmParameters.ldb = ldb;
}
MlasGemm(&GemmParameters, threadpool);
}
private:
MatrixGuardBuffer<uint8_t> BufferBPacked;
};
#endif
template<typename xint8_t, typename OutputType, bool Packed>
class MlasQgemmU8X8Test;
@ -690,6 +710,23 @@ private:
Test(M, N, K, A, K, offa, B, N, offb, C, CReference, N);
}
void
Test(
size_t M,
size_t N,
size_t K,
uint8_t offa
)
{
const uint8_t* A = BufferA.GetBuffer(K * M);
const uint8_t* B = BufferB.GetBuffer(N * K);
const uint8_t* ZeroPointB = BufferZeroPointB.GetBuffer(N);
int32_t* C = BufferC.GetBuffer(N * M);
int32_t* CReference = BufferCReference.GetBuffer(N * M);
Test(M, N, K, A, K, offa, B, N, ZeroPointB, C, CReference, N);
}
void
Test(
size_t M,
@ -720,6 +757,36 @@ private:
}
}
void
Test(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
const uint8_t *offb,
int32_t* C,
int32_t* CReference,
size_t ldc
)
{
std::fill_n(C, M * N, -1);
std::fill_n(CReference, M * N, -1);
this->TestGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc);
ReferenceQgemm(M, N, K, A, lda, offa, (const xint8_t*)B, ldb, (const xint8_t*)offb, CReference, ldc);
for (size_t f = 0; f < M * N; f++) {
if (C[f] != CReference[f]) {
printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d!\n", M, N, K, int(offa));
break;
}
}
}
void
ReferenceQgemm(
size_t M,
@ -755,8 +822,44 @@ private:
}
}
void
ReferenceQgemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const xint8_t* B,
size_t ldb,
const xint8_t* offb,
int32_t* C,
size_t ldc
)
{
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const uint8_t* a = A + (m * lda);
const xint8_t* b = B + n;
int32_t* c = C + (m * ldc) + n;
int32_t sum = 0;
for (size_t k = 0; k < K; k++) {
sum += ((int32_t(*b) - offb[n]) * (int32_t(*a) - offa));
b += ldb;
a += 1;
}
*c = sum;
}
}
}
MatrixGuardBuffer<uint8_t> BufferA;
MatrixGuardBuffer<uint8_t> BufferB;
MatrixGuardBuffer<uint8_t> BufferZeroPointB;
MatrixGuardBuffer<int32_t> BufferC;
MatrixGuardBuffer<int32_t> BufferCReference;
const bool BIsSigned = std::is_signed<xint8_t>::value;
@ -769,12 +872,15 @@ public:
{
for (size_t b = 1; b < 16; b++) {
Test(b, b, b, 14, 211);
Test(b, b, b, 21);
}
for (size_t b = 1; b < 16; b++) {
Test(b, b, b, 14, 211);
Test(b, b, b, 17);
}
for (size_t b = 16; b <= 256; b <<= 1) {
Test(b, b, b, 34, 1);
Test(b, b, b, 1);
}
for (size_t b = 256; b < 320; b += 32) {
Test(b, b, b, 85, 173);
@ -786,6 +892,7 @@ public:
}
Test(43, 500, 401, 183, 223);
Test(1023, 1023, 1023, 5, 8);
Test(1023, 1023, 1023, 7);
}
void
@ -3130,7 +3237,6 @@ RunThreadedTests(
printf("QGEMM U8U8=float tests.\n");
onnxruntime::make_unique<MlasQgemmU8X8Test<uint8_t, float, false>>()->ExecuteShort();
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (MlasGemmPackBSize(128, 128, true) > 0) {
printf("QGEMM U8S8=int32_t packed tests.\n");
onnxruntime::make_unique<MlasQgemmU8X8Test<int8_t, int32_t, true>>()->ExecuteShort();
@ -3143,7 +3249,6 @@ RunThreadedTests(
printf("QGEMM U8U8=float packed tests.\n");
onnxruntime::make_unique<MlasQgemmU8X8Test<uint8_t, float, true>>()->ExecuteShort();
}
#endif
printf("Conv2D tests.\n");
onnxruntime::make_unique<MlasConv2DTest>()->ExecuteShort();