mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
MLAS: quantized GEMM update (#6916)
Various updates to the int8_t GEMMs: 1) Add ARM64 udot kernel to take advantage of dot product instructions available in newer cores. Some models run 4x faster than the stock implementation we used before. 2) Refactor the x64 kernels to share common code for AVX2(u8u8/u8s8/avxvnni) vs AVX512(u8u8/u8s8/avx512vnni) to reduce binary size. 3) Extend kernels to support per-column zero points for matrix B. This is not currently wired to an operator.
This commit is contained in:
parent
bc319bd7aa
commit
a8b897f710
45 changed files with 6646 additions and 5723 deletions
|
|
@ -28,6 +28,7 @@ if(MSVC)
|
|||
if(onnxruntime_target_platform STREQUAL "ARM64")
|
||||
set(mlas_platform_preprocess_srcs
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/SgemmKernelNeon.asm
|
||||
)
|
||||
|
||||
|
|
@ -81,15 +82,13 @@ if(MSVC)
|
|||
${mlas_platform_srcs_avx2}
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Core.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvxVnni.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Core.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelSse2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelAvx.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelFma3.asm
|
||||
|
|
@ -183,6 +182,7 @@ else()
|
|||
|
||||
set(mlas_platform_srcs
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemvKernelNeon.S
|
||||
)
|
||||
|
|
@ -246,8 +246,8 @@ else()
|
|||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx2.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvxVnni.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvxVnni.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/DgemmKernelFma3.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S
|
||||
|
|
@ -322,11 +322,9 @@ else()
|
|||
|
||||
if(COMPILES_AVX512CORE)
|
||||
set(mlas_platform_srcs_avx512core
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Core.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Core.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S
|
||||
)
|
||||
if(HAS_AVX512CORE)
|
||||
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl")
|
||||
|
|
|
|||
|
|
@ -22,10 +22,7 @@ class QAttention : public OpKernel, public AttentionCPUBase {
|
|||
QAttention(const OpKernelInfo& info);
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
|
||||
#endif
|
||||
|
||||
private:
|
||||
BufferUniquePtr packed_weights_;
|
||||
|
|
@ -51,7 +48,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
|
|||
template <typename T>
|
||||
QAttention<T>::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info) {}
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
template <typename T>
|
||||
Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_packed) {
|
||||
is_packed = false;
|
||||
|
|
@ -98,7 +94,6 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
|
|||
is_packed = true;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
Status QAttention<T>::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -217,44 +212,29 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
head_size,
|
||||
&dequant_scale,
|
||||
bias_data + weights_offset);
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
|
||||
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
|
||||
gemm_params.M = sequence_length;
|
||||
gemm_params.N = head_size;
|
||||
gemm_params.K = input_hidden_size;
|
||||
gemm_params.A = input_data + input_offset;
|
||||
gemm_params.lda = input_hidden_size;
|
||||
gemm_params.ZeroPointA = input_zero_point;
|
||||
if (packed_weights_) {
|
||||
const auto* packed_weight =
|
||||
static_cast<const uint8_t*>(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size);
|
||||
|
||||
MlasGemm(
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
input_hidden_size, // K = D
|
||||
input_data + input_offset, // A
|
||||
input_hidden_size, // lda = D
|
||||
input_zero_point, // input zero point
|
||||
packed_weight, // B
|
||||
weight_zero_point, // weight zero point
|
||||
weights_is_signed, // weight data type
|
||||
reinterpret_cast<int32_t*>(qkv_dest + qkv_offset), // C
|
||||
head_size, // ldc
|
||||
nullptr, // use single-thread
|
||||
&scale_bias_processor); // output processor
|
||||
|
||||
continue;
|
||||
gemm_params.B = packed_weight;
|
||||
gemm_params.BIsPacked = true;
|
||||
} else {
|
||||
gemm_params.B = weights_data + weights_offset;
|
||||
gemm_params.ldb = 3 * hidden_size;
|
||||
}
|
||||
#endif
|
||||
MlasGemm(
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
input_hidden_size, // K = D
|
||||
input_data + input_offset, // A
|
||||
input_hidden_size, // lda = D
|
||||
input_zero_point, // input zero point
|
||||
weights_data + weights_offset, // B
|
||||
3 * hidden_size, // ldb = 3NH
|
||||
weight_zero_point, // weight zero point
|
||||
weights_is_signed, // weight data type
|
||||
reinterpret_cast<int32_t*>(qkv_dest + qkv_offset), // C
|
||||
head_size, // ldc
|
||||
nullptr, // use single-thread
|
||||
&scale_bias_processor); // post processor
|
||||
gemm_params.ZeroPointB = &weight_zero_point;
|
||||
gemm_params.BIsSigned = weights_is_signed;
|
||||
gemm_params.C = reinterpret_cast<int32_t*>(qkv_dest + qkv_offset);
|
||||
gemm_params.ldc = head_size;
|
||||
gemm_params.OutputProcessor = &scale_bias_processor;
|
||||
MlasGemm(&gemm_params, nullptr);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,19 +10,14 @@ using namespace rnn::detail;
|
|||
class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
|
||||
public:
|
||||
DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {}
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
|
||||
#endif
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
~DynamicQuantizeLSTM() override = default;
|
||||
|
||||
private:
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
Status ComputeImpl(OpKernelContext& context) const;
|
||||
|
|
@ -33,7 +28,6 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
|
|||
bool is_R_signed_;
|
||||
};
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status DynamicQuantizeLSTM::TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed) {
|
||||
const auto& shape = weights.Shape();
|
||||
if (shape.NumDimensions() != 3) {
|
||||
|
|
@ -83,7 +77,6 @@ Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, bool& i
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#define WeightCheck(weight_shape, weight_name) \
|
||||
if (weight_shape.NumDimensions() != 1 && weight_shape.NumDimensions() != 2 || \
|
||||
|
|
|
|||
|
|
@ -44,11 +44,19 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
|
|||
if (y->Shape().Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
|
||||
gemm_params.M = static_cast<size_t>(helper.M());
|
||||
gemm_params.N = static_cast<size_t>(helper.N());
|
||||
gemm_params.K = static_cast<size_t>(helper.K());
|
||||
gemm_params.lda = gemm_params.K;
|
||||
gemm_params.ZeroPointA = a_zero_point;
|
||||
gemm_params.ldb = gemm_params.N;
|
||||
gemm_params.ZeroPointB = &b_zero_point;
|
||||
gemm_params.ldc = gemm_params.N;
|
||||
|
||||
auto* y_data = y->template MutableData<float>();
|
||||
const auto* bias_data = bias_tensor != nullptr ? bias_tensor->Data<float>() : nullptr;
|
||||
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
|
|
@ -56,40 +64,18 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
|
|||
&multiplier,
|
||||
bias_data);
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
gemm_params.A = a_data + helper.LeftOffsets()[i];
|
||||
if (packed_b_) {
|
||||
MlasGemm(static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_zero_point,
|
||||
packed_b_.get(),
|
||||
b_zero_point,
|
||||
b_is_signed_,
|
||||
reinterpret_cast<int32_t*>(y_data + helper.OutputOffsets()[i]),
|
||||
static_cast<size_t>(helper.N()),
|
||||
thread_pool,
|
||||
&scale_bias_processor);
|
||||
continue;
|
||||
gemm_params.B = packed_b_.get();
|
||||
gemm_params.BIsPacked = true;
|
||||
gemm_params.BIsSigned = b_is_signed_;
|
||||
} else {
|
||||
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + + helper.RightOffsets()[i];
|
||||
gemm_params.BIsSigned = b->IsDataType<int8_t>();
|
||||
}
|
||||
#endif
|
||||
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
|
||||
const bool b_is_signed = b->IsDataType<int8_t>();
|
||||
MlasGemm(static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_zero_point,
|
||||
b_data + helper.RightOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
b_zero_point,
|
||||
b_is_signed,
|
||||
reinterpret_cast<int32_t*>(y_data + helper.OutputOffsets()[i]),
|
||||
static_cast<size_t>(helper.N()),
|
||||
thread_pool,
|
||||
&scale_bias_processor);
|
||||
gemm_params.C = reinterpret_cast<int32_t*>(y_data) + helper.OutputOffsets()[i];
|
||||
gemm_params.OutputProcessor = &scale_bias_processor;
|
||||
MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -61,10 +61,6 @@ Abstract:
|
|||
#define MLAS_SUPPORTS_GEMM_DOUBLE
|
||||
#endif
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_ARM64) || (defined(MLAS_TARGET_ARM) && !defined(_MSC_VER))
|
||||
#define MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
#endif
|
||||
|
||||
//
|
||||
// Basic Linear Algebra Subprograms (BLAS) types.
|
||||
//
|
||||
|
|
@ -273,41 +269,29 @@ private:
|
|||
MLAS_QUANTIZATION_GRANULARITY QuantGran_;
|
||||
};
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
size_t ldb,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool,
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
|
||||
);
|
||||
struct MLAS_GEMM_U8X8_PARAMETERS {
|
||||
size_t M = 0;
|
||||
size_t N = 0;
|
||||
size_t K = 0;
|
||||
const uint8_t* A = nullptr;
|
||||
size_t lda = 0;
|
||||
uint8_t ZeroPointA = 0;
|
||||
const void* B = 0;
|
||||
size_t ldb = 0;
|
||||
const uint8_t* ZeroPointB = nullptr;
|
||||
bool BIsPacked = false;
|
||||
bool BIsSigned = false;
|
||||
bool PerColumnZeroPoints = false;
|
||||
int32_t* C = nullptr;
|
||||
size_t ldc = 0;
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr;
|
||||
};
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const void* PackedB,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool,
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
|
||||
const MLAS_GEMM_U8X8_PARAMETERS* Parameters,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ Abstract:
|
|||
// Stack frame layout for the U8X8 kernel.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_SavedGeneralRegisters, (6 * 4)
|
||||
.equ .LGemmU8X8KernelFrame_SavedGeneralRegisters, (7 * 4)
|
||||
.equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (8 * 8)
|
||||
.equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedGeneralRegisters + .LGemmU8X8KernelFrame_SavedNeonRegisters
|
||||
.equ .LGemmU8X8KernelFrame_CountM, 0 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
|
|
@ -33,7 +33,7 @@ Abstract:
|
|||
.equ .LGemmU8X8KernelFrame_ldc, 8 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 12 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 16 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
.equ .LGemmU8X8KernelFrame_DepthValue, 20 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
.equ .LGemmU8X8KernelFrame_ZeroPointB, 20 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 24 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
|
||||
.text
|
||||
|
|
@ -48,10 +48,10 @@ Routine Description:
|
|||
Arguments:
|
||||
|
||||
A (r0) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackANeon.
|
||||
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>.
|
||||
|
||||
B (r1) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackBNeon.
|
||||
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_NEON>.
|
||||
|
||||
C (r2) - Supplies the address of matrix C.
|
||||
|
||||
|
|
@ -67,17 +67,18 @@ Arguments:
|
|||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by
|
||||
the zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A. These values have
|
||||
been pre-scaled by the zero point offset of matrix B if the offset is
|
||||
per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
scaled by the per-column zero point offsets of matrix B. These values are
|
||||
accumulated into every row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrix A multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized, else
|
||||
false if the output matrix is accumulated into.
|
||||
|
|
@ -95,26 +96,24 @@ Return Value:
|
|||
//
|
||||
// q0-q1 (d0-d3) matrix B data
|
||||
// q2-q3 (d4-d7) matrix A data
|
||||
// q4 (d8-d9) packed matrix B data
|
||||
// q5 (d10-d11) RowSumBufferData + DepthValue
|
||||
// q4 (d8-d9) packed matrix A data
|
||||
// q5 (d10-d11) RowSumBuffer data
|
||||
// q6-q7 (d12-d15) ColumnSumBuffer data
|
||||
// q8-q15 accumulators[4][2]
|
||||
//
|
||||
|
||||
push {r4,r5,r6,r7,r8,r10}
|
||||
push {r4,r5,r6,r7,r8,r9,r10}
|
||||
vpush {d8-d15}
|
||||
ldr r4,[sp,#.LGemmU8X8KernelFrame_CountM]
|
||||
ldr r5,[sp,#.LGemmU8X8KernelFrame_ZeroMode]
|
||||
ldr r7,[sp,#.LGemmU8X8KernelFrame_RowSumBuffer]
|
||||
ldr r7,[sp,#.LGemmU8X8KernelFrame_ZeroPointB]
|
||||
ldr r8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer]
|
||||
vldr s0,[sp,#.LGemmU8X8KernelFrame_DepthValue]
|
||||
ldr r9,[sp,#.LGemmU8X8KernelFrame_RowSumBuffer]
|
||||
ldr r10,[sp,#.LGemmU8X8KernelFrame_ldc]
|
||||
ldr r12,[sp,#.LGemmU8X8KernelFrame_CountN]
|
||||
vdup.32 q5,d0[0] // broadcast DepthValue
|
||||
vld1.32 {d12-d13},[r7]
|
||||
vld1.32 {d10-d11},[r9] // load RowSumBuffer
|
||||
mov r6,r0
|
||||
mov r7,r3
|
||||
vadd.u32 q5,q5,q6 // add row fixups and DepthValue
|
||||
mov r9,r3
|
||||
cmp r4,#1 // CountM == 1?
|
||||
beq .LGemmU8X8.M1.ProcessNextColumnLoop
|
||||
cmp r4,#4 // CountM < 4?
|
||||
|
|
@ -128,12 +127,36 @@ Return Value:
|
|||
vldr d0,[r1] // load packed B0
|
||||
mov r0,r6 // reload matrix A
|
||||
vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer
|
||||
mov r3,r7 // reload PackedCountK
|
||||
mov r3,r9 // reload PackedCountK
|
||||
vmovl.u8 q0,d0
|
||||
vdup.32 q9,d10[0]
|
||||
vdup.32 q11,d10[1]
|
||||
vdup.32 q13,d11[0]
|
||||
vdup.32 q15,d11[1]
|
||||
cbz r7,.LGemmU8X8.M4.SkipScaleByZeroPointB
|
||||
vld1.32 {d8-d9},[r7]! // load ZeroPointB0
|
||||
vmul.u32 q8,q9,q4
|
||||
vmul.u32 q10,q11,q4
|
||||
vmul.u32 q12,q13,q4
|
||||
vmul.u32 q14,q15,q4
|
||||
vld1.32 {d8-d9},[r7]! // load ZeroPointB1
|
||||
vmul.u32 q9,q9,q4
|
||||
vmul.u32 q11,q11,q4
|
||||
vmul.u32 q13,q13,q4
|
||||
vmul.u32 q15,q15,q4
|
||||
vldr d8,[r0] // load first packed A0
|
||||
vadd.u32 q8,q8,q6
|
||||
vadd.u32 q9,q9,q7
|
||||
vadd.u32 q10,q10,q6
|
||||
vadd.u32 q11,q11,q7
|
||||
vldr d9,[r0,#8] // load first packed A1
|
||||
vadd.u32 q12,q12,q6
|
||||
vadd.u32 q13,q13,q7
|
||||
vadd.u32 q14,q14,q6
|
||||
vadd.u32 q15,q15,q7
|
||||
b .LGemmU8X8.M4.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M4.SkipScaleByZeroPointB:
|
||||
vldr d8,[r0] // load first packed A0
|
||||
vadd.u32 q8,q9,q6
|
||||
vadd.u32 q9,q9,q7
|
||||
|
|
@ -244,7 +267,7 @@ Return Value:
|
|||
.LGemmU8X8.M4.ExitKernel:
|
||||
mov r0,#4 // return number of rows handled
|
||||
vpop {d8-d15}
|
||||
pop {r4,r5,r6,r7,r8,r10}
|
||||
pop {r4,r5,r6,r7,r8,r9,r10}
|
||||
bx lr
|
||||
|
||||
//
|
||||
|
|
@ -353,10 +376,24 @@ Return Value:
|
|||
vldr d0,[r1] // load packed B0
|
||||
mov r0,r6 // reload matrix A
|
||||
vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer
|
||||
mov r3,r7 // reload PackedCountK
|
||||
mov r3,r9 // reload PackedCountK
|
||||
vmovl.u8 q0,d0
|
||||
vdup.32 q9,d10[0]
|
||||
vdup.32 q11,d10[1]
|
||||
cbz r7,.LGemmU8X8.M2.SkipScaleByZeroPointB
|
||||
vld1.32 {d28-d31},[r7]! // load ZeroPointB
|
||||
vmul.u32 q8,q9,q14
|
||||
vmul.u32 q9,q9,q15
|
||||
vmul.u32 q10,q11,q14
|
||||
vmul.u32 q11,q11,q15
|
||||
vld1.32 d8,[r0]! // load first packed A0
|
||||
vadd.u32 q8,q8,q6
|
||||
vadd.u32 q9,q9,q7
|
||||
vadd.u32 q10,q10,q6
|
||||
vadd.u32 q11,q11,q7
|
||||
b .LGemmU8X8.M2.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M2.SkipScaleByZeroPointB:
|
||||
vld1.32 d8,[r0]! // load first packed A0
|
||||
vadd.u32 q8,q9,q6
|
||||
vadd.u32 q9,q9,q7
|
||||
|
|
@ -425,7 +462,7 @@ Return Value:
|
|||
.LGemmU8X8.M2.ExitKernel:
|
||||
mov r0,#2 // return number of rows handled
|
||||
vpop {d8-d15}
|
||||
pop {r4,r5,r6,r7,r8,r10}
|
||||
pop {r4,r5,r6,r7,r8,r9,r10}
|
||||
bx lr
|
||||
|
||||
//
|
||||
|
|
@ -502,9 +539,19 @@ Return Value:
|
|||
vldr d0,[r1] // load packed B0
|
||||
mov r0,r6 // reload matrix A
|
||||
vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer
|
||||
mov r3,r7 // reload PackedCountK
|
||||
mov r3,r9 // reload PackedCountK
|
||||
vmovl.u8 q0,d0
|
||||
vdup.32 q9,d10[0]
|
||||
cbz r7,.LGemmU8X8.M1.SkipScaleByZeroPointB
|
||||
vld1.32 {d28-d31},[r7]! // load ZeroPointB
|
||||
vmul.u32 q8,q9,q14
|
||||
vmul.u32 q9,q9,q15
|
||||
vld1.32 d8[0],[r0]! // load first packed A0
|
||||
vadd.u32 q8,q8,q6
|
||||
vadd.u32 q9,q9,q7
|
||||
b .LGemmU8X8.M1.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M1.SkipScaleByZeroPointB:
|
||||
vld1.32 d8[0],[r0]! // load first packed A0
|
||||
vadd.u32 q8,q9,q6
|
||||
vadd.u32 q9,q9,q7
|
||||
|
|
@ -554,7 +601,7 @@ Return Value:
|
|||
.LGemmU8X8.M1.ExitKernel:
|
||||
mov r0,#1 // return number of rows handled
|
||||
vpop {d8-d15}
|
||||
pop {r4,r5,r6,r7,r8,r10}
|
||||
pop {r4,r5,r6,r7,r8,r9,r10}
|
||||
bx lr
|
||||
|
||||
//
|
||||
|
|
|
|||
52
onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h
Normal file
52
onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
AssembleDotProduct.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains macros to build Advanced SIMD dot product instructions
|
||||
for toolchains that do not natively support this newer instruction set
|
||||
extension.
|
||||
|
||||
This implementation uses ARM v8.4 dot product instructions.
|
||||
|
||||
--*/
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro builds a UDOT instruction of the form:
|
||||
|
||||
UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index]
|
||||
|
||||
Arguments:
|
||||
|
||||
DestReg - Specifies the destination register.
|
||||
|
||||
Src1Reg - Specifies the first source register.
|
||||
|
||||
Src2Reg - Specifies the second source register.
|
||||
|
||||
Index - Specifies the element index of the second source register.
|
||||
|
||||
--*/
|
||||
|
||||
.macro UdotByElement DestReg, Src1Reg, Src2Reg, Index
|
||||
|
||||
.set Instruction, 0x6F80E000
|
||||
.set Instruction, Instruction + (\DestReg\() << 0)
|
||||
.set Instruction, Instruction + (\Src1Reg\() << 5)
|
||||
.set Instruction, Instruction + (\Src2Reg\() << 16)
|
||||
.set Instruction, Instruction + ((\Index\() & 2) << 10)
|
||||
.set Instruction, Instruction + ((\Index\() & 1) << 21)
|
||||
|
||||
.inst Instruction
|
||||
|
||||
.endm
|
||||
|
|
@ -22,12 +22,8 @@ Abstract:
|
|||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0
|
||||
.equ .LGemmU8X8KernelFrame_DepthValue, 8
|
||||
#if defined(__APPLE__)
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 12
|
||||
#else
|
||||
.equ .LGemmU8X8KernelFrame_ZeroPointB, 8
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 16
|
||||
#endif
|
||||
|
||||
.text
|
||||
|
||||
|
|
@ -41,10 +37,10 @@ Routine Description:
|
|||
Arguments:
|
||||
|
||||
A (x0) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackANeon.
|
||||
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_NEON>.
|
||||
|
||||
B (x1) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackBNeon.
|
||||
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_NEON>.
|
||||
|
||||
C (x2) - Supplies the address of matrix C.
|
||||
|
||||
|
|
@ -60,17 +56,18 @@ Arguments:
|
|||
|
||||
ldc (x6) - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer (x7) - Supplies the sum of each row from matrix A multiplied by
|
||||
the zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values
|
||||
have been pre-scaled by the zero point offset of matrix B if the offset
|
||||
is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
scaled by the per-column zero point offsets of matrix B. These values are
|
||||
accumulated into every row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrix A multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized, else
|
||||
false if the output matrix is accumulated into.
|
||||
|
|
@ -84,13 +81,11 @@ Return Value:
|
|||
FUNCTION_ENTRY MlasGemmU8X8KernelNeon
|
||||
|
||||
ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer]
|
||||
ldr s27,[sp,#.LGemmU8X8KernelFrame_DepthValue]
|
||||
ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB]
|
||||
ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode]
|
||||
dup v27.4s,v27.s[0]
|
||||
mov x14,x0
|
||||
ld1 {v0.4s},[x7]
|
||||
ld1 {v27.4s},[x7]
|
||||
mov x15,x3
|
||||
add v27.4s,v27.4s,v0.4s // broadcast add DepthValue
|
||||
dup v24.4s,v27.s[0] // broadcast row fixups
|
||||
cmp x4,#1 // CountM == 1?
|
||||
beq .LGemmU8X8.M1.ProcessNextColumnLoop
|
||||
|
|
@ -111,6 +106,30 @@ Return Value:
|
|||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
uxtl v0.8h,v0.8b
|
||||
cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB
|
||||
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v24.4s,v28.4s
|
||||
mul v17.4s,v24.4s,v29.4s
|
||||
mul v18.4s,v25.4s,v28.4s
|
||||
mul v19.4s,v25.4s,v29.4s
|
||||
mul v20.4s,v26.4s,v28.4s
|
||||
mul v21.4s,v26.4s,v29.4s
|
||||
mul v22.4s,v27.4s,v28.4s
|
||||
mul v23.4s,v27.4s,v29.4s
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
ld1 {v5.8b},[x0],#8 // load first packed A1
|
||||
add v20.4s,v2.4s,v20.4s
|
||||
add v21.4s,v3.4s,v21.4s
|
||||
add v22.4s,v2.4s,v22.4s
|
||||
add v23.4s,v3.4s,v23.4s
|
||||
b .LGemmU8X8.M4.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M4.SkipScaleByZeroPointB:
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v24.4s
|
||||
add v17.4s,v3.4s,v24.4s
|
||||
|
|
@ -322,6 +341,21 @@ Return Value:
|
|||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
uxtl v0.8h,v0.8b
|
||||
cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB
|
||||
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v24.4s,v28.4s
|
||||
mul v17.4s,v24.4s,v29.4s
|
||||
mul v18.4s,v25.4s,v28.4s
|
||||
mul v19.4s,v25.4s,v29.4s
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
b .LGemmU8X8.M2.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M2.SkipScaleByZeroPointB:
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v24.4s
|
||||
add v17.4s,v3.4s,v24.4s
|
||||
|
|
@ -460,6 +494,17 @@ Return Value:
|
|||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
uxtl v0.8h,v0.8b
|
||||
cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB
|
||||
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v24.4s,v28.4s
|
||||
mul v17.4s,v24.4s,v29.4s
|
||||
ldr s4,[x0],#4 // load first packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
b .LGemmU8X8.M1.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M1.SkipScaleByZeroPointB:
|
||||
ldr s4,[x0],#4 // load first packed A0
|
||||
add v16.4s,v2.4s,v24.4s
|
||||
add v17.4s,v3.4s,v24.4s
|
||||
|
|
|
|||
585
onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S
Normal file
585
onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S
Normal file
|
|
@ -0,0 +1,585 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelUdot.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses ARM v8.4 dot product instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "AssembleDotProduct.h"
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8X8 kernel.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (4 * 8)
|
||||
.equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedNeonRegisters
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
.equ .LGemmU8X8KernelFrame_ZeroPointB, 8 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 16 + .LGemmU8X8KernelFrame_SavedRegisters
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (x0) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>.
|
||||
|
||||
B (x1) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_UDOT>.
|
||||
|
||||
C (x2) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (x3) - Supplies the number of packed columns from matrix A and
|
||||
the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (x4) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (x5) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc (x6) - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values
|
||||
have been pre-scaled by the zero point offset of matrix B if the offset
|
||||
is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
scaled by the per-column zero point offsets of matrix B. These values are
|
||||
accumulated into every row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized, else
|
||||
false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8X8KernelUdot
|
||||
|
||||
stp d8,d9,[sp,#-32]!
|
||||
stp d10,d11,[sp,#16]
|
||||
ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer]
|
||||
ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB]
|
||||
ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode]
|
||||
mov x14,x0
|
||||
ld1 {v11.4s},[x7]
|
||||
mov x15,x3
|
||||
dup v8.4s,v11.s[0] // broadcast row fixups
|
||||
cmp x4,#1 // CountM == 1?
|
||||
beq .LGemmU8X8.M1.ProcessNextColumnLoop
|
||||
dup v9.4s,v11.s[1]
|
||||
cmp x4,#4 // CountM < 4?
|
||||
blo .LGemmU8X8.M2.ProcessNextColumnLoop
|
||||
dup v10.4s,v11.s[2]
|
||||
dup v11.4s,v11.s[3]
|
||||
|
||||
//
|
||||
// Process 4 rows of the matrices.
|
||||
//
|
||||
|
||||
.LGemmU8X8.M4.ProcessNextColumnLoop:
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
mov x0,x14 // reload matrix A
|
||||
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
|
||||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
|
||||
cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB
|
||||
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
|
||||
mul v16.4s,v30.4s,v8.4s
|
||||
mul v18.4s,v30.4s,v9.4s
|
||||
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
|
||||
mul v20.4s,v30.4s,v10.4s
|
||||
mul v22.4s,v30.4s,v11.4s
|
||||
mul v17.4s,v31.4s,v8.4s
|
||||
mul v19.4s,v31.4s,v9.4s
|
||||
mul v21.4s,v31.4s,v10.4s
|
||||
mul v23.4s,v31.4s,v11.4s
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v20.4s,v2.4s,v20.4s
|
||||
add v22.4s,v2.4s,v22.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
add v21.4s,v3.4s,v21.4s
|
||||
add v23.4s,v3.4s,v23.4s
|
||||
b .LGemmU8X8.M4.ComputeBlockLoopStart
|
||||
|
||||
.LGemmU8X8.M4.SkipScaleByZeroPointB:
|
||||
add v16.4s,v2.4s,v8.4s
|
||||
add v18.4s,v2.4s,v9.4s
|
||||
add v20.4s,v2.4s,v10.4s
|
||||
add v22.4s,v2.4s,v11.4s
|
||||
add v17.4s,v3.4s,v8.4s
|
||||
add v19.4s,v3.4s,v9.4s
|
||||
add v21.4s,v3.4s,v10.4s
|
||||
add v23.4s,v3.4s,v11.4s
|
||||
|
||||
//
|
||||
// The packing layout is setup to have a pair of four quad vectors from
|
||||
// packed matrix A and a pair of eight quad vectors from packed matrix B.
|
||||
// With this scheme, alternating loads from the packed matrices can be
|
||||
// interleaved with the dot product instructions.
|
||||
//
|
||||
// One negative consequence of using four rows here is that the accumulator
|
||||
// register tile is too small for processors with high out of order execution
|
||||
// windows (such as the Apple M1). The dot product instructions for a given
|
||||
// cell are too close to each other to avoid dependencies. To workaround this,
|
||||
// the below loop uses a pair of accumulator registers that are then added
|
||||
// together when the loop finishes.
|
||||
//
|
||||
// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for
|
||||
// packed matrix A. At the time of this implementation, using a wider 128-bit
|
||||
// load didn't affect performance for higher end cores.
|
||||
//
|
||||
|
||||
.LGemmU8X8.M4.ComputeBlockLoopStart:
|
||||
ldr d4,[x0],#32 // load packed A0.l
|
||||
movi v24.4s,#0
|
||||
movi v25.4s,#0
|
||||
ldur d5,[x0,#-24] // load packed A0.h
|
||||
movi v26.4s,#0
|
||||
movi v27.4s,#0
|
||||
ldur d6,[x0,#-16] // load packed A1.l
|
||||
movi v28.4s,#0
|
||||
movi v29.4s,#0
|
||||
movi v30.4s,#0
|
||||
movi v31.4s,#0
|
||||
|
||||
.LGemmU8X8.M4.ComputeBlockLoop:
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 16, 0, 4, 0
|
||||
UdotByElement 18, 0, 4, 1
|
||||
ldur d7,[x0,#-8] // load packed A1.h
|
||||
UdotByElement 20, 0, 5, 0
|
||||
UdotByElement 22, 0, 5, 1
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
UdotByElement 17, 1, 4, 0
|
||||
UdotByElement 19, 1, 4, 1
|
||||
sub x3,x3,#1
|
||||
cbz x3,.LGemmU8X8.M4.ComputeBlockLoopFinish
|
||||
ldr d4,[x0],#32 // load packed A0.l
|
||||
UdotByElement 21, 1, 5, 0
|
||||
UdotByElement 23, 1, 5, 1
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 24, 0, 6, 0
|
||||
UdotByElement 26, 0, 6, 1
|
||||
ldur d5,[x0,#-24] // load packed A0.h
|
||||
UdotByElement 28, 0, 7, 0
|
||||
UdotByElement 30, 0, 7, 1
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
UdotByElement 25, 1, 6, 0
|
||||
UdotByElement 27, 1, 6, 1
|
||||
ldur d6,[x0,#-16] // load packed A1.l
|
||||
UdotByElement 29, 1, 7, 0
|
||||
UdotByElement 31, 1, 7, 1
|
||||
b .LGemmU8X8.M4.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M4.ComputeBlockLoopFinish:
|
||||
UdotByElement 21, 1, 5, 0
|
||||
UdotByElement 23, 1, 5, 1
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 24, 0, 6, 0
|
||||
UdotByElement 26, 0, 6, 1
|
||||
UdotByElement 28, 0, 7, 0
|
||||
UdotByElement 30, 0, 7, 1
|
||||
UdotByElement 25, 1, 6, 0
|
||||
UdotByElement 27, 1, 6, 1
|
||||
UdotByElement 29, 1, 7, 0
|
||||
UdotByElement 31, 1, 7, 1
|
||||
add x10,x2,x6,lsl #2 // compute output row 2
|
||||
add v16.4s,v16.4s,v24.4s // fold high results into low results
|
||||
add v18.4s,v18.4s,v26.4s
|
||||
add v20.4s,v20.4s,v28.4s
|
||||
add v22.4s,v22.4s,v30.4s
|
||||
add x11,x10,x6,lsl #2 // compute output row 3
|
||||
add v17.4s,v17.4s,v25.4s
|
||||
add v19.4s,v19.4s,v27.4s
|
||||
add v21.4s,v21.4s,v29.4s
|
||||
add v23.4s,v23.4s,v31.4s
|
||||
add x12,x11,x6,lsl #2 // compute output row 4
|
||||
subs x5,x5,#8 // adjust CountN remaining
|
||||
blo .LGemmU8X8.M4.StoreOutputPartial
|
||||
cbnz x13,.LGemmU8X8.M4.SkipAccumulateOutput
|
||||
ldp q0,q1,[x2]
|
||||
ldp q2,q3,[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v17.4s,v17.4s,v1.4s
|
||||
ldp q4,q5,[x11]
|
||||
add v18.4s,v18.4s,v2.4s
|
||||
add v19.4s,v19.4s,v3.4s
|
||||
ldp q6,q7,[x12]
|
||||
add v20.4s,v20.4s,v4.4s
|
||||
add v21.4s,v21.4s,v5.4s
|
||||
add v22.4s,v22.4s,v6.4s
|
||||
add v23.4s,v23.4s,v7.4s
|
||||
|
||||
.LGemmU8X8.M4.SkipAccumulateOutput:
|
||||
stp q16,q17,[x2],#32
|
||||
stp q18,q19,[x10]
|
||||
stp q20,q21,[x11]
|
||||
stp q22,q23,[x12]
|
||||
cbnz x5,.LGemmU8X8.M4.ProcessNextColumnLoop
|
||||
|
||||
.LGemmU8X8.M4.ExitKernel:
|
||||
mov x0,#4 // return number of rows handled
|
||||
ldp d10,d11,[sp,#16]
|
||||
ldp d8,d9,[sp],#32
|
||||
ret
|
||||
|
||||
//
|
||||
// Store the partial 1 to 7 columns either overwriting the output matrix or
|
||||
// accumulating into the existing contents of the output matrix.
|
||||
//
|
||||
|
||||
.LGemmU8X8.M4.StoreOutputPartial:
|
||||
cbz x13,.LGemmU8X8.M4.StoreOutputPartial.AddMode
|
||||
|
||||
.LGemmU8X8.M4.StoreOutputPartial.ZeroMode:
|
||||
tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
st1 {v20.4s},[x11],#16
|
||||
mov v20.16b,v21.16b
|
||||
st1 {v22.4s},[x12],#16
|
||||
mov v22.16b,v23.16b
|
||||
|
||||
.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode:
|
||||
tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
st1 {v20.2s},[x11],#8
|
||||
dup v20.4s,v20.s[2]
|
||||
st1 {v22.2s},[x12],#8
|
||||
dup v22.4s,v22.s[2]
|
||||
|
||||
.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode:
|
||||
tbz x5,#0,.LGemmU8X8.M4.ExitKernel
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
st1 {v20.s}[0],[x11]
|
||||
st1 {v22.s}[0],[x12]
|
||||
b .LGemmU8X8.M4.ExitKernel
|
||||
|
||||
.LGemmU8X8.M4.StoreOutputPartial.AddMode:
|
||||
tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.AddMode
|
||||
ld1 {v0.4s},[x2]
|
||||
ld1 {v1.4s},[x10]
|
||||
ld1 {v2.4s},[x11]
|
||||
ld1 {v3.4s},[x12]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
add v20.4s,v20.4s,v2.4s
|
||||
add v22.4s,v22.4s,v3.4s
|
||||
st1 {v20.4s},[x11],#16
|
||||
mov v20.16b,v21.16b
|
||||
st1 {v22.4s},[x12],#16
|
||||
mov v22.16b,v23.16b
|
||||
|
||||
.LGemmU8X8.M4.StoreOutputPartial2.AddMode:
|
||||
tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.AddMode
|
||||
ld1 {v0.2s},[x2]
|
||||
ld1 {v1.2s},[x10]
|
||||
ld1 {v2.2s},[x11]
|
||||
ld1 {v3.2s},[x12]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
add v20.4s,v20.4s,v2.4s
|
||||
add v22.4s,v22.4s,v3.4s
|
||||
st1 {v20.2s},[x11],#8
|
||||
dup v20.4s,v20.s[2]
|
||||
st1 {v22.2s},[x12],#8
|
||||
dup v22.4s,v22.s[2]
|
||||
|
||||
.LGemmU8X8.M4.StoreOutputPartial1.AddMode:
|
||||
tbz x5,#0,.LGemmU8X8.M4.ExitKernel
|
||||
ld1 {v0.s}[0],[x2]
|
||||
ld1 {v1.s}[0],[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
ld1 {v2.s}[0],[x11]
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
ld1 {v3.s}[0],[x12]
|
||||
add v20.4s,v20.4s,v2.4s
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
add v22.4s,v22.4s,v3.4s
|
||||
st1 {v20.s}[0],[x11]
|
||||
st1 {v22.s}[0],[x12]
|
||||
b .LGemmU8X8.M4.ExitKernel
|
||||
|
||||
//
|
||||
// Process 2 rows of the matrices.
|
||||
//
|
||||
|
||||
.LGemmU8X8.M2.ProcessNextColumnLoop:
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
mov x0,x14 // reload matrix A
|
||||
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
|
||||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
|
||||
cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB
|
||||
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
|
||||
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
|
||||
mul v16.4s,v30.4s,v8.4s
|
||||
mul v18.4s,v30.4s,v9.4s
|
||||
mul v17.4s,v31.4s,v8.4s
|
||||
mul v19.4s,v31.4s,v9.4s
|
||||
ld1 {v4.16b},[x0],#16 // load packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
b .LGemmU8X8.M2.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M2.SkipScaleByZeroPointB:
|
||||
ld1 {v4.16b},[x0],#16 // load packed A0
|
||||
add v16.4s,v2.4s,v8.4s
|
||||
add v18.4s,v2.4s,v9.4s
|
||||
add v17.4s,v3.4s,v8.4s
|
||||
add v19.4s,v3.4s,v9.4s
|
||||
|
||||
.LGemmU8X8.M2.ComputeBlockLoop:
|
||||
UdotByElement 16, 0, 4, 0
|
||||
UdotByElement 17, 1, 4, 0
|
||||
UdotByElement 18, 0, 4, 1
|
||||
UdotByElement 19, 1, 4, 1
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 16, 0, 4, 2
|
||||
UdotByElement 17, 1, 4, 2
|
||||
UdotByElement 18, 0, 4, 3
|
||||
UdotByElement 19, 1, 4, 3
|
||||
sub x3,x3,#1
|
||||
cbz x3,.LGemmU8X8.M2.ComputeBlockLoopFinish
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
ld1 {v4.16b},[x0],#16 // load packed A0
|
||||
b .LGemmU8X8.M2.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M2.ComputeBlockLoopFinish:
|
||||
add x10,x2,x6,lsl #2 // compute output row 2
|
||||
subs x5,x5,#8 // adjust CountN remaining
|
||||
blo .LGemmU8X8.M2.StoreOutputPartial
|
||||
cbnz x13,.LGemmU8X8.M2.SkipAccumulateOutput
|
||||
ldp q0,q1,[x2]
|
||||
ldp q2,q3,[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v17.4s,v17.4s,v1.4s
|
||||
add v18.4s,v18.4s,v2.4s
|
||||
add v19.4s,v19.4s,v3.4s
|
||||
|
||||
.LGemmU8X8.M2.SkipAccumulateOutput:
|
||||
stp q16,q17,[x2],#32
|
||||
stp q18,q19,[x10]
|
||||
cbnz x5,.LGemmU8X8.M2.ProcessNextColumnLoop
|
||||
|
||||
.LGemmU8X8.M2.ExitKernel:
|
||||
mov x0,#2 // return number of rows handled
|
||||
ldp d10,d11,[sp,#16]
|
||||
ldp d8,d9,[sp],#32
|
||||
ret
|
||||
|
||||
//
|
||||
// Store the partial 1 to 7 columns either overwriting the output matrix or
|
||||
// accumulating into the existing contents of the output matrix.
|
||||
//
|
||||
|
||||
.LGemmU8X8.M2.StoreOutputPartial:
|
||||
cbz x13,.LGemmU8X8.M2.StoreOutputPartial.AddMode
|
||||
|
||||
.LGemmU8X8.M2.StoreOutputPartial.ZeroMode:
|
||||
tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
|
||||
.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode:
|
||||
tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
|
||||
.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode:
|
||||
tbz x5,#0,.LGemmU8X8.M2.ExitKernel
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
b .LGemmU8X8.M2.ExitKernel
|
||||
|
||||
.LGemmU8X8.M2.StoreOutputPartial.AddMode:
|
||||
tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.AddMode
|
||||
ld1 {v0.4s},[x2]
|
||||
ld1 {v1.4s},[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
|
||||
.LGemmU8X8.M2.StoreOutputPartial2.AddMode:
|
||||
tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.AddMode
|
||||
ld1 {v0.2s},[x2]
|
||||
ld1 {v1.2s},[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
|
||||
.LGemmU8X8.M2.StoreOutputPartial1.AddMode:
|
||||
tbz x5,#0,.LGemmU8X8.M2.ExitKernel
|
||||
ld1 {v0.s}[0],[x2]
|
||||
ld1 {v1.s}[0],[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
b .LGemmU8X8.M2.ExitKernel
|
||||
|
||||
//
|
||||
// Process 1 row of the matrices.
|
||||
//
|
||||
|
||||
.LGemmU8X8.M1.ProcessNextColumnLoop:
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
mov x0,x14 // reload matrix A
|
||||
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0
|
||||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB
|
||||
ld1 {v30.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v31.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v30.4s,v8.4s
|
||||
mul v17.4s,v31.4s,v8.4s
|
||||
ldr d4,[x0],#8 // load packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
b .LGemmU8X8.M1.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M1.SkipScaleByZeroPointB:
|
||||
ldr d4,[x0],#8 // load packed A0
|
||||
add v16.4s,v2.4s,v8.4s
|
||||
add v17.4s,v3.4s,v8.4s
|
||||
|
||||
.LGemmU8X8.M1.ComputeBlockLoop:
|
||||
UdotByElement 16, 0, 4, 0
|
||||
UdotByElement 17, 1, 4, 0
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 16, 0, 4, 1
|
||||
UdotByElement 17, 1, 4, 1
|
||||
sub x3,x3,#1
|
||||
cbz x3,.LGemmU8X8.M1.ComputeBlockLoopFinish
|
||||
ldr d4,[x0],#8 // load packed A0
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
b .LGemmU8X8.M1.ComputeBlockLoop
|
||||
|
||||
.LGemmU8X8.M1.ComputeBlockLoopFinish:
|
||||
subs x5,x5,#8 // adjust CountN remaining
|
||||
blo .LGemmU8X8.M1.StoreOutputPartial
|
||||
cbnz x13,.LGemmU8X8.M1.SkipAccumulateOutput
|
||||
ldp q0,q1,[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v17.4s,v17.4s,v1.4s
|
||||
|
||||
.LGemmU8X8.M1.SkipAccumulateOutput:
|
||||
stp q16,q17,[x2],#32
|
||||
cbnz x5,.LGemmU8X8.M1.ProcessNextColumnLoop
|
||||
|
||||
.LGemmU8X8.M1.ExitKernel:
|
||||
mov x0,#1 // return number of rows handled
|
||||
ldp d10,d11,[sp,#16]
|
||||
ldp d8,d9,[sp],#32
|
||||
ret
|
||||
|
||||
//
|
||||
// Store the partial 1 to 7 columns either overwriting the output matrix or
|
||||
// accumulating into the existing contents of the output matrix.
|
||||
//
|
||||
|
||||
.LGemmU8X8.M1.StoreOutputPartial:
|
||||
cbz x13,.LGemmU8X8.M1.StoreOutputPartial.AddMode
|
||||
|
||||
.LGemmU8X8.M1.StoreOutputPartial.ZeroMode:
|
||||
tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
|
||||
.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode:
|
||||
tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
|
||||
.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode:
|
||||
tbz x5,#0,.LGemmU8X8.M1.ExitKernel
|
||||
st1 {v16.s}[0],[x2]
|
||||
b .LGemmU8X8.M1.ExitKernel
|
||||
|
||||
.LGemmU8X8.M1.StoreOutputPartial.AddMode:
|
||||
tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.AddMode
|
||||
ld1 {v0.4s},[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
|
||||
.LGemmU8X8.M1.StoreOutputPartial2.AddMode:
|
||||
tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.AddMode
|
||||
ld1 {v0.2s},[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
|
||||
.LGemmU8X8.M1.StoreOutputPartial1.AddMode:
|
||||
tbz x5,#0,.LGemmU8X8.M1.ExitKernel
|
||||
ld1 {v0.s}[0],[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
st1 {v16.s}[0],[x2]
|
||||
b .LGemmU8X8.M1.ExitKernel
|
||||
|
||||
.end
|
||||
|
|
@ -19,9 +19,10 @@
|
|||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8X8KernelAvx2Common.inc
|
||||
.list
|
||||
|
||||
EXTERN MlasMaskMoveTableAvx:NEAR
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8S8 CopyPackA routine.
|
||||
;
|
||||
|
|
@ -142,9 +143,9 @@ GemmU8S8CopyPackBFrame ENDS
|
|||
and eax,15 ; isolate unaligned count
|
||||
add eax,3
|
||||
shr eax,2 ; align unaligned count to quad count
|
||||
mov DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp],eax
|
||||
vpbroadcastd xmm10,DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp]
|
||||
vpcmpgtd xmm10,xmm10,XMMWORD PTR [MlasMaskMoveAvx]
|
||||
neg rax
|
||||
lea rbx,MlasMaskMoveTableAvx+8*4
|
||||
vmovdqu xmm10,XMMWORD PTR [rbx+rax*4]
|
||||
|
||||
;
|
||||
; Zero initialize the padded stack buffers.
|
||||
|
|
@ -776,281 +777,4 @@ StoreColumnSumBufferNUnaligned:
|
|||
|
||||
NESTED_END MlasGemmU8S8CopyPackBAvx2, _TEXT
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single row of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; Vec2Reg - Supplies the low block accumulator register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; ymm0 - Supplies the first vector loaded from matrix B.
|
||||
;
|
||||
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
;
|
||||
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
|
||||
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddubsw ymm3,ymm2,ymm0
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd Vec1Reg,Vec1Reg,ymm3
|
||||
vpmaddubsw ymm2,ymm2,ymm1
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm2
|
||||
ELSE
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm3
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm11 - Supplies the block accumulators.
|
||||
;
|
||||
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF RowCount EQ 1
|
||||
vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]
|
||||
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd ymm4,ymm4,ymm3
|
||||
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32]
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd ymm5,ymm5,ymm2
|
||||
ELSE
|
||||
vpaddd ymm5,ymm5,ymm3
|
||||
ENDIF
|
||||
ELSE
|
||||
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
|
||||
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm11 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,64 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ENDM
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8S8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8S8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
|
||||
; the number of packed rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
; zero point offset of matrix B. These values are accumulated into every
|
||||
; row of matrix C.
|
||||
;
|
||||
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
; of matrix A multplied by the zero point offset of matrix B. This value is
|
||||
; accumulated into every element of matrix C.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemmU8S8KernelAvx2, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR13)
|
||||
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
|
||||
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
|
||||
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
|
||||
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
|
||||
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
|
||||
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
|
||||
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov rdi,rcx
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
|
||||
vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF]
|
||||
vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001]
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
cmp r11,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp r11,1
|
||||
je ProcessCountM1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM4:
|
||||
mov r11d,4 ; return 4 rows handled
|
||||
ProcessCountM 4, Fallthrough
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
mov eax,r11d
|
||||
vzeroupper
|
||||
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
|
||||
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
|
||||
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
|
||||
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
|
||||
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
|
||||
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
|
||||
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR13)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
NESTED_END MlasGemmU8S8KernelAvx2, _TEXT
|
||||
|
||||
END
|
||||
|
|
|
|||
|
|
@ -1,91 +0,0 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8S8KernelAvx512Common.inc
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module contains common kernel macros and structures for the quantized
|
||||
; integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
|
||||
; AVX512VNNI kernels.
|
||||
;
|
||||
;--
|
||||
|
||||
INCLUDE QgemmU8X8KernelAvx512Common.inc
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy4Loop
|
||||
LOCAL ProcessRemainingBlocks
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
LOCAL ComputeBlockLoopExit
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF ((RowCount AND 1) EQ 0)
|
||||
sub rsi,4*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
ComputeBlockBy4Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0*64, 0
|
||||
ComputeBlock ColumnCount, RowCount, 1*64, 4
|
||||
ComputeBlock ColumnCount, RowCount, 2*64, 8
|
||||
ComputeBlock ColumnCount, RowCount, 3*64, 12
|
||||
add rcx,4*4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,4*64 ; advance matrix B
|
||||
sub rsi,4*4 ; decrement quads remaining
|
||||
jae ComputeBlockBy4Loop
|
||||
|
||||
ProcessRemainingBlocks:
|
||||
add rsi,4*4 ; correct for over-subtract above
|
||||
jz ComputeBlockLoopExit
|
||||
ENDIF
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,64 ; advance matrix B
|
||||
sub rsi,4 ; decrement quads remaining
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ComputeBlockLoopExit:
|
||||
|
||||
ENDM
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8S8KernelAvx512Core.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVX512 core instructions (BW/DQ/VL).
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8S8KernelAvx512Common.inc
|
||||
.list
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single cell of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; AccumReg - Supplies the register to accumulate into.
|
||||
;
|
||||
; Mult1Reg - Supplies the first multiplication operand register.
|
||||
;
|
||||
; Mult2Reg - Supplies the second multiplication operand register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; zmm4 - Supplies a scratch register for intermediate results.
|
||||
;
|
||||
; zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
|
||||
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddubsw zmm4,Mult1Reg,Mult2Reg
|
||||
vpmaddwd zmm4,zmm4,zmm5
|
||||
vpaddd AccumReg,AccumReg,zmm4
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF ColumnCount GE 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
|
||||
ELSEIF ColumnCount GE 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
ELSE
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Generate the GEMM kernel.
|
||||
;
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512Core
|
||||
|
||||
END
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8S8KernelAvx512Vnni.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVX512VNNI instructions.
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8S8KernelAvx512Common.inc
|
||||
INCLUDE AssembleAvx512Vnni.inc
|
||||
.list
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF ColumnCount GE 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
|
||||
ELSEIF ColumnCount GE 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
ELSE
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Generate the GEMM kernel.
|
||||
;
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
|
||||
|
||||
END
|
||||
|
|
@ -1,298 +0,0 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Intel Corporation 2020. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8S8KernelAvxVnni.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVXVNNI instructions.
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8X8KernelAvx2Common.inc
|
||||
INCLUDE AssembleAvxVnni.inc
|
||||
.list
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single row of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; Vec2Reg - Supplies the low block accumulator register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; ymm0 - Supplies the first vector loaded from matrix B.
|
||||
;
|
||||
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
;
|
||||
|
||||
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
IF ColumnCount EQ 16
|
||||
VpdpbusdsYmmYmmYmm Vec1Reg,ymm2,ymm0
|
||||
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm1
|
||||
ELSE
|
||||
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm0
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
|
||||
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, ymm12, ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, ymm14, ymm15>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,64 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ENDM
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8S8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8S8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
|
||||
; the number of packed rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
; zero point offset of matrix B. These values are accumulated into every
|
||||
; row of matrix C.
|
||||
;
|
||||
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
; of matrix A multplied by the zero point offset of matrix B. This value is
|
||||
; accumulated into every element of matrix C.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemmU8S8KernelAvxVnni, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR13)
|
||||
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
|
||||
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
|
||||
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
|
||||
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
|
||||
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
|
||||
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
|
||||
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
|
||||
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
|
||||
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
|
||||
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov rdi,rcx
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
cmp r11,5
|
||||
ja ProcessCountM6
|
||||
je ProcessCountM5
|
||||
cmp r11,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp r11,1
|
||||
je ProcessCountM1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
ProcessCountM6:
|
||||
mov r11d,6 ; return 6 rows handled
|
||||
ProcessCountM 6, Fallthrough
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
mov eax,r11d
|
||||
vzeroupper
|
||||
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
|
||||
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
|
||||
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
|
||||
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
|
||||
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
|
||||
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
|
||||
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
|
||||
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
|
||||
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
|
||||
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR13)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
ProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
NESTED_END MlasGemmU8S8KernelAvxVnni, _TEXT
|
||||
|
||||
END
|
||||
|
|
@ -19,9 +19,10 @@
|
|||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8X8KernelAvx2Common.inc
|
||||
.list
|
||||
|
||||
EXTERN MlasMaskMoveTableAvx:NEAR
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8U8 CopyPackA routine.
|
||||
;
|
||||
|
|
@ -136,9 +137,9 @@ GemmU8U8CopyPackBFrame ENDS
|
|||
and eax,15 ; isolate unaligned count
|
||||
inc eax
|
||||
shr eax,1 ; align unaligned count to pair count
|
||||
mov DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp],eax
|
||||
vpbroadcastd ymm9,DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp]
|
||||
vpcmpgtd ymm9,ymm9,YMMWORD PTR [MlasMaskMoveAvx]
|
||||
neg rax
|
||||
lea rbx,MlasMaskMoveTableAvx+8*4
|
||||
vmovdqu ymm9,YMMWORD PTR [rbx+rax*4]
|
||||
|
||||
;
|
||||
; Zero initialize the padded stack buffers.
|
||||
|
|
@ -663,305 +664,4 @@ StoreColumnSumBufferNUnaligned:
|
|||
|
||||
NESTED_END MlasGemmU8U8CopyPackBAvx2, _TEXT
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single row of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; Vec2Reg - Supplies the low block accumulator register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; ymm0 - Supplies the first vector loaded from matrix B.
|
||||
;
|
||||
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
;
|
||||
|
||||
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd Vec1Reg,Vec1Reg,ymm3
|
||||
vpmaddwd ymm2,ymm2,ymm1
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm2
|
||||
ELSE
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm3
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset]
|
||||
EmitIfCountGE ColumnCount, 16, <vpmovzxbw ymm1,XMMWORD PTR [rdx+VectorOffset+16]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, ymm12, ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, ymm14, ymm15>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy2Loop
|
||||
LOCAL ProcessRemainingBlocks
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
LOCAL ComputeBlockLoopExit
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0)
|
||||
sub rsi,2*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
ComputeBlockBy2Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
ComputeBlock ColumnCount, RowCount, 32, 4
|
||||
add rcx,2*4 ; advance matrix A by 2 pairs
|
||||
IF RowCount GT 3
|
||||
add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs
|
||||
ENDIF
|
||||
add rdx,2*32 ; advance matrix B
|
||||
sub rsi,2*4
|
||||
jae ComputeBlockBy2Loop
|
||||
|
||||
ProcessRemainingBlocks:
|
||||
add rsi,2*4 ; correct for over-subtract above
|
||||
jz ComputeBlockLoopExit
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rdx,32 ; advance matrix B
|
||||
ELSE
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 pair
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
|
||||
ENDIF
|
||||
add rdx,32 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
ENDIF
|
||||
|
||||
ComputeBlockLoopExit:
|
||||
|
||||
ENDM
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8U8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8U8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
|
||||
; the number of packed rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
; zero point offset of matrix B. These values are accumulated into every
|
||||
; row of matrix C.
|
||||
;
|
||||
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
; of matrix A multplied by the zero point offset of matrix B. This value is
|
||||
; accumulated into every element of matrix C.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemmU8U8KernelAvx2, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR13)
|
||||
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
|
||||
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
|
||||
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
|
||||
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
|
||||
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
|
||||
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
|
||||
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
|
||||
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
|
||||
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
|
||||
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov rdi,rcx
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
cmp r11,5
|
||||
ja ProcessCountM6
|
||||
je ProcessCountM5
|
||||
cmp r11,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp r11,1
|
||||
je ProcessCountM1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
ProcessCountM6:
|
||||
mov r11d,6 ; return 6 rows handled
|
||||
ProcessCountM 6, Fallthrough
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
mov eax,r11d
|
||||
vzeroupper
|
||||
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
|
||||
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
|
||||
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
|
||||
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
|
||||
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
|
||||
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
|
||||
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
|
||||
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
|
||||
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
|
||||
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR13)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
ProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
NESTED_END MlasGemmU8U8KernelAvx2, _TEXT
|
||||
|
||||
END
|
||||
|
|
|
|||
|
|
@ -1,172 +0,0 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8U8KernelAvx512Core.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVX512 core instructions (BW/DQ/VL).
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8X8KernelAvx512Common.inc
|
||||
.list
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single cell of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; AccumReg - Supplies the register to accumulate into.
|
||||
;
|
||||
; Mult1Reg - Supplies the first multiplication operand register.
|
||||
;
|
||||
; Mult2Reg - Supplies the second multiplication operand register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; zmm4 - Supplies a scratch register for intermediate results.
|
||||
;
|
||||
|
||||
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddwd zmm4,Mult1Reg,Mult2Reg
|
||||
vpaddd AccumReg,AccumReg,zmm4
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF ColumnCount GE 48
|
||||
vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset]
|
||||
ELSEIF ColumnCount GE 32
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset]
|
||||
ELSE
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 pair
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
|
||||
ENDIF
|
||||
add rdx,32 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Generate the GEMM kernel.
|
||||
;
|
||||
|
||||
GemmU8X8KernelAvx512Function U8U8, Avx512Core
|
||||
|
||||
END
|
||||
940
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm
Normal file
940
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm
Normal file
|
|
@ -0,0 +1,940 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8X8KernelAvx2.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVX2 and AVX VNNI instructions.
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE AssembleAvxVnni.inc
|
||||
.list
|
||||
|
||||
EXTERN MlasMaskMoveTableAvx:NEAR
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8X8 kernel.
|
||||
;
|
||||
|
||||
GemmU8X8KernelFrame STRUCT
|
||||
|
||||
SavedXmm6 OWORD ?
|
||||
SavedXmm7 OWORD ?
|
||||
SavedXmm8 OWORD ?
|
||||
SavedXmm9 OWORD ?
|
||||
SavedXmm10 OWORD ?
|
||||
SavedXmm11 OWORD ?
|
||||
SavedXmm12 OWORD ?
|
||||
SavedXmm13 OWORD ?
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
Padding QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumBuffer QWORD ?
|
||||
ColumnSumBuffer QWORD ?
|
||||
ZeroPointB QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8X8KernelFrame ENDS
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single row of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; Vec2Reg - Supplies the low block accumulator register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; ymm0 - Supplies the first vector loaded from matrix B.
|
||||
;
|
||||
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
;
|
||||
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
|
||||
MultiplyAccumulateRowU8S8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddubsw ymm3,ymm2,ymm0
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd Vec1Reg,Vec1Reg,ymm3
|
||||
vpmaddubsw ymm2,ymm2,ymm1
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm2
|
||||
ELSE
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm3
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm11 - Supplies the block accumulators.
|
||||
;
|
||||
; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
|
||||
ComputeBlockU8S8Avx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF RowCount EQ 1
|
||||
vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]
|
||||
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd ymm4,ymm4,ymm3
|
||||
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32]
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd ymm5,ymm5,ymm2
|
||||
ELSE
|
||||
vpaddd ymm5,ymm5,ymm3
|
||||
ENDIF
|
||||
ELSE
|
||||
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
|
||||
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm4, ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm6, ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm8, ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRowU8S8Avx2 ColumnCount, ymm10, ymm11>
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single row of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; Vec2Reg - Supplies the low block accumulator register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; ymm0 - Supplies the first vector loaded from matrix B.
|
||||
;
|
||||
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
;
|
||||
|
||||
MultiplyAccumulateRowU8S8AvxVnni MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
IF ColumnCount EQ 16
|
||||
VpdpbusdsYmmYmmYmm Vec1Reg,ymm2,ymm0
|
||||
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm1
|
||||
ELSE
|
||||
VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm0
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockU8S8AvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset]
|
||||
EmitIfCountGE ColumnCount, 16, <vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm4, ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm6, ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm8, ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm10, ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm12, ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRowU8S8AvxVnni ColumnCount, ymm14, ymm15>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple times
|
||||
; and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Isa - Supplies the instruction set architecture string.
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm11 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,64 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single row of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; Vec2Reg - Supplies the low block accumulator register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; ymm0 - Supplies the first vector loaded from matrix B.
|
||||
;
|
||||
; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
; is 16).
|
||||
;
|
||||
; ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
;
|
||||
|
||||
MultiplyAccumulateRowU8U8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd Vec1Reg,Vec1Reg,ymm3
|
||||
vpmaddwd ymm2,ymm2,ymm1
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm2
|
||||
ELSE
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm3
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockU8U8Avx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset]
|
||||
EmitIfCountGE ColumnCount, 16, <vpmovzxbw ymm1,XMMWORD PTR [rdx+VectorOffset+16]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm4, ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm6, ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm8, ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm10, ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm12, ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRowU8U8Avx2 ColumnCount, ymm14, ymm15>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple times
|
||||
; and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Isa - Supplies the instruction set architecture string.
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy2Loop
|
||||
LOCAL ProcessRemainingBlocks
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
LOCAL ExitComputeBlockLoop
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0)
|
||||
sub rsi,2*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
ComputeBlockBy2Loop:
|
||||
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0
|
||||
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 32, 4
|
||||
add rcx,2*4 ; advance matrix A by 2 pairs
|
||||
IF RowCount GT 3
|
||||
add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs
|
||||
ENDIF
|
||||
add rdx,2*32 ; advance matrix B
|
||||
sub rsi,2*4
|
||||
jae ComputeBlockBy2Loop
|
||||
|
||||
ProcessRemainingBlocks:
|
||||
add rsi,2*4 ; correct for over-subtract above
|
||||
jz ExitComputeBlockLoop
|
||||
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0
|
||||
add rdx,32 ; advance matrix B
|
||||
ELSE
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 pair
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
|
||||
ENDIF
|
||||
add rdx,32 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
ENDIF
|
||||
|
||||
ExitComputeBlockLoop:
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r11 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r12 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
; r13 - Optionally supplies the address of the matrix B zero point buffer.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL SkipScaleByZeroPointB
|
||||
LOCAL AccumulatorsInitialized
|
||||
LOCAL ProduceWithU8S8AvxVnni
|
||||
LOCAL ProduceWithU8U8Avx2
|
||||
LOCAL ExitProduceOutputBlock
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the row and column sums.
|
||||
;
|
||||
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r11]>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r11+4]>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r11+8]>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r11+12]>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r11+16]>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r11+20]>
|
||||
IF ColumnCount EQ 16
|
||||
vmovdqu ymm0,YMMWORD PTR [r12]
|
||||
vmovdqu ymm1,YMMWORD PTR [r12+32]
|
||||
add r12,16*4 ; advance ColumnSumBuffer by 16 columns
|
||||
ELSE
|
||||
vmovdqu ymm1,YMMWORD PTR [r12]
|
||||
ENDIF
|
||||
test r13,r13 ; per column zero points?
|
||||
jz SkipScaleByZeroPointB
|
||||
IF ColumnCount EQ 16
|
||||
vmovdqu ymm2,YMMWORD PTR [r13]
|
||||
vmovdqu ymm3,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 ; advance ZeroPointB by 16 columns
|
||||
ELSE
|
||||
vmovdqu ymm3,YMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpmulld ymm4,ymm5,ymm2>
|
||||
EmitIfCountGE RowCount, 1, <vpmulld ymm5,ymm5,ymm3>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm0,ymm4>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm1,ymm5>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpmulld ymm6,ymm7,ymm2>
|
||||
EmitIfCountGE RowCount, 2, <vpmulld ymm7,ymm7,ymm3>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm0,ymm6>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm1,ymm7>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpmulld ymm8,ymm9,ymm2>
|
||||
EmitIfCountGE RowCount, 3, <vpmulld ymm9,ymm9,ymm3>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm0,ymm8>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm1,ymm9>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpmulld ymm10,ymm11,ymm2>
|
||||
EmitIfCountGE RowCount, 4, <vpmulld ymm11,ymm11,ymm3>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm0,ymm10>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm1,ymm11>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpmulld ymm12,ymm13,ymm2>
|
||||
EmitIfCountGE RowCount, 5, <vpmulld ymm13,ymm13,ymm3>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm0,ymm12>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm1,ymm13>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpmulld ymm14,ymm15,ymm2>
|
||||
EmitIfCountGE RowCount, 6, <vpmulld ymm15,ymm15,ymm3>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm0,ymm14>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm1,ymm15>
|
||||
jmp AccumulatorsInitialized
|
||||
|
||||
SkipScaleByZeroPointB:
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm0,ymm5>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm1,ymm5>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm0,ymm7>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm1,ymm7>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm0,ymm9>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm1,ymm9>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm0,ymm11>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm1,ymm11>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm0,ymm13>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm1,ymm13>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm0,ymm15>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm1,ymm15>
|
||||
|
||||
AccumulatorsInitialized:
|
||||
|
||||
;
|
||||
; Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
;
|
||||
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r9*2+r9]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
|
||||
jg ProduceWithU8U8Avx2
|
||||
IF RowCount LE 4
|
||||
jl ProduceWithU8S8AvxVnni
|
||||
ComputeBlockLoopU8S8 Avx2, ColumnCount, RowCount
|
||||
jmp ExitProduceOutputBlock
|
||||
ENDIF
|
||||
|
||||
ProduceWithU8S8AvxVnni:
|
||||
ComputeBlockLoopU8S8 AvxVnni, ColumnCount, RowCount
|
||||
jmp ExitProduceOutputBlock
|
||||
|
||||
ProduceWithU8U8Avx2:
|
||||
ComputeBlockLoopU8U8 Avx2, ColumnCount, RowCount
|
||||
|
||||
ExitProduceOutputBlock:
|
||||
IF RowCount GT 3
|
||||
lea rbx,[rax*2+rax]
|
||||
add rbx,r8 ; compute matrix C plus 3 rows
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10b - Supplies the zero mode flag.
|
||||
;
|
||||
; r11 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r12 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
; r13 - Optionally supplies the address of the matrix B zero point buffer.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount, Fallthrough
|
||||
|
||||
LOCAL ProcessNextColumnLoop16xN
|
||||
LOCAL SkipAccumulateOutput16xNBlock
|
||||
LOCAL OutputMasked16xNBlock
|
||||
LOCAL ExitProcessCountM
|
||||
LOCAL ProcessRemainingCountN
|
||||
LOCAL SkipAccumulateOutput8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked16xNBlock
|
||||
LOCAL OutputMasked8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked8xNBlock
|
||||
|
||||
cmp rbp,8
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop16xN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
sub rbp,16
|
||||
jb OutputMasked16xNBlock
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
|
||||
|
||||
SkipAccumulateOutput16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,8
|
||||
ja ProcessNextColumnLoop16xN
|
||||
test rbp,rbp
|
||||
jnz ProcessRemainingCountN
|
||||
|
||||
ExitProcessCountM:
|
||||
mov eax,RowCount
|
||||
jmp ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 8, RowCount
|
||||
cmp rbp,8
|
||||
jb OutputMasked8xNBlock
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
|
||||
jmp ExitProcessCountM
|
||||
|
||||
OutputMasked16xNBlock:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutputMasked16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
add r8,8*4 ; advance matrix C by 8 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
|
||||
ENDIF
|
||||
add rbp,8 ; correct for over-subtract above
|
||||
|
||||
OutputMasked8xNBlock:
|
||||
neg rbp
|
||||
lea rcx,MlasMaskMoveTableAvx+8*4
|
||||
vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4]
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
|
||||
|
||||
SkipAccumulateOutputMasked8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
|
||||
jmp ExitProcessCountM
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Reduce code size for the various types of kernels by sharing the outer logic
|
||||
; and switching on the selector codes (using sign bit to discriminate).
|
||||
;
|
||||
|
||||
LEAF_ENTRY MlasGemmU8S8KernelAvxVnni, _TEXT
|
||||
|
||||
mov eax,-1
|
||||
jmp MlasGemmU8X8KernelAvx2
|
||||
|
||||
LEAF_END MlasGemmU8S8KernelAvxVnni, _TEXT
|
||||
|
||||
LEAF_ENTRY MlasGemmU8U8KernelAvx2, _TEXT
|
||||
|
||||
mov eax,1
|
||||
jmp MlasGemmU8X8KernelAvx2
|
||||
|
||||
LEAF_END MlasGemmU8U8KernelAvx2, _TEXT
|
||||
|
||||
LEAF_ENTRY MlasGemmU8S8KernelAvx2, _TEXT
|
||||
|
||||
xor eax,eax
|
||||
jmp MlasGemmU8X8KernelAvx2
|
||||
|
||||
LEAF_END MlasGemmU8S8KernelAvx2, _TEXT
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
|
||||
; the number of packed rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumBuffer - Supplies the sum of each row from matrix A. These values have
|
||||
; been pre-scaled by the zero point offset of matrix B if the offset is
|
||||
; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
; scaled by the per-column zero point offsets of matrix B. These values are
|
||||
; accumulated into every row of matrix C.
|
||||
;
|
||||
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
; B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemmU8X8KernelAvx2, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR13)
|
||||
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
|
||||
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
|
||||
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
|
||||
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
|
||||
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
|
||||
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
|
||||
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
|
||||
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
|
||||
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
|
||||
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],eax
|
||||
mov rdi,rcx
|
||||
mov rbx,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.RowSumBuffer[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ZeroPointB[rsp]
|
||||
vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF]
|
||||
vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001]
|
||||
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
|
||||
je CheckCountM4OrMore ; U8S8 AVX2 kernel requires extra registers
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
CheckCountM6OrMore:
|
||||
cmp rbx,5
|
||||
ja ProcessCountM6
|
||||
je ProcessCountM5
|
||||
|
||||
CheckCountM4OrMore:
|
||||
cmp rbx,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp rbx,1
|
||||
je ProcessCountM1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
ProcessCountM6:
|
||||
ProcessCountM 6
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
vzeroupper
|
||||
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
|
||||
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
|
||||
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
|
||||
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
|
||||
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
|
||||
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
|
||||
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
|
||||
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
|
||||
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
|
||||
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR13)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
ProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
NESTED_END MlasGemmU8X8KernelAvx2, _TEXT
|
||||
|
||||
END
|
||||
|
|
@ -1,302 +0,0 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8X8KernelAvx2Common.inc
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module contains common kernel macros and structures for the quantized
|
||||
; integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
|
||||
;
|
||||
;--
|
||||
|
||||
EXTERN MlasMaskMoveAvx:NEAR
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
;
|
||||
|
||||
GemmU8X8KernelFrame STRUCT
|
||||
|
||||
SavedXmm6 OWORD ?
|
||||
SavedXmm7 OWORD ?
|
||||
SavedXmm8 OWORD ?
|
||||
SavedXmm9 OWORD ?
|
||||
SavedXmm10 OWORD ?
|
||||
SavedXmm11 OWORD ?
|
||||
SavedXmm12 OWORD ?
|
||||
SavedXmm13 OWORD ?
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
Padding QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumBuffer QWORD ?
|
||||
ColumnSumBuffer QWORD ?
|
||||
DepthValue QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8X8KernelFrame ENDS
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the sum of the global depth value constant,
|
||||
; the column sums, and the row sums.
|
||||
;
|
||||
|
||||
vpbroadcastd ymm1,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 ; advance ColumnSumBuffer by 16 columns
|
||||
ELSE
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r12]>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r12+4]>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r12+8]>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r12+12]>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r12+16]>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r12+20]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm5,ymm0>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm7,ymm0>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm9,ymm0>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm11,ymm0>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm13,ymm0>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm15,ymm0>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm1>
|
||||
|
||||
;
|
||||
; Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
;
|
||||
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r9*2+r9]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
ComputeBlockLoop ColumnCount, RowCount
|
||||
IF RowCount GT 3
|
||||
lea rbx,[rax*2+rax]
|
||||
add rbx,r8 ; compute matrix C plus 3 rows
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Fallthrough - Supplies a non-blank value if the macro may fall through to
|
||||
; the ExitKernel label.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10b - Supplies the zero mode flag.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount, Fallthrough
|
||||
|
||||
LOCAL ProcessNextColumnLoop16xN
|
||||
LOCAL SkipAccumulateOutput16xNBlock
|
||||
LOCAL OutputMasked16xNBlock
|
||||
LOCAL ProcessRemainingCountN
|
||||
LOCAL SkipAccumulateOutput8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked16xNBlock
|
||||
LOCAL OutputMasked8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked8xNBlock
|
||||
|
||||
cmp rbp,8
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop16xN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
sub rbp,16
|
||||
jb OutputMasked16xNBlock
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
|
||||
|
||||
SkipAccumulateOutput16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,8
|
||||
ja ProcessNextColumnLoop16xN
|
||||
test rbp,rbp
|
||||
jz ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 8, RowCount
|
||||
cmp rbp,8
|
||||
jb OutputMasked8xNBlock
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
|
||||
jmp ExitKernel
|
||||
|
||||
OutputMasked16xNBlock:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutputMasked16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
add r8,8*4 ; advance matrix C by 8 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
|
||||
ENDIF
|
||||
add rbp,8 ; correct for over-subtract above
|
||||
|
||||
OutputMasked8xNBlock:
|
||||
mov DWORD PTR GemmU8X8KernelFrame.CountN[rsp],ebp
|
||||
vpbroadcastd ymm0,DWORD PTR GemmU8X8KernelFrame.CountN[rsp]
|
||||
vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx]
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
|
||||
|
||||
SkipAccumulateOutputMasked8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
|
||||
IFB <Fallthrough>
|
||||
jmp ExitKernel
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
|
@ -1,438 +0,0 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8X8KernelAvx512Common.inc
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module contains common kernel macros and structures for the quantized
|
||||
; integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
|
||||
; AVX512VNNI kernels.
|
||||
;
|
||||
;--
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
;
|
||||
|
||||
GemmU8X8KernelFrame STRUCT
|
||||
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
SavedR14 QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumBuffer QWORD ?
|
||||
ColumnSumBuffer QWORD ?
|
||||
DepthValue QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8X8KernelFrame ENDS
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the sum of the global depth value constant,
|
||||
; the column sums, and the row sums.
|
||||
;
|
||||
|
||||
vpbroadcastd zmm3,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
|
||||
IF ColumnCount GE 32
|
||||
IF ColumnCount GE 48
|
||||
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
|
||||
ELSE
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
|
||||
ENDIF
|
||||
add_immed r13,ColumnCount*4 ; advance ColumnSumBuffer by N columns
|
||||
ELSE
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,DWORD BCST [r12]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,DWORD BCST [r12]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,DWORD BCST [r12]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,DWORD BCST [r12+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,DWORD BCST [r12+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,DWORD BCST [r12+4]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,DWORD BCST [r12+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,DWORD BCST [r12+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,DWORD BCST [r12+8]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,DWORD BCST [r12+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,DWORD BCST [r12+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,DWORD BCST [r12+12]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,DWORD BCST [r12+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,DWORD BCST [r12+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,DWORD BCST [r12+16]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,DWORD BCST [r12+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,DWORD BCST [r12+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,DWORD BCST [r12+20]>
|
||||
|
||||
;
|
||||
; Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
;
|
||||
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r9*2+r9]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
ComputeBlockLoop ColumnCount, RowCount
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10b - Supplies the zero mode flag.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount
|
||||
|
||||
LOCAL ProcessNextColumnLoop32xN
|
||||
LOCAL Output32xNBlock
|
||||
LOCAL SkipAccumulateOutput32xNBlock
|
||||
LOCAL Output16xNBlock
|
||||
LOCAL Output16xNBlockWithMask
|
||||
LOCAL SkipAccumulateOutput16xNBlockWithMask
|
||||
LOCAL ProcessRemainingCountN
|
||||
LOCAL ProcessNextColumnLoop48xN
|
||||
LOCAL SkipAccumulateOutput48xNBlock
|
||||
|
||||
cmp rbp,32
|
||||
ja ProcessNextColumnLoop48xN
|
||||
cmp rbp,16
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop32xN:
|
||||
ProduceOutputBlock 32, RowCount
|
||||
add rdx,r14 ; advance matrix B by packed block stride
|
||||
|
||||
Output32xNBlock:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput32xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm21,zmm21,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm22,zmm22,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput32xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm20>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm21>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm22>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm23>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
|
||||
ENDIF
|
||||
sub rbp,16
|
||||
|
||||
Output16xNBlock:
|
||||
sub rbp,16
|
||||
jae Output16xNBlockWithMask
|
||||
lea ecx,[ebp+16] ; correct for over-subtract above
|
||||
mov esi,1
|
||||
shl esi,cl
|
||||
dec esi
|
||||
kmovw k1,esi ; update mask for remaining columns
|
||||
xor ebp,ebp ; no more columns remaining
|
||||
|
||||
Output16xNBlockWithMask:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlockWithMask
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm14{k1},zmm14,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm15{k1},zmm15,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm16{k1},zmm16,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput16xNBlockWithMask:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm14>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm15>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm16>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,32
|
||||
ja ProcessNextColumnLoop48xN
|
||||
cmp rbp,16
|
||||
ja ProcessNextColumnLoop32xN
|
||||
test rbp,rbp
|
||||
jz ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
jmp Output16xNBlock
|
||||
|
||||
ProcessNextColumnLoop48xN:
|
||||
ProduceOutputBlock 48, RowCount
|
||||
lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput48xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm26,zmm26,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm27,zmm27,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm28,zmm28,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput48xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm26>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm27>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm28>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm29>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
|
||||
ENDIF
|
||||
sub rbp,16
|
||||
jmp Output32xNBlock
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates the common AVX512 code for the inner kernel to compute
|
||||
; matrix multiplication.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Type - Supplies the kernel type string for function tags.
|
||||
;
|
||||
; Isa - Supplies the instruction set architecture string for function tags.
|
||||
;
|
||||
|
||||
GemmU8X8KernelAvx512Function MACRO Type, Isa
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
|
||||
; the number of packed rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
; zero point offset of matrix B. These values are accumulated into every
|
||||
; row of matrix C.
|
||||
;
|
||||
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
; of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
; accumulated into every element of matrix C.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemm&Type&Kernel&Isa&, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
push_reg r14
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR14)
|
||||
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
|
||||
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov rdi,rcx
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
|
||||
mov esi,-1
|
||||
kmovw k1,esi ; update mask to write all columns
|
||||
IFIDNI <Type>, <U8S8>
|
||||
IFIDNI <Isa>, <Avx512Core>
|
||||
neg esi
|
||||
vpbroadcastw zmm5,esi ; generate 512-bit word vector [0x0001]
|
||||
ENDIF
|
||||
mov r14,r9
|
||||
shl r14,4 ; compute matrix B packed stride
|
||||
ELSE
|
||||
lea r14,[r9*8] ; compute matrix B packed stride
|
||||
ENDIF
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
cmp r11,5
|
||||
ja ProcessCountM6
|
||||
je ProcessCountM5
|
||||
cmp r11,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp r11,1
|
||||
je ProcessCountM1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
ProcessCountM6:
|
||||
mov r11d,6 ; return 6 rows handled
|
||||
ProcessCountM 6
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
mov eax,r11d
|
||||
vzeroupper
|
||||
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
|
||||
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR14)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
ProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
NESTED_END MlasGemm&Type&Kernel&Isa&, _TEXT
|
||||
|
||||
ENDM
|
||||
764
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm
Normal file
764
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm
Normal file
|
|
@ -0,0 +1,764 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8X8KernelAvx512Core.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions.
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE AssembleAvx512Vnni.inc
|
||||
.list
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8X8 kernel.
|
||||
;
|
||||
|
||||
GemmU8X8KernelFrame STRUCT
|
||||
|
||||
SavedXmm13 OWORD ?
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
SavedR14 QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumBuffer QWORD ?
|
||||
ColumnSumBuffer QWORD ?
|
||||
ZeroPointB QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8X8KernelFrame ENDS
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to load packed data from matrix B.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; VecReg - Supplies the register to load the data into.
|
||||
;
|
||||
; AddressOperand - Supplies the address operand.
|
||||
;
|
||||
|
||||
LoadPackedMatrixBU8S8 MACRO VecReg, AddressOperand
|
||||
|
||||
vmovdqu32 VecReg,ZMMWORD PTR AddressOperand
|
||||
|
||||
ENDM
|
||||
|
||||
LoadPackedMatrixBU8U8 MACRO VecReg, AddressOperand
|
||||
|
||||
vpmovzxbw VecReg,YMMWORD PTR AddressOperand
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single cell of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; AccumReg - Supplies the register to accumulate into.
|
||||
;
|
||||
; Mult1Reg - Supplies the first multiplication operand register.
|
||||
;
|
||||
; Mult2Reg - Supplies the second multiplication operand register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; zmm4 - Supplies a scratch register for intermediate results.
|
||||
;
|
||||
; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
|
||||
MultiplyAccumulateCellU8S8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddubsw zmm4,Mult1Reg,Mult2Reg
|
||||
vpmaddwd zmm4,zmm4,zmm13
|
||||
vpaddd AccumReg,AccumReg,zmm4
|
||||
|
||||
ENDM
|
||||
|
||||
MultiplyAccumulateCellU8S8Avx512Vnni MACRO AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
VpdpbusdsZmmZmmZmm AccumReg,Mult1Reg,Mult2Reg
|
||||
|
||||
ENDM
|
||||
|
||||
MultiplyAccumulateCellU8U8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddwd zmm4,Mult1Reg,Mult2Reg
|
||||
vpaddd AccumReg,AccumReg,zmm4
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Type - Supplies the type of kernel to generate (U8S8 or U8U8).
|
||||
;
|
||||
; Isa - Supplies the instruction set architecture string.
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF ColumnCount GE 48
|
||||
LoadPackedMatrixB&Type& zmm0,[rdx+VectorOffset]
|
||||
LoadPackedMatrixB&Type& zmm1,[rdx+r14+VectorOffset]
|
||||
LoadPackedMatrixB&Type& zmm2,[rdx+r14*2+VectorOffset]
|
||||
ELSEIF ColumnCount GE 32
|
||||
LoadPackedMatrixB&Type& zmm1,[rdx+VectorOffset]
|
||||
LoadPackedMatrixB&Type& zmm2,[rdx+r14+VectorOffset]
|
||||
ELSE
|
||||
LoadPackedMatrixB&Type& zmm2,[rdx+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell&Type&&Isa& zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell&Type&&Isa& zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell&Type&&Isa& zmm19,zmm3,zmm2>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple times
|
||||
; and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Isa - Supplies the instruction set architecture string.
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy4Loop
|
||||
LOCAL ProcessRemainingBlocks
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
LOCAL ComputeBlockLoopExit
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF ((RowCount AND 1) EQ 0)
|
||||
sub rsi,4*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
ComputeBlockBy4Loop:
|
||||
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0*64, 0
|
||||
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 1*64, 4
|
||||
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 2*64, 8
|
||||
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 3*64, 12
|
||||
add rcx,4*4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,4*64 ; advance matrix B
|
||||
sub rsi,4*4 ; decrement quads remaining
|
||||
jae ComputeBlockBy4Loop
|
||||
|
||||
ProcessRemainingBlocks:
|
||||
add rsi,4*4 ; correct for over-subtract above
|
||||
jz ComputeBlockLoopExit
|
||||
ENDIF
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,64 ; advance matrix B
|
||||
sub rsi,4 ; decrement quads remaining
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ComputeBlockLoopExit:
|
||||
|
||||
ENDM
|
||||
|
||||
ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock U8U8, Isa, ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 pair
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
|
||||
ENDIF
|
||||
add rdx,32 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r11 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r12 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL SkipScaleByZeroPointB
|
||||
LOCAL AccumulatorsInitialized
|
||||
LOCAL ProduceWithU8S8Avx512Core
|
||||
LOCAL ProduceWithU8U8Avx512Core
|
||||
LOCAL ExitProduceOutputBlock
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the row and column sums.
|
||||
;
|
||||
|
||||
IF ColumnCount GE 32
|
||||
IF ColumnCount GE 48
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [r12]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [r12+64]
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [r12+128]
|
||||
ELSE
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [r12]
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [r12+64]
|
||||
ENDIF
|
||||
add_immed r12,ColumnCount*4 ; advance ColumnSumBuffer by N columns
|
||||
ELSE
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [r12]
|
||||
ENDIF
|
||||
test r13,r13 ; per column zero points?
|
||||
jz SkipScaleByZeroPointB
|
||||
IF ColumnCount GE 32
|
||||
IF ColumnCount GE 48
|
||||
vmovdqu32 zmm5,ZMMWORD PTR [r13]
|
||||
vmovdqu32 zmm4,ZMMWORD PTR [r13+64]
|
||||
vmovdqu32 zmm3,ZMMWORD PTR [r13+128]
|
||||
ELSE
|
||||
vmovdqu32 zmm4,ZMMWORD PTR [r13]
|
||||
vmovdqu32 zmm3,ZMMWORD PTR [r13+64]
|
||||
ENDIF
|
||||
add_immed r13,ColumnCount*4 ; advance ZeroPointB by N columns
|
||||
ELSE
|
||||
vmovdqu32 zmm3,ZMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpmulld zmm14,zmm3,DWORD BCST [r11]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpmulld zmm20,zmm4,DWORD BCST [r11]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpmulld zmm26,zmm5,DWORD BCST [r11]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,zmm14>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,zmm20>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,zmm26>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpmulld zmm15,zmm3,DWORD BCST [r11+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpmulld zmm21,zmm4,DWORD BCST [r11+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpmulld zmm27,zmm5,DWORD BCST [r11+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,zmm15>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,zmm21>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,zmm27>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpmulld zmm16,zmm3,DWORD BCST [r11+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpmulld zmm22,zmm4,DWORD BCST [r11+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpmulld zmm28,zmm5,DWORD BCST [r11+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,zmm16>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,zmm22>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,zmm28>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpmulld zmm17,zmm3,DWORD BCST [r11+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpmulld zmm23,zmm4,DWORD BCST [r11+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpmulld zmm29,zmm5,DWORD BCST [r11+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,zmm17>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,zmm23>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,zmm29>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpmulld zmm18,zmm3,DWORD BCST [r11+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpmulld zmm24,zmm4,DWORD BCST [r11+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpmulld zmm30,zmm5,DWORD BCST [r11+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,zmm18>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,zmm24>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,zmm30>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpmulld zmm19,zmm3,DWORD BCST [r11+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpmulld zmm25,zmm4,DWORD BCST [r11+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpmulld zmm31,zmm5,DWORD BCST [r11+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,zmm19>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,zmm25>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,zmm31>
|
||||
jmp AccumulatorsInitialized
|
||||
|
||||
SkipScaleByZeroPointB:
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,DWORD BCST [r11]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,DWORD BCST [r11]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,DWORD BCST [r11]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,DWORD BCST [r11+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,DWORD BCST [r11+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,DWORD BCST [r11+4]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,DWORD BCST [r11+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,DWORD BCST [r11+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,DWORD BCST [r11+8]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,DWORD BCST [r11+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,DWORD BCST [r11+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,DWORD BCST [r11+12]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,DWORD BCST [r11+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,DWORD BCST [r11+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,DWORD BCST [r11+16]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,DWORD BCST [r11+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,DWORD BCST [r11+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,DWORD BCST [r11+20]>
|
||||
|
||||
AccumulatorsInitialized:
|
||||
|
||||
;
|
||||
; Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
;
|
||||
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r9*2+r9]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
|
||||
je ProduceWithU8S8Avx512Core
|
||||
jg ProduceWithU8U8Avx512Core
|
||||
ComputeBlockLoopU8S8 Avx512Vnni, ColumnCount, RowCount
|
||||
jmp ExitProduceOutputBlock
|
||||
|
||||
ProduceWithU8U8Avx512Core:
|
||||
ComputeBlockLoopU8U8 Avx512Core, ColumnCount, RowCount
|
||||
jmp ExitProduceOutputBlock
|
||||
|
||||
ProduceWithU8S8Avx512Core:
|
||||
ComputeBlockLoopU8S8 Avx512Core, ColumnCount, RowCount
|
||||
|
||||
ExitProduceOutputBlock:
|
||||
IF RowCount GT 3
|
||||
lea rbx,[rax*2+rax]
|
||||
add rbx,r8 ; compute matrix C plus 3 rows
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10b - Supplies the zero mode flag.
|
||||
;
|
||||
; r11 - Supplies the address of the row sum buffer.
|
||||
;
|
||||
; r12 - Supplies the address of the column sum buffer.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount
|
||||
|
||||
LOCAL ProcessNextColumnLoop32xN
|
||||
LOCAL Output32xNBlock
|
||||
LOCAL SkipAccumulateOutput32xNBlock
|
||||
LOCAL Output16xNBlock
|
||||
LOCAL Output16xNBlockWithMask
|
||||
LOCAL SkipAccumulateOutput16xNBlockWithMask
|
||||
LOCAL ProcessRemainingCountN
|
||||
LOCAL ProcessNextColumnLoop48xN
|
||||
LOCAL SkipAccumulateOutput48xNBlock
|
||||
|
||||
cmp rbp,32
|
||||
ja ProcessNextColumnLoop48xN
|
||||
cmp rbp,16
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop32xN:
|
||||
ProduceOutputBlock 32, RowCount
|
||||
add rdx,r14 ; advance matrix B by packed block stride
|
||||
|
||||
Output32xNBlock:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput32xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm21,zmm21,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm22,zmm22,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput32xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm20>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm21>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm22>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm23>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
|
||||
ENDIF
|
||||
sub rbp,16
|
||||
|
||||
Output16xNBlock:
|
||||
sub rbp,16
|
||||
jae Output16xNBlockWithMask
|
||||
lea ecx,[ebp+16] ; correct for over-subtract above
|
||||
mov esi,1
|
||||
shl esi,cl
|
||||
dec esi
|
||||
kmovw k1,esi ; update mask for remaining columns
|
||||
xor ebp,ebp ; no more columns remaining
|
||||
|
||||
Output16xNBlockWithMask:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlockWithMask
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm14{k1},zmm14,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm15{k1},zmm15,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm16{k1},zmm16,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput16xNBlockWithMask:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm14>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm15>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm16>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,32
|
||||
ja ProcessNextColumnLoop48xN
|
||||
cmp rbp,16
|
||||
ja ProcessNextColumnLoop32xN
|
||||
test rbp,rbp
|
||||
jnz ProcessRemainingCountN
|
||||
mov eax,RowCount
|
||||
jmp ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
jmp Output16xNBlock
|
||||
|
||||
ProcessNextColumnLoop48xN:
|
||||
ProduceOutputBlock 48, RowCount
|
||||
lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput48xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm26,zmm26,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm27,zmm27,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm28,zmm28,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput48xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm26>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm27>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm28>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm29>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
|
||||
ENDIF
|
||||
sub rbp,16
|
||||
jmp Output32xNBlock
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Reduce code size for the various types of kernels by sharing the outer logic
|
||||
; and switching on the selector codes (using sign bit to discriminate).
|
||||
;
|
||||
|
||||
LEAF_ENTRY MlasGemmU8S8KernelAvx512Vnni, _TEXT
|
||||
|
||||
mov eax,-1
|
||||
jmp MlasGemmU8X8KernelAvx512Core
|
||||
|
||||
LEAF_END MlasGemmU8S8KernelAvx512Vnni, _TEXT
|
||||
|
||||
LEAF_ENTRY MlasGemmU8U8KernelAvx512Core, _TEXT
|
||||
|
||||
mov eax,1
|
||||
jmp MlasGemmU8X8KernelAvx512Core
|
||||
|
||||
LEAF_END MlasGemmU8U8KernelAvx512Core, _TEXT
|
||||
|
||||
LEAF_ENTRY MlasGemmU8S8KernelAvx512Core, _TEXT
|
||||
|
||||
xor eax,eax
|
||||
jmp MlasGemmU8X8KernelAvx512Core
|
||||
|
||||
LEAF_END MlasGemmU8S8KernelAvx512Core, _TEXT
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PackedCountK (r9) - Supplies the number of packed columns from matrix A and
|
||||
; the number of packed rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumBuffer - Supplies the sum of each row from matrix A. These values have
|
||||
; been pre-scaled by the zero point offset of matrix B if the offset is
|
||||
; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
; scaled by the per-column zero point offsets of matrix B. These values are
|
||||
; accumulated into every row of matrix C.
|
||||
;
|
||||
; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
; B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemmU8X8KernelAvx512Core, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
push_reg r14
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR14)
|
||||
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
|
||||
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
|
||||
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],eax
|
||||
mov rdi,rcx
|
||||
mov rbx,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.RowSumBuffer[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.ColumnSumBuffer[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ZeroPointB[rsp]
|
||||
mov esi,-1
|
||||
kmovw k1,esi ; update mask to write all columns
|
||||
neg esi
|
||||
vpbroadcastw zmm13,esi ; generate 512-bit word vector [0x0001]
|
||||
lea rsi,[r9*8] ; compute matrix B packed stride (U8U8)
|
||||
lea r14,[rsi*2] ; compute matrix B packed stride (U8S8)
|
||||
cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0
|
||||
cmovg r14,rsi ; select matrix B packed stride
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
cmp rbx,5
|
||||
ja ProcessCountM6
|
||||
je ProcessCountM5
|
||||
cmp rbx,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp rbx,1
|
||||
ja ProcessCountM2
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
ProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
ProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
ProcessCountM6:
|
||||
ProcessCountM 6
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
vzeroupper
|
||||
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
|
||||
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
|
||||
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR14)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
NESTED_END MlasGemmU8X8KernelAvx512Core, _TEXT
|
||||
|
||||
END
|
||||
46
onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h
Normal file
46
onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
AssembleDotProduct.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains macros to build Advanced SIMD dot product instructions
|
||||
for toolchains that do not natively support this newer instruction set
|
||||
extension.
|
||||
|
||||
This implementation uses ARM v8.4 dot product instructions.
|
||||
|
||||
--*/
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro builds a UDOT instruction of the form:
|
||||
|
||||
UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index]
|
||||
|
||||
Arguments:
|
||||
|
||||
DestReg - Specifies the destination register.
|
||||
|
||||
Src1Reg - Specifies the first source register.
|
||||
|
||||
Src2Reg - Specifies the second source register.
|
||||
|
||||
Index - Specifies the element index of the second source register.
|
||||
|
||||
--*/
|
||||
|
||||
MACRO
|
||||
UdotByElement $DestReg, $Src1Reg, $Src2Reg, $Index
|
||||
|
||||
DCD 0x6F80E000:OR:($DestReg):OR:($Src1Reg:SHL:5):OR:($Src2Reg:SHL:16):OR:(($Index:AND:2):SHL:10):OR:(($Index:AND:1):SHL:21)
|
||||
|
||||
MEND
|
||||
|
|
@ -22,7 +22,7 @@ Abstract:
|
|||
//
|
||||
|
||||
#define GemmU8X8KernelFrame_ColumnSumBuffer 0
|
||||
#define GemmU8X8KernelFrame_DepthValue 8
|
||||
#define GemmU8X8KernelFrame_ZeroPointB 8
|
||||
#define GemmU8X8KernelFrame_ZeroMode 16
|
||||
|
||||
//
|
||||
|
|
@ -75,10 +75,6 @@ Arguments:
|
|||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrix A multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized, else
|
||||
false if the output matrix is accumulated into.
|
||||
|
||||
|
|
@ -91,18 +87,16 @@ Return Value:
|
|||
LEAF_ENTRY MlasGemmU8X8KernelNeon
|
||||
|
||||
ldr x8,[sp,#GemmU8X8KernelFrame_ColumnSumBuffer]
|
||||
ldr s27,[sp,#GemmU8X8KernelFrame_DepthValue]
|
||||
ldr x9,[sp,#GemmU8X8KernelFrame_ZeroPointB]
|
||||
ldrb w13,[sp,#GemmU8X8KernelFrame_ZeroMode]
|
||||
dup v27.4s,v27.s[0]
|
||||
mov x14,x0
|
||||
ld1 {v0.4s},[x7]
|
||||
ld1 {v27.4s},[x7]
|
||||
mov x15,x3
|
||||
add v27.4s,v27.4s,v0.4s // broadcast add DepthValue
|
||||
dup v24.4s,v27.s[0] // broadcast row fixups
|
||||
cmp x4,#1 // CountM == 1?
|
||||
beq ProcessNextColumnLoopM1
|
||||
dup v25.4s,v27.s[1]
|
||||
cmp x4,#4 // CountM < 4 ?
|
||||
cmp x4,#4 // CountM < 4?
|
||||
blo ProcessNextColumnLoopM2
|
||||
dup v26.4s,v27.s[2]
|
||||
dup v27.4s,v27.s[3]
|
||||
|
|
@ -118,6 +112,30 @@ ProcessNextColumnLoopM4
|
|||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
uxtl v0.8h,v0.8b
|
||||
cbz x9,SkipScaleByZeroPointBM4
|
||||
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v24.4s,v28.4s
|
||||
mul v17.4s,v24.4s,v29.4s
|
||||
mul v18.4s,v25.4s,v28.4s
|
||||
mul v19.4s,v25.4s,v29.4s
|
||||
mul v20.4s,v26.4s,v28.4s
|
||||
mul v21.4s,v26.4s,v29.4s
|
||||
mul v22.4s,v27.4s,v28.4s
|
||||
mul v23.4s,v27.4s,v29.4s
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
ld1 {v5.8b},[x0],#8 // load first packed A1
|
||||
add v20.4s,v2.4s,v20.4s
|
||||
add v21.4s,v3.4s,v21.4s
|
||||
add v22.4s,v2.4s,v22.4s
|
||||
add v23.4s,v3.4s,v23.4s
|
||||
b ComputeBlockLoopM4
|
||||
|
||||
SkipScaleByZeroPointBM4
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v24.4s
|
||||
add v17.4s,v3.4s,v24.4s
|
||||
|
|
@ -329,6 +347,21 @@ ProcessNextColumnLoopM2
|
|||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
uxtl v0.8h,v0.8b
|
||||
cbz x9,SkipScaleByZeroPointBM2
|
||||
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v24.4s,v28.4s
|
||||
mul v17.4s,v24.4s,v29.4s
|
||||
mul v18.4s,v25.4s,v28.4s
|
||||
mul v19.4s,v25.4s,v29.4s
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
b ComputeBlockLoopM2
|
||||
|
||||
SkipScaleByZeroPointBM2
|
||||
ld1 {v4.8b},[x0],#8 // load first packed A0
|
||||
add v16.4s,v2.4s,v24.4s
|
||||
add v17.4s,v3.4s,v24.4s
|
||||
|
|
@ -467,6 +500,17 @@ ProcessNextColumnLoopM1
|
|||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
uxtl v0.8h,v0.8b
|
||||
cbz x9,SkipScaleByZeroPointBM1
|
||||
ld1 {v28.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v29.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v24.4s,v28.4s
|
||||
mul v17.4s,v24.4s,v29.4s
|
||||
ldr s4,[x0],#4 // load first packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
b ComputeBlockLoopM1
|
||||
|
||||
SkipScaleByZeroPointBM1
|
||||
ldr s4,[x0],#4 // load first packed A0
|
||||
add v16.4s,v2.4s,v24.4s
|
||||
add v17.4s,v3.4s,v24.4s
|
||||
|
|
|
|||
587
onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm
Normal file
587
onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm
Normal file
|
|
@ -0,0 +1,587 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelUdot.asm
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses ARM v8.4 dot product instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "kxarm64.h"
|
||||
#include "AssembleDotProduct.h"
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8X8 kernel.
|
||||
//
|
||||
|
||||
#define GemmU8XKernelFrame_SavedNeonRegisters (4 * 8)
|
||||
#define GemmU8XKernelFrame_SavedRegisters GemmU8XKernelFrame_SavedNeonRegisters
|
||||
#define GemmU8XKernelFrame_ColumnSumBuffer (0 + GemmU8XKernelFrame_SavedRegisters)
|
||||
#define GemmU8XKernelFrame_ZeroPointB (8 + GemmU8XKernelFrame_SavedRegisters)
|
||||
#define GemmU8XKernelFrame_ZeroMode (16 + GemmU8XKernelFrame_SavedRegisters)
|
||||
|
||||
TEXTAREA
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (x0) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackA<MLAS_GEMM_U8X8_KERNEL_UDOT>.
|
||||
|
||||
B (x1) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackB<MLAS_GEMM_U8X8_KERNEL_UDOT>.
|
||||
|
||||
C (x2) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (x3) - Supplies the number of packed columns from matrix A and
|
||||
the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (x4) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (x5) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc (x6) - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values
|
||||
have been pre-scaled by the zero point offset of matrix B if the offset
|
||||
is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
scaled by the per-column zero point offsets of matrix B. These values are
|
||||
accumulated into every row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized, else
|
||||
false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
NESTED_ENTRY MlasGemmU8X8KernelUdot
|
||||
|
||||
PROLOG_SAVE_REG_PAIR d8,d9,#-32!
|
||||
PROLOG_SAVE_REG_PAIR d10,d11,#16
|
||||
ldr x8,[sp,#GemmU8XKernelFrame_ColumnSumBuffer]
|
||||
ldr x9,[sp,#GemmU8XKernelFrame_ZeroPointB]
|
||||
ldrb w13,[sp,#GemmU8XKernelFrame_ZeroMode]
|
||||
mov x14,x0
|
||||
ld1 {v11.4s},[x7]
|
||||
mov x15,x3
|
||||
dup v8.4s,v11.s[0] // broadcast row fixups
|
||||
cmp x4,#1 // CountM == 1?
|
||||
beq ProcessNextColumnLoopM1
|
||||
dup v9.4s,v11.s[1]
|
||||
cmp x4,#4 // CountM < 4?
|
||||
blo ProcessNextColumnLoopM2
|
||||
dup v10.4s,v11.s[2]
|
||||
dup v11.4s,v11.s[3]
|
||||
|
||||
//
|
||||
// Process 4 rows of the matrices.
|
||||
//
|
||||
|
||||
ProcessNextColumnLoopM4
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
mov x0,x14 // reload matrix A
|
||||
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
|
||||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
|
||||
cbz x9,SkipScaleByZeroPointBM4
|
||||
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
|
||||
mul v16.4s,v30.4s,v8.4s
|
||||
mul v18.4s,v30.4s,v9.4s
|
||||
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
|
||||
mul v20.4s,v30.4s,v10.4s
|
||||
mul v22.4s,v30.4s,v11.4s
|
||||
mul v17.4s,v31.4s,v8.4s
|
||||
mul v19.4s,v31.4s,v9.4s
|
||||
mul v21.4s,v31.4s,v10.4s
|
||||
mul v23.4s,v31.4s,v11.4s
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v20.4s,v2.4s,v20.4s
|
||||
add v22.4s,v2.4s,v22.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
add v21.4s,v3.4s,v21.4s
|
||||
add v23.4s,v3.4s,v23.4s
|
||||
b ComputeBlockLoopStartM4
|
||||
|
||||
SkipScaleByZeroPointBM4
|
||||
add v16.4s,v2.4s,v8.4s
|
||||
add v18.4s,v2.4s,v9.4s
|
||||
add v20.4s,v2.4s,v10.4s
|
||||
add v22.4s,v2.4s,v11.4s
|
||||
add v17.4s,v3.4s,v8.4s
|
||||
add v19.4s,v3.4s,v9.4s
|
||||
add v21.4s,v3.4s,v10.4s
|
||||
add v23.4s,v3.4s,v11.4s
|
||||
|
||||
//
|
||||
// The packing layout is setup to have a pair of four quad vectors from
|
||||
// packed matrix A and a pair of eight quad vectors from packed matrix B.
|
||||
// With this scheme, alternating loads from the packed matrices can be
|
||||
// interleaved with the dot product instructions.
|
||||
//
|
||||
// One negative consequence of using four rows here is that the accumulator
|
||||
// register tile is too small for processors with high out of order execution
|
||||
// windows (such as the Apple M1). The dot product instructions for a given
|
||||
// cell are too close to each other to avoid dependencies. To workaround this,
|
||||
// the below loop uses a pair of accumulator registers that are then added
|
||||
// together when the loop finishes.
|
||||
//
|
||||
// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for
|
||||
// packed matrix A. At the time of this implementation, using a wider 128-bit
|
||||
// load didn't affect performance for higher end cores.
|
||||
//
|
||||
|
||||
ComputeBlockLoopStartM4
|
||||
ldr d4,[x0],#32 // load packed A0.l
|
||||
movi v24.4s,#0
|
||||
movi v25.4s,#0
|
||||
ldur d5,[x0,#-24] // load packed A0.h
|
||||
movi v26.4s,#0
|
||||
movi v27.4s,#0
|
||||
ldur d6,[x0,#-16] // load packed A1.l
|
||||
movi v28.4s,#0
|
||||
movi v29.4s,#0
|
||||
movi v30.4s,#0
|
||||
movi v31.4s,#0
|
||||
|
||||
ComputeBlockLoopM4
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 16, 0, 4, 0
|
||||
UdotByElement 18, 0, 4, 1
|
||||
ldur d7,[x0,#-8] // load packed A1.h
|
||||
UdotByElement 20, 0, 5, 0
|
||||
UdotByElement 22, 0, 5, 1
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
UdotByElement 17, 1, 4, 0
|
||||
UdotByElement 19, 1, 4, 1
|
||||
sub x3,x3,#1
|
||||
cbz x3,ComputeBlockLoopFinishM4
|
||||
ldr d4,[x0],#32 // load packed A0.l
|
||||
UdotByElement 21, 1, 5, 0
|
||||
UdotByElement 23, 1, 5, 1
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 24, 0, 6, 0
|
||||
UdotByElement 26, 0, 6, 1
|
||||
ldur d5,[x0,#-24] // load packed A0.h
|
||||
UdotByElement 28, 0, 7, 0
|
||||
UdotByElement 30, 0, 7, 1
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
UdotByElement 25, 1, 6, 0
|
||||
UdotByElement 27, 1, 6, 1
|
||||
ldur d6,[x0,#-16] // load packed A1.l
|
||||
UdotByElement 29, 1, 7, 0
|
||||
UdotByElement 31, 1, 7, 1
|
||||
b ComputeBlockLoopM4
|
||||
|
||||
ComputeBlockLoopFinishM4
|
||||
UdotByElement 21, 1, 5, 0
|
||||
UdotByElement 23, 1, 5, 1
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 24, 0, 6, 0
|
||||
UdotByElement 26, 0, 6, 1
|
||||
UdotByElement 28, 0, 7, 0
|
||||
UdotByElement 30, 0, 7, 1
|
||||
UdotByElement 25, 1, 6, 0
|
||||
UdotByElement 27, 1, 6, 1
|
||||
UdotByElement 29, 1, 7, 0
|
||||
UdotByElement 31, 1, 7, 1
|
||||
add x10,x2,x6,lsl #2 // compute output row 2
|
||||
add v16.4s,v16.4s,v24.4s // fold high results into low results
|
||||
add v18.4s,v18.4s,v26.4s
|
||||
add v20.4s,v20.4s,v28.4s
|
||||
add v22.4s,v22.4s,v30.4s
|
||||
add x11,x10,x6,lsl #2 // compute output row 3
|
||||
add v17.4s,v17.4s,v25.4s
|
||||
add v19.4s,v19.4s,v27.4s
|
||||
add v21.4s,v21.4s,v29.4s
|
||||
add v23.4s,v23.4s,v31.4s
|
||||
add x12,x11,x6,lsl #2 // compute output row 4
|
||||
subs x5,x5,#8 // adjust CountN remaining
|
||||
blo StoreOutputPartialM4
|
||||
cbnz x13,SkipAccumulateOutputM4
|
||||
ldp q0,q1,[x2]
|
||||
ldp q2,q3,[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v17.4s,v17.4s,v1.4s
|
||||
ldp q4,q5,[x11]
|
||||
add v18.4s,v18.4s,v2.4s
|
||||
add v19.4s,v19.4s,v3.4s
|
||||
ldp q6,q7,[x12]
|
||||
add v20.4s,v20.4s,v4.4s
|
||||
add v21.4s,v21.4s,v5.4s
|
||||
add v22.4s,v22.4s,v6.4s
|
||||
add v23.4s,v23.4s,v7.4s
|
||||
|
||||
SkipAccumulateOutputM4
|
||||
stp q16,q17,[x2],#32
|
||||
stp q18,q19,[x10]
|
||||
stp q20,q21,[x11]
|
||||
stp q22,q23,[x12]
|
||||
cbnz x5,ProcessNextColumnLoopM4
|
||||
|
||||
ExitKernelM4
|
||||
mov x0,#4 // return number of rows handled
|
||||
EPILOG_RESTORE_REG_PAIR d10,d11,#16
|
||||
EPILOG_RESTORE_REG_PAIR d8,d9,#32!
|
||||
EPILOG_RETURN
|
||||
|
||||
//
|
||||
// Store the partial 1 to 7 columns either overwriting the output matrix or
|
||||
// accumulating into the existing contents of the output matrix.
|
||||
//
|
||||
|
||||
StoreOutputPartialM4
|
||||
cbz x13,StoreOutputPartialAddModeM4
|
||||
|
||||
StoreOutputPartialZeroModeM4
|
||||
tbz x5,#2,StoreOutputPartial2ZeroModeM4
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
st1 {v20.4s},[x11],#16
|
||||
mov v20.16b,v21.16b
|
||||
st1 {v22.4s},[x12],#16
|
||||
mov v22.16b,v23.16b
|
||||
|
||||
StoreOutputPartial2ZeroModeM4
|
||||
tbz x5,#1,StoreOutputPartial1ZeroModeM4
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
st1 {v20.2s},[x11],#8
|
||||
dup v20.4s,v20.s[2]
|
||||
st1 {v22.2s},[x12],#8
|
||||
dup v22.4s,v22.s[2]
|
||||
|
||||
StoreOutputPartial1ZeroModeM4
|
||||
tbz x5,#0,ExitKernelM4
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
st1 {v20.s}[0],[x11]
|
||||
st1 {v22.s}[0],[x12]
|
||||
b ExitKernelM4
|
||||
|
||||
StoreOutputPartialAddModeM4
|
||||
tbz x5,#2,StoreOutputPartial2AddModeM4
|
||||
ld1 {v0.4s},[x2]
|
||||
ld1 {v1.4s},[x10]
|
||||
ld1 {v2.4s},[x11]
|
||||
ld1 {v3.4s},[x12]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
add v20.4s,v20.4s,v2.4s
|
||||
add v22.4s,v22.4s,v3.4s
|
||||
st1 {v20.4s},[x11],#16
|
||||
mov v20.16b,v21.16b
|
||||
st1 {v22.4s},[x12],#16
|
||||
mov v22.16b,v23.16b
|
||||
|
||||
StoreOutputPartial2AddModeM4
|
||||
tbz x5,#1,StoreOutputPartial1AddModeM4
|
||||
ld1 {v0.2s},[x2]
|
||||
ld1 {v1.2s},[x10]
|
||||
ld1 {v2.2s},[x11]
|
||||
ld1 {v3.2s},[x12]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
add v20.4s,v20.4s,v2.4s
|
||||
add v22.4s,v22.4s,v3.4s
|
||||
st1 {v20.2s},[x11],#8
|
||||
dup v20.4s,v20.s[2]
|
||||
st1 {v22.2s},[x12],#8
|
||||
dup v22.4s,v22.s[2]
|
||||
|
||||
StoreOutputPartial1AddModeM4
|
||||
tbz x5,#0,ExitKernelM4
|
||||
ld1 {v0.s}[0],[x2]
|
||||
ld1 {v1.s}[0],[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
ld1 {v2.s}[0],[x11]
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
ld1 {v3.s}[0],[x12]
|
||||
add v20.4s,v20.4s,v2.4s
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
add v22.4s,v22.4s,v3.4s
|
||||
st1 {v20.s}[0],[x11]
|
||||
st1 {v22.s}[0],[x12]
|
||||
b ExitKernelM4
|
||||
|
||||
//
|
||||
// Process 2 rows of the matrices.
|
||||
//
|
||||
|
||||
ProcessNextColumnLoopM2
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
mov x0,x14 // reload matrix A
|
||||
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0]
|
||||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4]
|
||||
cbz x9,SkipScaleByZeroPointBM2
|
||||
ld1 {v30.4s},[x9],#16 // load ZeroPointB[0]
|
||||
ld1 {v31.4s},[x9],#16 // load ZeroPointB[4]
|
||||
mul v16.4s,v30.4s,v8.4s
|
||||
mul v18.4s,v30.4s,v9.4s
|
||||
mul v17.4s,v31.4s,v8.4s
|
||||
mul v19.4s,v31.4s,v9.4s
|
||||
ld1 {v4.16b},[x0],#16 // load packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v18.4s,v2.4s,v18.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
add v19.4s,v3.4s,v19.4s
|
||||
b ComputeBlockLoopM2
|
||||
|
||||
SkipScaleByZeroPointBM2
|
||||
ld1 {v4.16b},[x0],#16 // load packed A0
|
||||
add v16.4s,v2.4s,v8.4s
|
||||
add v18.4s,v2.4s,v9.4s
|
||||
add v17.4s,v3.4s,v8.4s
|
||||
add v19.4s,v3.4s,v9.4s
|
||||
|
||||
ComputeBlockLoopM2
|
||||
UdotByElement 16, 0, 4, 0
|
||||
UdotByElement 17, 1, 4, 0
|
||||
UdotByElement 18, 0, 4, 1
|
||||
UdotByElement 19, 1, 4, 1
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 16, 0, 4, 2
|
||||
UdotByElement 17, 1, 4, 2
|
||||
UdotByElement 18, 0, 4, 3
|
||||
UdotByElement 19, 1, 4, 3
|
||||
sub x3,x3,#1
|
||||
cbz x3,ComputeBlockLoopFinishM2
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
ld1 {v4.16b},[x0],#16 // load packed A0
|
||||
b ComputeBlockLoopM2
|
||||
|
||||
ComputeBlockLoopFinishM2
|
||||
add x10,x2,x6,lsl #2 // compute output row 2
|
||||
subs x5,x5,#8 // adjust CountN remaining
|
||||
blo StoreOutputPartialM2
|
||||
cbnz x13,SkipAccumulateOutputM2
|
||||
ldp q0,q1,[x2]
|
||||
ldp q2,q3,[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v17.4s,v17.4s,v1.4s
|
||||
add v18.4s,v18.4s,v2.4s
|
||||
add v19.4s,v19.4s,v3.4s
|
||||
|
||||
SkipAccumulateOutputM2
|
||||
stp q16,q17,[x2],#32
|
||||
stp q18,q19,[x10]
|
||||
cbnz x5,ProcessNextColumnLoopM2
|
||||
|
||||
ExitKernelM2
|
||||
mov x0,#2 // return number of rows handled
|
||||
EPILOG_RESTORE_REG_PAIR d10,d11,#16
|
||||
EPILOG_RESTORE_REG_PAIR d8,d9,#32!
|
||||
EPILOG_RETURN
|
||||
|
||||
//
|
||||
// Store the partial 1 to 7 columns either overwriting the output matrix or
|
||||
// accumulating into the existing contents of the output matrix.
|
||||
//
|
||||
|
||||
StoreOutputPartialM2
|
||||
cbz x13,StoreOutputPartialAddModeM2
|
||||
|
||||
StoreOutputPartialZeroModeM2
|
||||
tbz x5,#2,StoreOutputPartial2ZeroModeM2
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
|
||||
StoreOutputPartial2ZeroModeM2
|
||||
tbz x5,#1,StoreOutputPartial1ZeroModeM2
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
|
||||
StoreOutputPartial1ZeroModeM2
|
||||
tbz x5,#0,ExitKernelM2
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
b ExitKernelM2
|
||||
|
||||
StoreOutputPartialAddModeM2
|
||||
tbz x5,#2,StoreOutputPartial2AddModeM2
|
||||
ld1 {v0.4s},[x2]
|
||||
ld1 {v1.4s},[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
st1 {v18.4s},[x10],#16
|
||||
mov v18.16b,v19.16b
|
||||
|
||||
StoreOutputPartial2AddModeM2
|
||||
tbz x5,#1,StoreOutputPartial1AddModeM2
|
||||
ld1 {v0.2s},[x2]
|
||||
ld1 {v1.2s},[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
st1 {v18.2s},[x10],#8
|
||||
dup v18.4s,v18.s[2]
|
||||
|
||||
StoreOutputPartial1AddModeM2
|
||||
tbz x5,#0,ExitKernelM2
|
||||
ld1 {v0.s}[0],[x2]
|
||||
ld1 {v1.s}[0],[x10]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v18.4s,v18.4s,v1.4s
|
||||
st1 {v16.s}[0],[x2]
|
||||
st1 {v18.s}[0],[x10]
|
||||
b ExitKernelM2
|
||||
|
||||
//
|
||||
// Process 1 row of the matrices.
|
||||
//
|
||||
|
||||
ProcessNextColumnLoopM1
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
mov x0,x14 // reload matrix A
|
||||
ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0
|
||||
mov x3,x15 // reload PackedCountK
|
||||
ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1
|
||||
cbz x9,SkipScaleByZeroPointBM1
|
||||
ld1 {v30.4s},[x9],#16 // load ZeroPointB0
|
||||
ld1 {v31.4s},[x9],#16 // load ZeroPointB1
|
||||
mul v16.4s,v30.4s,v8.4s
|
||||
mul v17.4s,v31.4s,v8.4s
|
||||
ldr d4,[x0],#8 // load packed A0
|
||||
add v16.4s,v2.4s,v16.4s
|
||||
add v17.4s,v3.4s,v17.4s
|
||||
b ComputeBlockLoopM1
|
||||
|
||||
SkipScaleByZeroPointBM1
|
||||
ldr d4,[x0],#8 // load packed A0
|
||||
add v16.4s,v2.4s,v8.4s
|
||||
add v17.4s,v3.4s,v8.4s
|
||||
|
||||
ComputeBlockLoopM1
|
||||
UdotByElement 16, 0, 4, 0
|
||||
UdotByElement 17, 1, 4, 0
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
UdotByElement 16, 0, 4, 1
|
||||
UdotByElement 17, 1, 4, 1
|
||||
sub x3,x3,#1
|
||||
cbz x3,ComputeBlockLoopFinishM1
|
||||
ldr d4,[x0],#8 // load packed A0
|
||||
ld1 {v0.16b},[x1],#16 // load packed B0
|
||||
ld1 {v1.16b},[x1],#16 // load packed B1
|
||||
b ComputeBlockLoopM1
|
||||
|
||||
ComputeBlockLoopFinishM1
|
||||
subs x5,x5,#8 // adjust CountN remaining
|
||||
blo StoreOutputPartialM1
|
||||
cbnz x13,SkipAccumulateOutputM1
|
||||
ldp q0,q1,[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
add v17.4s,v17.4s,v1.4s
|
||||
|
||||
SkipAccumulateOutputM1
|
||||
stp q16,q17,[x2],#32
|
||||
cbnz x5,ProcessNextColumnLoopM1
|
||||
|
||||
ExitKernelM1
|
||||
mov x0,#1 // return number of rows handled
|
||||
EPILOG_RESTORE_REG_PAIR d10,d11,#16
|
||||
EPILOG_RESTORE_REG_PAIR d8,d9,#32!
|
||||
EPILOG_RETURN
|
||||
|
||||
//
|
||||
// Store the partial 1 to 7 columns either overwriting the output matrix or
|
||||
// accumulating into the existing contents of the output matrix.
|
||||
//
|
||||
|
||||
StoreOutputPartialM1
|
||||
cbz x13,StoreOutputPartialAddModeM1
|
||||
|
||||
StoreOutputPartialZeroModeM1
|
||||
tbz x5,#2,StoreOutputPartial2ZeroModeM1
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
|
||||
StoreOutputPartial2ZeroModeM1
|
||||
tbz x5,#1,StoreOutputPartial1ZeroModeM1
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
|
||||
StoreOutputPartial1ZeroModeM1
|
||||
tbz x5,#0,ExitKernelM1
|
||||
st1 {v16.s}[0],[x2]
|
||||
b ExitKernelM1
|
||||
|
||||
StoreOutputPartialAddModeM1
|
||||
tbz x5,#2,StoreOutputPartial2AddModeM1
|
||||
ld1 {v0.4s},[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
st1 {v16.4s},[x2],#16
|
||||
mov v16.16b,v17.16b // shift remaining elements down
|
||||
|
||||
StoreOutputPartial2AddModeM1
|
||||
tbz x5,#1,StoreOutputPartial1AddModeM1
|
||||
ld1 {v0.2s},[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
st1 {v16.2s},[x2],#8
|
||||
dup v16.4s,v16.s[2] // shift remaining elements down
|
||||
|
||||
StoreOutputPartial1AddModeM1
|
||||
tbz x5,#0,ExitKernelM1
|
||||
ld1 {v0.s}[0],[x2]
|
||||
add v16.4s,v16.4s,v0.4s
|
||||
st1 {v16.s}[0],[x2]
|
||||
b ExitKernelM1
|
||||
|
||||
NESTED_END MlasGemmU8X8KernelUdot
|
||||
|
||||
END
|
||||
|
|
@ -245,14 +245,6 @@ void
|
|||
|
||||
typedef MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE;
|
||||
|
||||
typedef
|
||||
void
|
||||
(MLASCALL MLAS_GEMM_U8X8_OPERATION)(
|
||||
const struct MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
||||
);
|
||||
|
||||
typedef MLAS_GEMM_U8X8_OPERATION* PMLAS_GEMM_U8X8_OPERATION;
|
||||
|
||||
typedef
|
||||
size_t
|
||||
(MLASCALL MLAS_GEMM_U8S8_KERNEL)(
|
||||
|
|
@ -265,7 +257,7 @@ size_t
|
|||
size_t ldc,
|
||||
const int32_t* RowSumVector,
|
||||
const int32_t* ColumnSumVector,
|
||||
int32_t DepthValue,
|
||||
const int32_t* ZeroPointB,
|
||||
bool ZeroMode
|
||||
);
|
||||
|
||||
|
|
@ -296,7 +288,7 @@ size_t
|
|||
size_t ldc,
|
||||
const int32_t* RowSumVector,
|
||||
const int32_t* ColumnSumVector,
|
||||
int32_t DepthValue,
|
||||
const int32_t* ZeroPointB,
|
||||
bool ZeroMode
|
||||
);
|
||||
|
||||
|
|
@ -697,26 +689,17 @@ MlasSgemmOperation(
|
|||
);
|
||||
|
||||
//
|
||||
// Quantized integer matrix/matrix multiply operation.
|
||||
// Quantized integer matrix/matrix dispatch structure.
|
||||
//
|
||||
|
||||
struct MLAS_GEMM_U8X8_KERNEL_SSE;
|
||||
struct MLAS_GEMM_U8S8_KERNEL_AVX2;
|
||||
struct MLAS_GEMM_U8U8_KERNEL_AVX2;
|
||||
struct MLAS_GEMM_U8X8_DISPATCH;
|
||||
|
||||
template<typename KernelType>
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemmU8X8Operation(
|
||||
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
||||
);
|
||||
|
||||
template<typename KernelType>
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemmU8X8PackedOperation(
|
||||
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
||||
);
|
||||
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchSse;
|
||||
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8S8DispatchAvx2;
|
||||
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8U8DispatchAvx2;
|
||||
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchNeon;
|
||||
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchUdot;
|
||||
extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchDefault;
|
||||
|
||||
//
|
||||
// Quantized depthwise convolution kernels.
|
||||
|
|
@ -767,12 +750,10 @@ struct MLAS_PLATFORM {
|
|||
PMLAS_SGEMM_KERNEL_M1_ROUTINE KernelM1TransposeBRoutine;
|
||||
PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE TransposePackB16x4Routine;
|
||||
PMLAS_GEMM_DOUBLE_KERNEL GemmDoubleKernel;
|
||||
PMLAS_GEMM_U8X8_OPERATION GemmU8S8Operation;
|
||||
PMLAS_GEMM_U8X8_OPERATION GemmU8S8PackedOperation;
|
||||
const MLAS_GEMM_U8X8_DISPATCH* GemmU8S8Dispatch;
|
||||
PMLAS_GEMM_U8S8_KERNEL GemmU8S8Kernel;
|
||||
PMLAS_GEMV_U8S8_KERNEL GemvU8S8Kernel;
|
||||
PMLAS_GEMM_U8X8_OPERATION GemmU8U8Operation;
|
||||
PMLAS_GEMM_U8X8_OPERATION GemmU8U8PackedOperation;
|
||||
const MLAS_GEMM_U8X8_DISPATCH* GemmU8U8Dispatch;
|
||||
PMLAS_GEMM_U8U8_KERNEL GemmU8U8Kernel;
|
||||
PMLAS_CONV_FLOAT_KERNEL ConvNchwFloatKernel;
|
||||
PMLAS_CONV_FLOAT_KERNEL ConvNchwcFloatKernel;
|
||||
|
|
@ -800,6 +781,10 @@ struct MLAS_PLATFORM {
|
|||
#else
|
||||
static constexpr uint32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
|
||||
#endif
|
||||
|
||||
#if defined(MLAS_TARGET_ARM64)
|
||||
const MLAS_GEMM_U8X8_DISPATCH* GemmU8X8Dispatch;
|
||||
#endif
|
||||
};
|
||||
|
||||
extern MLAS_PLATFORM MlasPlatform;
|
||||
|
|
|
|||
|
|
@ -17,6 +17,16 @@ Abstract:
|
|||
|
||||
#include "mlasi.h"
|
||||
|
||||
#if defined(MLAS_TARGET_ARM64) && defined(__linux__)
|
||||
#include <sys/auxv.h>
|
||||
#include <asm/hwcap.h>
|
||||
// N.B. Support building with older versions of asm/hwcap.h that do not define
|
||||
// this capability bit.
|
||||
#ifndef HWCAP_ASIMDDP
|
||||
#define HWCAP_ASIMDDP (1 << 20)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
//
|
||||
// Stores the platform information.
|
||||
//
|
||||
|
|
@ -120,8 +130,8 @@ Return Value:
|
|||
|
||||
this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse;
|
||||
this->GemmDoubleKernel = MlasGemmDoubleKernelSse;
|
||||
this->GemmU8S8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8X8_KERNEL_SSE>;
|
||||
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8X8_KERNEL_SSE>;
|
||||
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchSse;
|
||||
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchSse;
|
||||
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse;
|
||||
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelSse;
|
||||
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelSse;
|
||||
|
|
@ -206,12 +216,10 @@ Return Value:
|
|||
|
||||
if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) {
|
||||
|
||||
this->GemmU8S8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
|
||||
this->GemmU8S8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
|
||||
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2;
|
||||
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2;
|
||||
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2;
|
||||
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8U8_KERNEL_AVX2>;
|
||||
this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8U8_KERNEL_AVX2>;
|
||||
this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2;
|
||||
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2;
|
||||
|
||||
this->GemmFloatKernel = MlasGemmFloatKernelFma3;
|
||||
|
|
@ -229,7 +237,7 @@ Return Value:
|
|||
this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernelAvx2<int8_t>;
|
||||
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernelAvx2<uint8_t>;
|
||||
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
|
||||
|
||||
|
||||
//
|
||||
// Check if the processor supports Hybrid core architecture.
|
||||
//
|
||||
|
|
@ -251,8 +259,7 @@ Return Value:
|
|||
|
||||
if ((Cpuid7_1[0] & 0x10) != 0) {
|
||||
|
||||
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
|
||||
this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
|
||||
this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
|
||||
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
|
||||
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
|
||||
}
|
||||
|
|
@ -304,8 +311,7 @@ Return Value:
|
|||
|
||||
if ((Cpuid7[2] & 0x800) != 0) {
|
||||
|
||||
this->GemmU8U8Operation = MlasGemmU8X8Operation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
|
||||
this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>;
|
||||
this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
|
||||
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni;
|
||||
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni;
|
||||
}
|
||||
|
|
@ -326,6 +332,24 @@ Return Value:
|
|||
|
||||
#endif // MLAS_TARGET_AMD64_IX86
|
||||
|
||||
#if defined(MLAS_TARGET_ARM64)
|
||||
|
||||
this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchNeon;
|
||||
|
||||
#if defined(__linux__)
|
||||
|
||||
//
|
||||
// Check if the processor supports ASIMD dot product instructions.
|
||||
//
|
||||
|
||||
if ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0) {
|
||||
this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchUdot;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
size_t
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -18,7 +18,6 @@ Abstract:
|
|||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8X8KernelAvx2Common.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
|
|
@ -27,7 +26,7 @@ Abstract:
|
|||
//
|
||||
|
||||
.equ .LGemmU8S8CopyPackAFrame_PaddedMatrixAData, -72
|
||||
.equ .LGemmU8S8CopyPackAFrame_mask, -8
|
||||
.equ .LGemmU8S8CopyPackAFrame_Padding, -8
|
||||
.equ .LGemmU8S8CopyPackAFrame_SavedR13, 0
|
||||
.equ .LGemmU8S8CopyPackAFrame_SavedR12, 8
|
||||
.equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16
|
||||
|
|
@ -76,8 +75,7 @@ Return Value:
|
|||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2):
|
||||
FUNCTION_ENTRY MlasGemmU8S8CopyPackAAvx2
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
|
|
@ -101,9 +99,9 @@ C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2):
|
|||
and eax,15 # isolate unaligned count
|
||||
add eax,3
|
||||
shr eax,2 # align unaligned count to quad count
|
||||
mov DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp],eax
|
||||
vpbroadcastd xmm10,DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp]
|
||||
vpcmpgtd xmm10,xmm10,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
|
||||
neg rax
|
||||
lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4]
|
||||
vmovdqu xmm10,XMMWORD PTR [rbx+rax*4]
|
||||
|
||||
//
|
||||
// Zero initialize the padded stack buffers.
|
||||
|
|
@ -430,8 +428,7 @@ Return Value:
|
|||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2):
|
||||
FUNCTION_ENTRY MlasGemmU8S8CopyPackBAvx2
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
|
|
@ -698,258 +695,4 @@ C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2):
|
|||
vmovdqu YMMWORD PTR [r9+32],ymm1
|
||||
jmp .LCopyPackB.ExitRoutine
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 16).
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
ymm0 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 16).
|
||||
|
||||
ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
|
||||
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddubsw ymm3,ymm2,ymm0
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
|
||||
vpmaddubsw ymm2,ymm2,ymm1
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
|
||||
.else
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rcx - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm11 - Supplies the block accumulators.
|
||||
|
||||
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \RowCount\() == 1
|
||||
vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]
|
||||
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd ymm4,ymm4,ymm3
|
||||
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32]
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd ymm5,ymm5,ymm2
|
||||
.else
|
||||
vpaddd ymm5,ymm5,ymm3
|
||||
.endif
|
||||
.else
|
||||
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm11 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,64 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
|
||||
and the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrix A multplied by the zero point offset of matrix B. This value
|
||||
is accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8S8KernelAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8S8KernelAvx2):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
|
||||
vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF]
|
||||
vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001]
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
mov r8d,4 # return 4 rows handled
|
||||
ProcessCountM 4, Fallthrough
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
mov eax,r8d
|
||||
vzeroupper
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.end
|
||||
|
|
|
|||
|
|
@ -1,88 +0,0 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvx512Common.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains common kernel macros and structures for the quantized
|
||||
integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
|
||||
AVX512VNNI kernels.
|
||||
|
||||
--*/
|
||||
|
||||
#include "QgemmU8X8KernelAvx512Common.h"
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if ((\RowCount\() & 1) == 0)
|
||||
sub rbp,4*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
.LComputeBlockBy4Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0*64, 0
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 1*64, 4
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 2*64, 8
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 3*64, 12
|
||||
add rdi,4*4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4*4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,4*64 # advance matrix B
|
||||
sub rbp,4*4 # decrement quads remaining
|
||||
jae .LComputeBlockBy4Loop\@
|
||||
|
||||
.LProcessRemainingBlocks\@:
|
||||
add rbp,4*4 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit\@
|
||||
.endif
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,64 # advance matrix B
|
||||
sub rbp,4 # decrement quads remaining
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.LComputeBlockLoopExit\@:
|
||||
|
||||
.endm
|
||||
|
|
@ -1,136 +0,0 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvx512Core.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX512 core instructions (BW/DQ/VL).
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8S8KernelAvx512Common.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single cell of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
AccumReg - Supplies the register to accumulate into.
|
||||
|
||||
Mult1Reg - Supplies the first multiplication operand register.
|
||||
|
||||
Mult2Reg - Supplies the second multiplication operand register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
zmm4 - Supplies a scratch register for intermediate results.
|
||||
|
||||
zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\()
|
||||
vpmaddwd zmm4,zmm4,zmm5
|
||||
vpaddd \AccumReg\(),\AccumReg\(),zmm4
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \ColumnCount\() >= 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
|
||||
.elseif \ColumnCount\() >= 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
.else
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
|
||||
|
||||
.endm
|
||||
|
||||
//
|
||||
// Generate the GEMM kernel.
|
||||
//
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512Core
|
||||
|
||||
.end
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvx512Vnni.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX512VNNI instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8S8KernelAvx512Common.h"
|
||||
#include "AssembleAvx512Vnni.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \ColumnCount\() >= 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
|
||||
.elseif \ColumnCount\() >= 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
.else
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2"
|
||||
|
||||
.endm
|
||||
|
||||
//
|
||||
// Generate the GEMM kernel.
|
||||
//
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
|
||||
|
||||
.end
|
||||
|
|
@ -1,268 +0,0 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) 2020 Intel Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvxVnni.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVXVNNI instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8X8KernelAvx2Common.h"
|
||||
#include "AssembleAvxVnni.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
/*++
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 16).
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
ymm0 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 16).
|
||||
|
||||
ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
.if \ColumnCount\() == 16
|
||||
VpdpbusdsYmmYmmYmm \Vec1Reg\(),ymm2,ymm0
|
||||
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm1
|
||||
.else
|
||||
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm0
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rcx - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15"
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,64 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
|
||||
and the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrix A multplied by the zero point offset of matrix B. This value
|
||||
is accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8S8KernelAvxVnni)
|
||||
C_UNDERSCORE(MlasGemmU8S8KernelAvxVnni):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,5
|
||||
ja .LProcessCountM6
|
||||
je .LProcessCountM5
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
.LProcessCountM6:
|
||||
mov r8d,6 # return 6 rows handled
|
||||
ProcessCountM 6, Fallthrough
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
mov eax,r8d
|
||||
vzeroupper
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.LProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
.end
|
||||
|
|
@ -18,7 +18,6 @@ Abstract:
|
|||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8X8KernelAvx2Common.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
|
|
@ -27,7 +26,7 @@ Abstract:
|
|||
//
|
||||
|
||||
.equ .LGemmU8U8CopyPackAFrame_PaddedMatrixAData, -72
|
||||
.equ .LGemmU8U8CopyPackAFrame_mask, -8
|
||||
.equ .LGemmU8U8CopyPackAFrame_Padding, -8
|
||||
.equ .LGemmU8U8CopyPackAFrame_SavedR13, 0
|
||||
.equ .LGemmU8U8CopyPackAFrame_SavedR12, 8
|
||||
.equ .LGemmU8U8CopyPackAFrame_SavedRbx, 16
|
||||
|
|
@ -79,8 +78,7 @@ Return Value:
|
|||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
||||
FUNCTION_ENTRY MlasGemmU8U8CopyPackAAvx2
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
|
|
@ -102,9 +100,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
and eax,15 # isolate unaligned count
|
||||
inc eax
|
||||
shr eax,1 # align unaligned count to pair count
|
||||
mov DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp],eax
|
||||
vpbroadcastd ymm9,DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp]
|
||||
vpcmpgtd ymm9,ymm9,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
|
||||
neg rax
|
||||
lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4]
|
||||
vmovdqu ymm9,YMMWORD PTR [rbx+rax*4]
|
||||
|
||||
//
|
||||
// Zero initialize the padded stack buffers.
|
||||
|
|
@ -394,8 +392,7 @@ Return Value:
|
|||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
||||
FUNCTION_ENTRY MlasGemmU8U8CopyPackBAvx2
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
|
|
@ -599,273 +596,4 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
vmovdqu YMMWORD PTR [r9+32],ymm1
|
||||
jmp .LCopyPackB.ExitRoutine
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 16).
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
ymm0 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 16).
|
||||
|
||||
ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
|
||||
vpmaddwd ymm2,ymm2,ymm1
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
|
||||
.else
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15"
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0)
|
||||
sub rbp,2*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
.LComputeBlockBy2Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 32, 4
|
||||
add rdi,2*4 # advance matrix A by 2 pairs
|
||||
.if \RowCount\() > 3
|
||||
add rbx,2*4 # advance matrix A plus 3 rows by 2 pairs
|
||||
.endif
|
||||
add rsi,2*32 # advance matrix B
|
||||
sub rbp,2*4
|
||||
jae .LComputeBlockBy2Loop\@
|
||||
|
||||
.LProcessRemainingBlocks\@:
|
||||
add rbp,2*4 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit\@
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rsi,32 # advance matrix B
|
||||
.else
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 pair
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 pair
|
||||
.endif
|
||||
add rsi,32 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
.endif
|
||||
|
||||
.LComputeBlockLoopExit\@:
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8U8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8U8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
|
||||
and the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrix A multplied by the zero point offset of matrix B. This value
|
||||
is accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8U8KernelAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8U8KernelAvx2):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,5
|
||||
ja .LProcessCountM6
|
||||
je .LProcessCountM5
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
.LProcessCountM6:
|
||||
mov r8d,6 # return 6 rows handled
|
||||
ProcessCountM 6, Fallthrough
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
mov eax,r8d
|
||||
vzeroupper
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.LProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
.end
|
||||
|
|
|
|||
|
|
@ -1,178 +0,0 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8U8KernelAvx512Core.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX512 core instructions (BW/DQ/VL).
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8X8KernelAvx512Common.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single cell of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
AccumReg - Supplies the register to accumulate into.
|
||||
|
||||
Mult1Reg - Supplies the first multiplication operand register.
|
||||
|
||||
Mult2Reg - Supplies the second multiplication operand register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
zmm4 - Supplies a scratch register for intermediate results.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\()
|
||||
vpaddd \AccumReg\(),\AccumReg\(),zmm4
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \ColumnCount\() >= 48
|
||||
vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()]
|
||||
.elseif \ColumnCount\() >= 32
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
.else
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 pair
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 pair
|
||||
.endif
|
||||
add rsi,32 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.endm
|
||||
|
||||
//
|
||||
// Generate the GEMM kernel.
|
||||
//
|
||||
|
||||
GemmU8X8KernelAvx512Function U8U8, Avx512Core
|
||||
|
||||
.end
|
||||
868
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S
Normal file
868
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S
Normal file
|
|
@ -0,0 +1,868 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelAvx2.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX2 and AVX VNNI instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "AssembleAvxVnni.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8X8 kernel.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_type, -8
|
||||
.equ .LGemmU8X8KernelFrame_SavedR13, 0
|
||||
.equ .LGemmU8X8KernelFrame_SavedR12, 8
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbx, 16
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbp, 24
|
||||
.equ .LGemmU8X8KernelFrame_ReturnAddress, 32
|
||||
.equ .LGemmU8X8KernelFrame_ldc, 40
|
||||
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 48
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 56
|
||||
.equ .LGemmU8X8KernelFrame_ZeroPointB, 64
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 72
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 16).
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
ymm0 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 16).
|
||||
|
||||
ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
|
||||
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRowU8S8Avx2 ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddubsw ymm3,ymm2,ymm0
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
|
||||
vpmaddubsw ymm2,ymm2,ymm1
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
|
||||
.else
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
r8 - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm11 - Supplies the block accumulators.
|
||||
|
||||
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockU8S8Avx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \RowCount\() == 1
|
||||
vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]
|
||||
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd ymm4,ymm4,ymm3
|
||||
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32]
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd ymm5,ymm5,ymm2
|
||||
.else
|
||||
vpaddd ymm5,ymm5,ymm3
|
||||
.endif
|
||||
.else
|
||||
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm10, ymm11"
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 16).
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
ymm0 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 16).
|
||||
|
||||
ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRowU8S8AvxVnni ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
.if \ColumnCount\() == 16
|
||||
VpdpbusdsYmmYmmYmm \Vec1Reg\(),ymm2,ymm0
|
||||
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm1
|
||||
.else
|
||||
VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm0
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
r8 - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockU8S8AvxVnni ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm10, ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm12, ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm14, ymm15"
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple times
|
||||
and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
Isa - Supplies the instruction set architecture string.
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
r8 - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm11 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoopU8S8 Isa, ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add r8,4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,64 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 16).
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
ymm0 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 16).
|
||||
|
||||
ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRowU8U8Avx2 ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
|
||||
vpmaddwd ymm2,ymm2,ymm1
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
|
||||
.else
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
r8 - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockU8U8Avx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm10, ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm12, ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm14, ymm15"
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple times
|
||||
and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
Isa - Supplies the instruction set architecture string.
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
r8 - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0)
|
||||
sub rbp,2*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
.LComputeBlockBy2Loop\@:
|
||||
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
|
||||
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 32, 4
|
||||
add rdi,2*4 # advance matrix A by 2 pairs
|
||||
.if \RowCount\() > 3
|
||||
add r8,2*4 # advance matrix A plus 3 rows by 2 pairs
|
||||
.endif
|
||||
add rsi,2*32 # advance matrix B
|
||||
sub rbp,2*4
|
||||
jae .LComputeBlockBy2Loop\@
|
||||
|
||||
.LProcessRemainingBlocks\@:
|
||||
add rbp,2*4 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit\@
|
||||
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rsi,32 # advance matrix B
|
||||
.else
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 pair
|
||||
.if \RowCount\() > 3
|
||||
add r8,4 # advance matrix A plus 3 rows by 1 pair
|
||||
.endif
|
||||
add rsi,32 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
.endif
|
||||
|
||||
.LComputeBlockLoopExit\@:
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r11 - Supplies the address of the row sum buffer.
|
||||
|
||||
r12 - Supplies the address of the column sum buffer.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the row and column sums.
|
||||
//
|
||||
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r11]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r11+4]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r11+8]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r11+12]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r11+16]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r11+20]"
|
||||
.if \ColumnCount\() == 16
|
||||
vmovdqu ymm0,YMMWORD PTR [r12]
|
||||
vmovdqu ymm1,YMMWORD PTR [r12+32]
|
||||
add r12,16*4 # advance ColumnSumBuffer by 16 columns
|
||||
.else
|
||||
vmovdqu ymm1,YMMWORD PTR [r12]
|
||||
.endif
|
||||
test r13,r13 # per column zero points?
|
||||
jz .LSkipScaleByZeroPointB\@
|
||||
.if \ColumnCount\() == 16
|
||||
vmovdqu ymm2,YMMWORD PTR [r13]
|
||||
vmovdqu ymm3,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 # advance ZeroPointB by 16 columns
|
||||
.else
|
||||
vmovdqu ymm3,YMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld ymm4,ymm5,ymm2"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmulld ymm5,ymm5,ymm3"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm0,ymm4"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm1,ymm5"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld ymm6,ymm7,ymm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmulld ymm7,ymm7,ymm3"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm0,ymm6"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm1,ymm7"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld ymm8,ymm9,ymm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmulld ymm9,ymm9,ymm3"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm0,ymm8"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm1,ymm9"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld ymm10,ymm11,ymm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmulld ymm11,ymm11,ymm3"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm0,ymm10"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm1,ymm11"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld ymm12,ymm13,ymm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmulld ymm13,ymm13,ymm3"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm0,ymm12"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm1,ymm13"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld ymm14,ymm15,ymm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmulld ymm15,ymm15,ymm3"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm0,ymm14"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm1,ymm15"
|
||||
jmp .LAccumulatorsInitialized\@
|
||||
|
||||
.LSkipScaleByZeroPointB\@:
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
|
||||
|
||||
.LAccumulatorsInitialized\@:
|
||||
|
||||
//
|
||||
// Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
//
|
||||
|
||||
.if \RowCount\() > 3
|
||||
lea r8,[rcx*2+rcx]
|
||||
add r8,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
|
||||
jg .LProduceWithU8U8Avx2\@
|
||||
.if \RowCount\() <= 4
|
||||
jl .LProduceWithU8S8AvxVnni\@
|
||||
ComputeBlockLoopU8S8 Avx2, \ColumnCount\(), \RowCount\()
|
||||
jmp .LExitProduceOutputBlock\@
|
||||
.endif
|
||||
|
||||
.LProduceWithU8S8AvxVnni\@:
|
||||
ComputeBlockLoopU8S8 AvxVnni, \ColumnCount\(), \RowCount\()
|
||||
jmp .LExitProduceOutputBlock\@
|
||||
|
||||
.LProduceWithU8U8Avx2\@:
|
||||
ComputeBlockLoopU8U8 Avx2, \ColumnCount\(), \RowCount\()
|
||||
|
||||
.LExitProduceOutputBlock\@:
|
||||
.if \RowCount\() > 3
|
||||
lea r8,[rax*2+rax]
|
||||
add r8,rdx # compute matrix C plus 3 rows
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
rbx - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10b - Supplies the zero mode flag.
|
||||
|
||||
r11 - Supplies the address of the row sum buffer.
|
||||
|
||||
r12 - Supplies the address of the column sum buffer.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount
|
||||
|
||||
cmp r9,8
|
||||
jbe .LProcessRemainingCountN\@
|
||||
|
||||
.LProcessNextColumnLoop16xN\@:
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
sub r9,16
|
||||
jb .LOutputMasked16xNBlock\@
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8+32]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2+32]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8+32],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax+32],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2+32],ymm15"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,rbx # reload matrix A
|
||||
cmp r9,8
|
||||
ja .LProcessNextColumnLoop16xN\@
|
||||
test r9,r9
|
||||
jnz .LProcessRemainingCountN\@
|
||||
|
||||
.LExitProcessCountM\@:
|
||||
mov eax,\RowCount\()
|
||||
jmp .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN\@:
|
||||
ProduceOutputBlock 8, \RowCount\()
|
||||
cmp r9,8
|
||||
jb .LOutputMasked8xNBlock\@
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput8xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput8xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm15"
|
||||
jmp .LExitProcessCountM\@
|
||||
|
||||
.LOutputMasked16xNBlock\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked16xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutputMasked16xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14"
|
||||
add rdx,8*4 # advance matrix C by 8 columns
|
||||
.if \RowCount\() > 3
|
||||
add r8,8*4 # advance matrix C plus 3 rows by 8 columns
|
||||
.endif
|
||||
add r9,8 # correct for over-subtract above
|
||||
|
||||
.LOutputMasked8xNBlock\@:
|
||||
neg r9
|
||||
lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4]
|
||||
vmovdqu ymm0,YMMWORD PTR [rdi+r9*4]
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked8xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [r8]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [r8+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [r8+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
|
||||
|
||||
.LSkipAccumulateOutputMasked8xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [r8],ymm0,ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm15"
|
||||
jmp .LExitProcessCountM\@
|
||||
|
||||
.endm
|
||||
|
||||
//
|
||||
// Reduce code size for the various types of kernels by sharing the outer logic
|
||||
// and switching on the selector codes (using sign bit to discriminate).
|
||||
//
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8S8KernelAvxVnni
|
||||
|
||||
mov eax,-1
|
||||
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2)
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8U8KernelAvx2
|
||||
|
||||
mov eax,1
|
||||
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2)
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8S8KernelAvx2
|
||||
|
||||
xor eax,eax
|
||||
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2)
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (rcx) - Supplies the number of packed columns from matrix A
|
||||
and the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A. These values have
|
||||
been pre-scaled by the zero point offset of matrix B if the offset is
|
||||
per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
scaled by the per-column zero point offsets of matrix B. These values are
|
||||
accumulated into every row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8X8KernelAvx2
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
|
||||
mov DWORD PTR .LGemmU8X8KernelFrame_type[rsp],eax
|
||||
mov rbx,rdi
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
|
||||
mov r12,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ZeroPointB[rsp]
|
||||
vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF]
|
||||
vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001]
|
||||
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
|
||||
je .LCheckCountM4OrMore # U8S8 AVX2 kernel requires extra registers
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
.LCheckCountM6OrMore:
|
||||
cmp r8,5
|
||||
ja .LProcessCountM6
|
||||
je .LProcessCountM5
|
||||
|
||||
.LCheckCountM4OrMore:
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
.LProcessCountM6:
|
||||
ProcessCountM 6
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
vzeroupper
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.LProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
.end
|
||||
|
|
@ -1,273 +0,0 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelAvx2Common.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains common kernel macros and structures for the quantized
|
||||
integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
|
||||
|
||||
--*/
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_mask, -8
|
||||
.equ .LGemmU8X8KernelFrame_SavedR13, 0
|
||||
.equ .LGemmU8X8KernelFrame_SavedR12, 8
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbx, 16
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbp, 24
|
||||
.equ .LGemmU8X8KernelFrame_ReturnAddress, 32
|
||||
.equ .LGemmU8X8KernelFrame_ldc, 40
|
||||
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 48
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 56
|
||||
.equ .LGemmU8X8KernelFrame_DepthValue, 64
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 72
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r12 - Supplies the address of the row sum buffer.
|
||||
|
||||
r13 - Supplies the address of the column sum buffer.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the sum of the global depth value constant,
|
||||
// the column sums, and the row sums.
|
||||
//
|
||||
|
||||
vpbroadcastd ymm1,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 # advance ColumnSumBuffer by 16 columns
|
||||
.else
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
|
||||
|
||||
//
|
||||
// Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
//
|
||||
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rcx*2+rcx]
|
||||
add rbx,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
ComputeBlockLoop \ColumnCount\(), \RowCount\()
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rax*2+rax]
|
||||
add rbx,rdx # compute matrix C plus 3 rows
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Fallthrough - Supplies a non-blank value if the macro may fall through to
|
||||
the ExitKernel label.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
r11 - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10b - Supplies the zero mode flag.
|
||||
|
||||
r12 - Supplies the address of the row sum buffer.
|
||||
|
||||
r13 - Supplies the address of the column sum buffer.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount, Fallthrough
|
||||
|
||||
cmp r9,8
|
||||
jbe .LProcessRemainingCountN\@
|
||||
|
||||
.LProcessNextColumnLoop16xN\@:
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
sub r9,16
|
||||
jb .LOutputMasked16xNBlock\@
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,r11 # reload matrix A
|
||||
cmp r9,8
|
||||
ja .LProcessNextColumnLoop16xN\@
|
||||
test r9,r9
|
||||
jz .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN\@:
|
||||
ProduceOutputBlock 8, \RowCount\()
|
||||
cmp r9,8
|
||||
jb .LOutputMasked8xNBlock\@
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput8xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput8xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15"
|
||||
jmp .LExitKernel
|
||||
|
||||
.LOutputMasked16xNBlock\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked16xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutputMasked16xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
|
||||
add rdx,8*4 # advance matrix C by 8 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,8*4 # advance matrix C plus 3 rows by 8 columns
|
||||
.endif
|
||||
add r9,8 # correct for over-subtract above
|
||||
|
||||
.LOutputMasked8xNBlock\@:
|
||||
mov DWORD PTR .LGemmU8X8KernelFrame_mask[rsp],r9d
|
||||
vpbroadcastd ymm0,DWORD PTR .LGemmU8X8KernelFrame_mask[rsp]
|
||||
vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked8xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
|
||||
|
||||
.LSkipAccumulateOutputMasked8xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15"
|
||||
.ifb \Fallthrough\()
|
||||
jmp .LExitKernel
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
|
@ -1,403 +0,0 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelAvx512Common.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains common kernel macros and structures for the quantized
|
||||
integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and
|
||||
AVX512VNNI kernels.
|
||||
|
||||
--*/
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_SavedR14, 0
|
||||
.equ .LGemmU8X8KernelFrame_SavedR13, 8
|
||||
.equ .LGemmU8X8KernelFrame_SavedR12, 16
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbx, 24
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbp, 32
|
||||
.equ .LGemmU8X8KernelFrame_ReturnAddress, 40
|
||||
.equ .LGemmU8X8KernelFrame_ldc, 48
|
||||
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 56
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 64
|
||||
.equ .LGemmU8X8KernelFrame_DepthValue, 72
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 80
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r12 - Supplies the address of the row sum buffer.
|
||||
|
||||
r13 - Supplies the address of the column sum buffer.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the sum of the global depth value constant,
|
||||
// the column sums, and the row sums.
|
||||
//
|
||||
|
||||
vpbroadcastd zmm3,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
|
||||
.if \ColumnCount\() >= 32
|
||||
.if \ColumnCount\() >= 48
|
||||
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
|
||||
.else
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
|
||||
.endif
|
||||
add_immed r13,\ColumnCount\()*4 # advance ColumnSumBuffer by N columns
|
||||
.else
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r12+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r12+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r12+20]{1to16}"
|
||||
|
||||
//
|
||||
// Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
//
|
||||
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rcx*2+rcx]
|
||||
add rbx,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
ComputeBlockLoop \ColumnCount\(), \RowCount\()
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
r11 - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10b - Supplies the zero mode flag.
|
||||
|
||||
r12 - Supplies the address of the row sum buffer.
|
||||
|
||||
r13 - Supplies the address of the column sum buffer.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount
|
||||
|
||||
cmp r9,32
|
||||
ja .LProcessNextColumnLoop48xN\@
|
||||
cmp r9,16
|
||||
jbe .LProcessRemainingCountN\@
|
||||
|
||||
.LProcessNextColumnLoop32xN\@:
|
||||
ProduceOutputBlock 32, \RowCount\()
|
||||
add rsi,r14 # advance matrix B by packed block stride
|
||||
|
||||
.LOutput32xNBlock\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput32xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput32xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm23"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
|
||||
.endif
|
||||
sub r9,16
|
||||
|
||||
.LOutput16xNBlock\@:
|
||||
sub r9,16
|
||||
jae .LOutput16xNBlockWithMask\@
|
||||
lea rcx,[r9+16] # correct for over-subtract above
|
||||
mov ebp,1
|
||||
shl ebp,cl
|
||||
dec ebp
|
||||
kmovw k1,ebp # update mask for remaining columns
|
||||
xor r9,r9 # no more columns remaining
|
||||
|
||||
.LOutput16xNBlockWithMask\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlockWithMask\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlockWithMask\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,r11 # reload matrix A
|
||||
cmp r9,32
|
||||
ja .LProcessNextColumnLoop48xN\@
|
||||
cmp r9,16
|
||||
ja .LProcessNextColumnLoop32xN\@
|
||||
test r9,r9
|
||||
jz .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN\@:
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
jmp .LOutput16xNBlock\@
|
||||
|
||||
.LProcessNextColumnLoop48xN\@:
|
||||
ProduceOutputBlock 48, \RowCount\()
|
||||
lea rsi,[rsi+r14*2] # advance matrix B by packed block stride
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput48xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput48xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm29"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
|
||||
.endif
|
||||
sub r9,16
|
||||
jmp .LOutput32xNBlock\@
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates the common AVX512 code for the inner kernel to compute
|
||||
matrix multiplication.
|
||||
|
||||
Arguments:
|
||||
|
||||
Type - Supplies the kernel type string for function tags.
|
||||
|
||||
Isa - Supplies the instruction set architecture string for function tags.
|
||||
|
||||
--*/
|
||||
|
||||
.macro GemmU8X8KernelAvx512Function Type, Isa
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (rcx) - Supplies the number of packed columns from matrix A and
|
||||
the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\())
|
||||
C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
push r14
|
||||
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
|
||||
mov ebp,-1
|
||||
kmovw k1,ebp # update mask to write all columns
|
||||
.ifeqs "\Type\()", "U8S8"
|
||||
.ifeqs "\Isa\()", "Avx512Core"
|
||||
neg ebp
|
||||
vpbroadcastw zmm5,ebp # generate 512-bit word vector [0x0001]
|
||||
.endif
|
||||
mov r14,rcx
|
||||
shl r14,4 # compute matrix B packed stride
|
||||
.else
|
||||
lea r14,[rcx*8] # compute matrix B packed stride
|
||||
.endif
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,5
|
||||
ja .LProcessCountM6
|
||||
je .LProcessCountM5
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
.LProcessCountM6:
|
||||
mov r8d,6 # return 6 rows handled
|
||||
ProcessCountM 6
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
mov eax,r8d
|
||||
vzeroupper
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.LProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
.endm
|
||||
709
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S
Normal file
709
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S
Normal file
|
|
@ -0,0 +1,709 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelAvx512Core.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "AssembleAvx512Vnni.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8X8 kernel.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_type, -8
|
||||
.equ .LGemmU8X8KernelFrame_SavedR14, 0
|
||||
.equ .LGemmU8X8KernelFrame_SavedR13, 8
|
||||
.equ .LGemmU8X8KernelFrame_SavedR12, 16
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbx, 24
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbp, 32
|
||||
.equ .LGemmU8X8KernelFrame_ReturnAddress, 40
|
||||
.equ .LGemmU8X8KernelFrame_ldc, 48
|
||||
.equ .LGemmU8X8KernelFrame_RowSumBuffer, 56
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 64
|
||||
.equ .LGemmU8X8KernelFrame_ZeroPointB, 72
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 80
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to load packed data from matrix B.
|
||||
|
||||
Arguments:
|
||||
|
||||
VecReg - Supplies the register to load the data into.
|
||||
|
||||
AddressOperand - Supplies the address operand.
|
||||
|
||||
--*/
|
||||
|
||||
.macro LoadPackedMatrixBU8S8 VecReg, AddressOperand
|
||||
|
||||
vmovdqu32 \VecReg\(),ZMMWORD PTR \AddressOperand\()
|
||||
|
||||
.endm
|
||||
|
||||
.macro LoadPackedMatrixBU8U8 VecReg, AddressOperand
|
||||
|
||||
vpmovzxbw \VecReg\(),YMMWORD PTR \AddressOperand\()
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single cell of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
AccumReg - Supplies the register to accumulate into.
|
||||
|
||||
Mult1Reg - Supplies the first multiplication operand register.
|
||||
|
||||
Mult2Reg - Supplies the second multiplication operand register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
zmm4 - Supplies a scratch register for intermediate results.
|
||||
|
||||
zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateCellU8S8Avx512Core AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\()
|
||||
vpmaddwd zmm4,zmm4,zmm13
|
||||
vpaddd \AccumReg\(),\AccumReg\(),zmm4
|
||||
|
||||
.endm
|
||||
|
||||
.macro MultiplyAccumulateCellU8S8Avx512Vnni AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
VpdpbusdsZmmZmmZmm \AccumReg\(),\Mult1Reg\(),\Mult2Reg\()
|
||||
|
||||
.endm
|
||||
|
||||
.macro MultiplyAccumulateCellU8U8Avx512Core AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\()
|
||||
vpaddd \AccumReg\(),\AccumReg\(),zmm4
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
r8 - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \ColumnCount\() >= 48
|
||||
LoadPackedMatrixB\Type\() zmm0,"[rsi+\VectorOffset\()]"
|
||||
LoadPackedMatrixB\Type\() zmm1,"[rsi+r14+\VectorOffset\()]"
|
||||
LoadPackedMatrixB\Type\() zmm2,"[rsi+r14*2+\VectorOffset\()]"
|
||||
.elseif \ColumnCount\() >= 32
|
||||
LoadPackedMatrixB\Type\() zmm1,"[rsi+\VectorOffset\()]"
|
||||
LoadPackedMatrixB\Type\() zmm2,"[rsi+r14+\VectorOffset\()]"
|
||||
.else
|
||||
LoadPackedMatrixB\Type\() zmm2,"[rsi+\VectorOffset\()]"
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [r8+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [r8+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm19,zmm3,zmm2"
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple times
|
||||
and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
Isa - Supplies the instruction set architecture string.
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
r8 - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoopU8S8 Isa, ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if ((\RowCount\() & 1) == 0)
|
||||
sub rbp,4*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
.LComputeBlockBy4Loop\@:
|
||||
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0*64, 0
|
||||
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 1*64, 4
|
||||
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 2*64, 8
|
||||
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 3*64, 12
|
||||
add rdi,4*4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add r8,4*4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,4*64 # advance matrix B
|
||||
sub rbp,4*4 # decrement quads remaining
|
||||
jae .LComputeBlockBy4Loop\@
|
||||
|
||||
.LProcessRemainingBlocks\@:
|
||||
add rbp,4*4 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit\@
|
||||
.endif
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add r8,4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,64 # advance matrix B
|
||||
sub rbp,4 # decrement quads remaining
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.LComputeBlockLoopExit\@:
|
||||
|
||||
.endm
|
||||
|
||||
.macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock U8U8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 pair
|
||||
.if \RowCount\() > 3
|
||||
add r8,4 # advance matrix A plus 3 rows by 1 pair
|
||||
.endif
|
||||
add rsi,32 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r11 - Supplies the address of the row sum buffer.
|
||||
|
||||
r12 - Supplies the address of the column sum buffer.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the row and column sums.
|
||||
//
|
||||
|
||||
.if \ColumnCount\() >= 32
|
||||
.if \ColumnCount\() >= 48
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [r12]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [r12+64]
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [r12+128]
|
||||
.else
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [r12]
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [r12+64]
|
||||
.endif
|
||||
add_immed r12,\ColumnCount\()*4 # advance ColumnSumBuffer by N columns
|
||||
.else
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [r12]
|
||||
.endif
|
||||
test r13,r13 # per column zero points?
|
||||
jz .LSkipScaleByZeroPointB\@
|
||||
.if \ColumnCount\() >= 32
|
||||
.if \ColumnCount\() >= 48
|
||||
vmovdqu32 zmm5,ZMMWORD PTR [r13]
|
||||
vmovdqu32 zmm4,ZMMWORD PTR [r13+64]
|
||||
vmovdqu32 zmm3,ZMMWORD PTR [r13+128]
|
||||
.else
|
||||
vmovdqu32 zmm4,ZMMWORD PTR [r13]
|
||||
vmovdqu32 zmm3,ZMMWORD PTR [r13+64]
|
||||
.endif
|
||||
add_immed r13,\ColumnCount\()*4 # advance ZeroPointB by N columns
|
||||
.else
|
||||
vmovdqu32 zmm3,ZMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld zmm14,zmm3,DWORD PTR [r11]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld zmm20,zmm4,DWORD PTR [r11]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpmulld zmm26,zmm5,DWORD PTR [r11]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,zmm14"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,zmm20"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,zmm26"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld zmm15,zmm3,DWORD PTR [r11+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpmulld zmm21,zmm4,DWORD PTR [r11+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpmulld zmm27,zmm5,DWORD PTR [r11+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,zmm15"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,zmm21"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,zmm27"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld zmm16,zmm3,DWORD PTR [r11+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpmulld zmm22,zmm4,DWORD PTR [r11+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpmulld zmm28,zmm5,DWORD PTR [r11+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,zmm16"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,zmm22"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,zmm28"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld zmm17,zmm3,DWORD PTR [r11+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpmulld zmm23,zmm4,DWORD PTR [r11+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpmulld zmm29,zmm5,DWORD PTR [r11+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,zmm17"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,zmm23"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,zmm29"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld zmm18,zmm3,DWORD PTR [r11+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpmulld zmm24,zmm4,DWORD PTR [r11+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpmulld zmm30,zmm5,DWORD PTR [r11+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,zmm18"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,zmm24"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,zmm30"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld zmm19,zmm3,DWORD PTR [r11+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpmulld zmm25,zmm4,DWORD PTR [r11+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpmulld zmm31,zmm5,DWORD PTR [r11+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,zmm19"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,zmm25"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,zmm31"
|
||||
jmp .LAccumulatorsInitialized\@
|
||||
|
||||
.LSkipScaleByZeroPointB\@:
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r11]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r11]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r11]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r11+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r11+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r11+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r11+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r11+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r11+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r11+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r11+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r11+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r11+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r11+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r11+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r11+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r11+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r11+20]{1to16}"
|
||||
|
||||
.LAccumulatorsInitialized\@:
|
||||
|
||||
//
|
||||
// Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
//
|
||||
|
||||
.if \RowCount\() > 3
|
||||
lea r8,[rcx*2+rcx]
|
||||
add r8,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
|
||||
je .LProduceWithU8S8Avx512Core\@
|
||||
jg .LProduceWithU8U8Avx512Core\@
|
||||
ComputeBlockLoopU8S8 Avx512Vnni, \ColumnCount\(), \RowCount\()
|
||||
jmp .LExitProduceOutputBlock\@
|
||||
|
||||
.LProduceWithU8U8Avx512Core\@:
|
||||
ComputeBlockLoopU8U8 Avx512Core, \ColumnCount\(), \RowCount\()
|
||||
jmp .LExitProduceOutputBlock\@
|
||||
|
||||
.LProduceWithU8S8Avx512Core\@:
|
||||
ComputeBlockLoopU8S8 Avx512Core, \ColumnCount\(), \RowCount\()
|
||||
|
||||
.LExitProduceOutputBlock\@:
|
||||
.if \RowCount\() > 3
|
||||
lea r8,[rax*2+rax]
|
||||
add r8,rdx # compute matrix C plus 3 rows
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
rbx - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10b - Supplies the zero mode flag.
|
||||
|
||||
r11 - Supplies the address of the row sum buffer.
|
||||
|
||||
r12 - Supplies the address of the column sum buffer.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount
|
||||
|
||||
cmp r9,32
|
||||
ja .LProcessNextColumnLoop48xN\@
|
||||
cmp r9,16
|
||||
jbe .LProcessRemainingCountN\@
|
||||
|
||||
.LProcessNextColumnLoop32xN\@:
|
||||
ProduceOutputBlock 32, \RowCount\()
|
||||
add rsi,r14 # advance matrix B by packed block stride
|
||||
|
||||
.LOutput32xNBlock\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput32xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [r8]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [r8+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [r8+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput32xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm23"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm24"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm25"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
.if \RowCount\() > 3
|
||||
add r8,16*4 # advance matrix C plus 3 rows by 16 columns
|
||||
.endif
|
||||
sub r9,16
|
||||
|
||||
.LOutput16xNBlock\@:
|
||||
sub r9,16
|
||||
jae .LOutput16xNBlockWithMask\@
|
||||
lea rcx,[r9+16] # correct for over-subtract above
|
||||
mov ebp,1
|
||||
shl ebp,cl
|
||||
dec ebp
|
||||
kmovw k1,ebp # update mask for remaining columns
|
||||
xor r9,r9 # no more columns remaining
|
||||
|
||||
.LOutput16xNBlockWithMask\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlockWithMask\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [r8]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [r8+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [r8+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlockWithMask\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8]{k1},zmm17"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm18"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm19"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,rbx # reload matrix A
|
||||
cmp r9,32
|
||||
ja .LProcessNextColumnLoop48xN\@
|
||||
cmp r9,16
|
||||
ja .LProcessNextColumnLoop32xN\@
|
||||
test r9,r9
|
||||
jnz .LProcessRemainingCountN\@
|
||||
mov eax,\RowCount\()
|
||||
jmp .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN\@:
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
jmp .LOutput16xNBlock\@
|
||||
|
||||
.LProcessNextColumnLoop48xN\@:
|
||||
ProduceOutputBlock 48, \RowCount\()
|
||||
lea rsi,[rsi+r14*2] # advance matrix B by packed block stride
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput48xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [r8]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [r8+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [r8+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput48xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm29"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm30"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm31"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
.if \RowCount\() > 3
|
||||
add r8,16*4 # advance matrix C plus 3 rows by 16 columns
|
||||
.endif
|
||||
sub r9,16
|
||||
jmp .LOutput32xNBlock\@
|
||||
|
||||
.endm
|
||||
|
||||
//
|
||||
// Reduce code size for the various types of kernels by sharing the outer logic
|
||||
// and switching on the selector codes (using sign bit to discriminate).
|
||||
//
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8U8KernelAvx512Core
|
||||
|
||||
mov eax,1
|
||||
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core)
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Core
|
||||
|
||||
xor eax,eax
|
||||
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core)
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Vnni
|
||||
|
||||
mov eax,-1
|
||||
jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core)
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PackedCountK (rcx) - Supplies the number of packed columns from matrix A and
|
||||
the number of packed rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumBuffer - Supplies the sum of each row from matrix A. These values have
|
||||
been pre-scaled by the zero point offset of matrix B if the offset is
|
||||
per-tensor (ZeroPointB is nullptr). Otherwise, these values must be
|
||||
scaled by the per-column zero point offsets of matrix B. These values are
|
||||
accumulated into every row of matrix C.
|
||||
|
||||
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
ZeroPointB - Optionally supplies the per-column zero point offsets of matrix
|
||||
B, else nullptr if the matrix B is using per-tensor quantization.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
FUNCTION_ENTRY MlasGemmU8X8KernelAvx512Core
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
push r14
|
||||
|
||||
mov DWORD PTR .LGemmU8X8KernelFrame_type[rsp],eax
|
||||
mov rbx,rdi
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,.LGemmU8X8KernelFrame_RowSumBuffer[rsp]
|
||||
mov r12,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ZeroPointB[rsp]
|
||||
mov ebp,-1
|
||||
kmovw k1,ebp # update mask to write all columns
|
||||
neg ebp
|
||||
vpbroadcastw zmm13,ebp # generate 512-bit word vector [0x0001]
|
||||
lea rbp,[rcx*8]
|
||||
lea r14,[rbp*2]
|
||||
cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0
|
||||
cmovg r14,rbp # select matrix B packed stride
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,5
|
||||
ja .LProcessCountM6
|
||||
je .LProcessCountM5
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
.LProcessCountM6:
|
||||
ProcessCountM 6
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
vzeroupper
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.LProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
.end
|
||||
|
|
@ -22,6 +22,32 @@ Abstract:
|
|||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro emits the assembler directives to annotate a new function.
|
||||
|
||||
Arguments:
|
||||
|
||||
FunctionName - Supplies the name of the function.
|
||||
|
||||
--*/
|
||||
|
||||
.macro FUNCTION_ENTRY FunctionName
|
||||
|
||||
.p2align 4
|
||||
#if defined(__APPLE__)
|
||||
.globl _\FunctionName\()
|
||||
_\FunctionName\():
|
||||
#else
|
||||
.globl \FunctionName\()
|
||||
.type \FunctionName\(),@function
|
||||
\FunctionName\():
|
||||
#endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates an optimization for "add reg,128" which can instead
|
||||
|
|
|
|||
|
|
@ -29,8 +29,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
|
|||
MatMulInteger);
|
||||
|
||||
Status MatMulInteger::Compute(OpKernelContext* ctx) const {
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
const auto* a = ctx->Input<Tensor>(0);
|
||||
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
|
||||
|
||||
|
|
@ -58,49 +56,33 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
|
|||
b_offset = *static_cast<const uint8_t*>(b_zero_point->DataRaw());
|
||||
}
|
||||
|
||||
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
|
||||
gemm_params.M = static_cast<size_t>(helper.M());
|
||||
gemm_params.N = static_cast<size_t>(helper.N());
|
||||
gemm_params.K = static_cast<size_t>(helper.K());
|
||||
gemm_params.lda = gemm_params.K;
|
||||
gemm_params.ZeroPointA = a_offset;
|
||||
gemm_params.ldb = gemm_params.N;
|
||||
gemm_params.ZeroPointB = &b_offset;
|
||||
gemm_params.ldc = gemm_params.N;
|
||||
|
||||
const auto* a_data = a->template Data<uint8_t>();
|
||||
auto* y_data = y->template MutableData<int32_t>();
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
if (packed_b_) {
|
||||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
MlasGemm(static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_offset,
|
||||
packed_b_.get(),
|
||||
b_offset,
|
||||
b_is_signed_,
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
thread_pool);
|
||||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
gemm_params.A = a_data + helper.LeftOffsets()[i];
|
||||
if (packed_b_) {
|
||||
gemm_params.B = packed_b_.get();
|
||||
gemm_params.BIsPacked = true;
|
||||
gemm_params.BIsSigned = b_is_signed_;
|
||||
} else if (b != nullptr) {
|
||||
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + + helper.RightOffsets()[i];
|
||||
gemm_params.BIsSigned = b->IsDataType<int8_t>();
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
if (b != nullptr) {
|
||||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
|
||||
const bool b_is_signed = b->IsDataType<int8_t>();
|
||||
MlasGemm(static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_offset,
|
||||
b_data + helper.RightOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
b_offset,
|
||||
b_is_signed,
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
thread_pool);
|
||||
}
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null.");
|
||||
gemm_params.C = y_data + helper.OutputOffsets()[i];
|
||||
MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ class MatMulIntegerBase : public OpKernel {
|
|||
public:
|
||||
MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override {
|
||||
is_packed = false;
|
||||
|
||||
|
|
@ -43,7 +42,6 @@ class MatMulIntegerBase : public OpKernel {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
bool b_is_signed_{true};
|
||||
|
|
|
|||
|
|
@ -76,20 +76,22 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
|
|||
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
|
||||
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
|
||||
|
||||
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
|
||||
gemm_params.M = static_cast<size_t>(helper.M());
|
||||
gemm_params.N = static_cast<size_t>(helper.N());
|
||||
gemm_params.K = static_cast<size_t>(helper.K());
|
||||
gemm_params.lda = gemm_params.K;
|
||||
gemm_params.ZeroPointA = *a_offset->template Data<uint8_t>();
|
||||
gemm_params.ldb = gemm_params.N;
|
||||
gemm_params.ZeroPointB = static_cast<const uint8_t*>(b_offset->DataRaw());
|
||||
gemm_params.BIsSigned = b->IsDataType<int8_t>();
|
||||
gemm_params.C = gemm_output;
|
||||
gemm_params.ldc = gemm_params.N;
|
||||
|
||||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
MlasGemm(static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
|
||||
static_cast<size_t>(helper.K()),
|
||||
*a_offset->template Data<uint8_t>(),
|
||||
static_cast<const uint8_t*>(b->DataRaw()) + helper.RightOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
*static_cast<const uint8_t*>(b_offset->DataRaw()),
|
||||
b->IsDataType<int8_t>(),
|
||||
gemm_output,
|
||||
static_cast<size_t>(helper.N()),
|
||||
ctx->GetOperatorThreadPool());
|
||||
gemm_params.A = a->template Data<uint8_t>() + helper.LeftOffsets()[i];
|
||||
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + helper.RightOffsets()[i];
|
||||
MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());
|
||||
|
||||
MlasRequantizeOutput(gemm_output,
|
||||
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
|
||||
|
|
|
|||
|
|
@ -149,19 +149,19 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
MlasGemm(static_cast<size_t>(M / conv_attrs_.group),
|
||||
static_cast<size_t>(output_image_size),
|
||||
static_cast<size_t>(kernel_dim),
|
||||
Wdata + group_id * W_offset,
|
||||
static_cast<size_t>(kernel_dim),
|
||||
filter_offset,
|
||||
col_buffer_data == nullptr ? Xdata : col_buffer_data,
|
||||
static_cast<size_t>(output_image_size),
|
||||
input_offset,
|
||||
false,
|
||||
Ydata,
|
||||
static_cast<size_t>(output_image_size),
|
||||
thread_pool);
|
||||
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
|
||||
gemm_params.M = static_cast<size_t>(M / conv_attrs_.group);
|
||||
gemm_params.N = static_cast<size_t>(output_image_size);
|
||||
gemm_params.K = static_cast<size_t>(kernel_dim);
|
||||
gemm_params.A = Wdata + group_id * W_offset;
|
||||
gemm_params.lda = static_cast<size_t>(kernel_dim);
|
||||
gemm_params.ZeroPointA = filter_offset;
|
||||
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
|
||||
gemm_params.ldb = static_cast<size_t>(output_image_size);
|
||||
gemm_params.ZeroPointB = &input_offset;
|
||||
gemm_params.C = Ydata;
|
||||
gemm_params.ldc = static_cast<size_t>(output_image_size);
|
||||
MlasGemm(&gemm_params, thread_pool);
|
||||
|
||||
Xdata += X_offset;
|
||||
Ydata += Y_offset;
|
||||
|
|
|
|||
|
|
@ -43,10 +43,8 @@ class QLinearConv : public OpKernel {
|
|||
|
||||
ConvAttributes conv_attrs_;
|
||||
TensorShape W_shape_;
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
BufferUniquePtr packed_W_buffer_;
|
||||
size_t packed_W_size_;
|
||||
#endif
|
||||
BufferUniquePtr reordered_W_buffer_;
|
||||
bool is_W_signed_;
|
||||
bool is_W_packed_;
|
||||
|
|
@ -116,7 +114,6 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed
|
|||
|
||||
auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
const size_t group_count = static_cast<size_t>(conv_attrs_.group);
|
||||
const size_t group_output_channels = output_channels / group_count;
|
||||
const size_t kernel_dim = group_input_channels * kernel_size;
|
||||
|
|
@ -151,7 +148,6 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed
|
|||
return Status::OK();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
auto* reordered_W = static_cast<uint8_t*>(alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * output_channels * group_input_channels * kernel_size));
|
||||
reordered_W_buffer_ = BufferUniquePtr(reordered_W, BufferDeleter(alloc));
|
||||
|
|
@ -279,13 +275,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
// Handle the case of a dynamic weight filter.
|
||||
BufferUniquePtr reordered_W_buffer;
|
||||
uint8_t* reordered_W = nullptr;
|
||||
bool use_reordered_W = true;
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
if (packed_W_buffer_) {
|
||||
use_reordered_W = false;
|
||||
}
|
||||
#endif
|
||||
if (use_reordered_W) {
|
||||
if (!packed_W_buffer_) {
|
||||
if (W == nullptr) {
|
||||
// Weight was constant and reordered.
|
||||
reordered_W = static_cast<uint8_t*>(reordered_W_buffer_.get());
|
||||
|
|
@ -307,7 +297,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
int64_t group_output_channels = M / group_count;
|
||||
|
||||
// Test for depthwise convolution.
|
||||
const bool is_depthwise_conv = (use_reordered_W && group_input_channels == 1 && group_output_channels == 1);
|
||||
const bool is_depthwise_conv = (reordered_W != nullptr && group_input_channels == 1 && group_output_channels == 1);
|
||||
if (is_depthwise_conv) {
|
||||
// Update the input and output channels to the number of groups in order to
|
||||
// reuse as much of the below standard convolution path.
|
||||
|
|
@ -510,39 +500,25 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
worker_gemm_input = input_data + output_start * kernel_dim;
|
||||
}
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
|
||||
gemm_params.M = static_cast<size_t>(output_count);
|
||||
gemm_params.N = static_cast<size_t>(group_output_channels);
|
||||
gemm_params.K = static_cast<size_t>(kernel_dim);
|
||||
gemm_params.A = worker_gemm_input;
|
||||
gemm_params.lda = static_cast<size_t>(kernel_dim);
|
||||
gemm_params.ZeroPointA = X_zero_point_value;
|
||||
if (packed_W_buffer_) {
|
||||
MlasGemm(
|
||||
static_cast<size_t>(output_count),
|
||||
static_cast<size_t>(group_output_channels),
|
||||
static_cast<size_t>(kernel_dim),
|
||||
worker_gemm_input,
|
||||
static_cast<size_t>(kernel_dim),
|
||||
X_zero_point_value,
|
||||
static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
|
||||
W_zero_point_value,
|
||||
is_W_signed,
|
||||
worker_gemm_output + group_id * group_output_channels,
|
||||
static_cast<size_t>(M),
|
||||
nullptr);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
MlasGemm(
|
||||
static_cast<size_t>(output_count),
|
||||
static_cast<size_t>(group_output_channels),
|
||||
static_cast<size_t>(kernel_dim),
|
||||
worker_gemm_input,
|
||||
static_cast<size_t>(kernel_dim),
|
||||
X_zero_point_value,
|
||||
reordered_W + group_id * group_output_channels,
|
||||
static_cast<size_t>(M),
|
||||
W_zero_point_value,
|
||||
is_W_signed,
|
||||
worker_gemm_output + group_id * group_output_channels,
|
||||
static_cast<size_t>(M),
|
||||
nullptr);
|
||||
gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
|
||||
gemm_params.BIsPacked = true;
|
||||
} else {
|
||||
gemm_params.B = reordered_W + group_id * group_output_channels,
|
||||
gemm_params.ldb = static_cast<size_t>(M);
|
||||
}
|
||||
gemm_params.ZeroPointB = &W_zero_point_value;
|
||||
gemm_params.BIsSigned = is_W_signed;
|
||||
gemm_params.C = worker_gemm_output + group_id * group_output_channels;
|
||||
gemm_params.ldc = static_cast<size_t>(M);
|
||||
MlasGemm(&gemm_params, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -286,39 +286,23 @@ void ComputeGemm(const int M,
|
|||
C, ldc, scale_multiplier.data(), nullptr,
|
||||
beta == 1.0f ? MLAS_QGEMM_OUTPUT_MODE::AccumulateMode : MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
|
||||
scale_multiplier.size() == 1 ? MLAS_QUANTIZATION_GRANULARITY::PerMatrix : MLAS_QUANTIZATION_GRANULARITY::PerColumn);
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
if (weights.is_prepacked_) {
|
||||
MlasGemm(static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
a_data_quant,
|
||||
static_cast<size_t>(K),
|
||||
a_zero_point,
|
||||
weights.buffer_,
|
||||
b_zero_point,
|
||||
b_is_signed,
|
||||
C_buffer,
|
||||
ld_C_buffer,
|
||||
thread_pool,
|
||||
&output_processor);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
MlasGemm(static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
a_data_quant,
|
||||
static_cast<size_t>(K),
|
||||
a_zero_point,
|
||||
static_cast<const uint8_t*>(weights.buffer_),
|
||||
static_cast<size_t>(N),
|
||||
b_zero_point,
|
||||
b_is_signed,
|
||||
C_buffer,
|
||||
ld_C_buffer,
|
||||
thread_pool,
|
||||
&output_processor);
|
||||
MLAS_GEMM_U8X8_PARAMETERS gemm_params;
|
||||
gemm_params.M = static_cast<size_t>(M);
|
||||
gemm_params.N = static_cast<size_t>(N);
|
||||
gemm_params.K = static_cast<size_t>(K);
|
||||
gemm_params.A = a_data_quant;
|
||||
gemm_params.lda = static_cast<size_t>(K);
|
||||
gemm_params.ZeroPointA = a_zero_point;
|
||||
gemm_params.B = weights.buffer_;
|
||||
gemm_params.ldb = static_cast<size_t>(N);
|
||||
gemm_params.ZeroPointB = &b_zero_point;
|
||||
gemm_params.BIsPacked = weights.is_prepacked_;
|
||||
gemm_params.BIsSigned = b_is_signed;
|
||||
gemm_params.C = C_buffer;
|
||||
gemm_params.ldc = ld_C_buffer;
|
||||
gemm_params.OutputProcessor = &output_processor;
|
||||
MlasGemm(&gemm_params, thread_pool);
|
||||
}
|
||||
|
||||
namespace deepcpu {
|
||||
|
|
|
|||
|
|
@ -537,63 +537,7 @@ public:
|
|||
};
|
||||
|
||||
template<bool Packed>
|
||||
class MlasQgemmU8X8U8X8TestBase;
|
||||
|
||||
template<>
|
||||
class MlasQgemmU8X8U8X8TestBase<false> : public MlasTestBase
|
||||
{
|
||||
protected:
|
||||
void
|
||||
TestGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
size_t ldb,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc
|
||||
)
|
||||
{
|
||||
MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc, threadpool);
|
||||
}
|
||||
|
||||
void
|
||||
TestGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
size_t ldb,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
float CScale,
|
||||
const float* Bias
|
||||
)
|
||||
{
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
|
||||
MlasGemm(M, N, K,
|
||||
A, lda, offa,
|
||||
B, ldb, offb, BIsSigned,
|
||||
reinterpret_cast<int32_t*>(C), ldc,
|
||||
threadpool,
|
||||
&scale_bias_processor);
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
|
||||
template<>
|
||||
class MlasQgemmU8X8U8X8TestBase<true> : public MlasTestBase
|
||||
class MlasQgemmU8X8U8X8TestBase : public MlasTestBase
|
||||
{
|
||||
private:
|
||||
void*
|
||||
|
|
@ -628,8 +572,69 @@ protected:
|
|||
size_t ldc
|
||||
)
|
||||
{
|
||||
const void* PackedB = PackB(N, K, B, ldb, BIsSigned);
|
||||
MlasGemm(M, N, K, A, lda, offa, PackedB, offb, BIsSigned, C, ldc, threadpool);
|
||||
MLAS_GEMM_U8X8_PARAMETERS GemmParameters;
|
||||
|
||||
GemmParameters.M = M;
|
||||
GemmParameters.N = N;
|
||||
GemmParameters.K = K;
|
||||
GemmParameters.A = A;
|
||||
GemmParameters.lda = lda;
|
||||
GemmParameters.ZeroPointA = offa;
|
||||
GemmParameters.ZeroPointB = &offb;
|
||||
GemmParameters.BIsSigned = BIsSigned;
|
||||
GemmParameters.C = C;
|
||||
GemmParameters.ldc = ldc;
|
||||
|
||||
if (Packed) {
|
||||
GemmParameters.B = PackB(N, K, B, ldb, BIsSigned);
|
||||
GemmParameters.BIsPacked = true;
|
||||
} else {
|
||||
GemmParameters.B = B;
|
||||
GemmParameters.ldb = ldb;
|
||||
}
|
||||
|
||||
MlasGemm(&GemmParameters, threadpool);
|
||||
}
|
||||
|
||||
void
|
||||
TestGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
size_t ldb,
|
||||
const uint8_t* offb,
|
||||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc
|
||||
)
|
||||
{
|
||||
MLAS_GEMM_U8X8_PARAMETERS GemmParameters;
|
||||
|
||||
GemmParameters.M = M;
|
||||
GemmParameters.N = N;
|
||||
GemmParameters.K = K;
|
||||
GemmParameters.A = A;
|
||||
GemmParameters.lda = lda;
|
||||
GemmParameters.ZeroPointA = offa;
|
||||
GemmParameters.ZeroPointB = offb;
|
||||
GemmParameters.BIsSigned = BIsSigned;
|
||||
GemmParameters.PerColumnZeroPoints = true;
|
||||
GemmParameters.C = C;
|
||||
GemmParameters.ldc = ldc;
|
||||
|
||||
if (Packed) {
|
||||
GemmParameters.B = PackB(N, K, B, ldb, BIsSigned);
|
||||
GemmParameters.BIsPacked = true;
|
||||
} else {
|
||||
GemmParameters.B = B;
|
||||
GemmParameters.ldb = ldb;
|
||||
}
|
||||
|
||||
MlasGemm(&GemmParameters, threadpool);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -650,22 +655,37 @@ protected:
|
|||
const float* Bias
|
||||
)
|
||||
{
|
||||
const void* PackedB = PackB(N, K, B, ldb, BIsSigned);
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
|
||||
MlasGemm(M, N, K,
|
||||
A, lda, offa,
|
||||
PackedB, offb, BIsSigned,
|
||||
reinterpret_cast<int32_t*>(C), ldc,
|
||||
threadpool,
|
||||
&scale_bias_processor);
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR ScaleBiasProcessor(C, ldc, &CScale, Bias);
|
||||
|
||||
MLAS_GEMM_U8X8_PARAMETERS GemmParameters;
|
||||
|
||||
GemmParameters.M = M;
|
||||
GemmParameters.N = N;
|
||||
GemmParameters.K = K;
|
||||
GemmParameters.A = A;
|
||||
GemmParameters.lda = lda;
|
||||
GemmParameters.ZeroPointA = offa;
|
||||
GemmParameters.ZeroPointB = &offb;
|
||||
GemmParameters.BIsSigned = BIsSigned;
|
||||
GemmParameters.C = reinterpret_cast<int32_t*>(C);
|
||||
GemmParameters.ldc = ldc;
|
||||
GemmParameters.OutputProcessor = &ScaleBiasProcessor;
|
||||
|
||||
if (Packed) {
|
||||
GemmParameters.B = PackB(N, K, B, ldb, BIsSigned);
|
||||
GemmParameters.BIsPacked = true;
|
||||
} else {
|
||||
GemmParameters.B = B;
|
||||
GemmParameters.ldb = ldb;
|
||||
}
|
||||
|
||||
MlasGemm(&GemmParameters, threadpool);
|
||||
}
|
||||
|
||||
private:
|
||||
MatrixGuardBuffer<uint8_t> BufferBPacked;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template<typename xint8_t, typename OutputType, bool Packed>
|
||||
class MlasQgemmU8X8Test;
|
||||
|
||||
|
|
@ -690,6 +710,23 @@ private:
|
|||
Test(M, N, K, A, K, offa, B, N, offb, C, CReference, N);
|
||||
}
|
||||
|
||||
void
|
||||
Test(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
uint8_t offa
|
||||
)
|
||||
{
|
||||
const uint8_t* A = BufferA.GetBuffer(K * M);
|
||||
const uint8_t* B = BufferB.GetBuffer(N * K);
|
||||
const uint8_t* ZeroPointB = BufferZeroPointB.GetBuffer(N);
|
||||
int32_t* C = BufferC.GetBuffer(N * M);
|
||||
int32_t* CReference = BufferCReference.GetBuffer(N * M);
|
||||
|
||||
Test(M, N, K, A, K, offa, B, N, ZeroPointB, C, CReference, N);
|
||||
}
|
||||
|
||||
void
|
||||
Test(
|
||||
size_t M,
|
||||
|
|
@ -720,6 +757,36 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
Test(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
size_t ldb,
|
||||
const uint8_t *offb,
|
||||
int32_t* C,
|
||||
int32_t* CReference,
|
||||
size_t ldc
|
||||
)
|
||||
{
|
||||
std::fill_n(C, M * N, -1);
|
||||
std::fill_n(CReference, M * N, -1);
|
||||
|
||||
this->TestGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc);
|
||||
ReferenceQgemm(M, N, K, A, lda, offa, (const xint8_t*)B, ldb, (const xint8_t*)offb, CReference, ldc);
|
||||
|
||||
for (size_t f = 0; f < M * N; f++) {
|
||||
if (C[f] != CReference[f]) {
|
||||
printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d!\n", M, N, K, int(offa));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ReferenceQgemm(
|
||||
size_t M,
|
||||
|
|
@ -755,8 +822,44 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
ReferenceQgemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const xint8_t* B,
|
||||
size_t ldb,
|
||||
const xint8_t* offb,
|
||||
int32_t* C,
|
||||
size_t ldc
|
||||
)
|
||||
{
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
|
||||
const uint8_t* a = A + (m * lda);
|
||||
const xint8_t* b = B + n;
|
||||
int32_t* c = C + (m * ldc) + n;
|
||||
int32_t sum = 0;
|
||||
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += ((int32_t(*b) - offb[n]) * (int32_t(*a) - offa));
|
||||
b += ldb;
|
||||
a += 1;
|
||||
}
|
||||
|
||||
*c = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MatrixGuardBuffer<uint8_t> BufferA;
|
||||
MatrixGuardBuffer<uint8_t> BufferB;
|
||||
MatrixGuardBuffer<uint8_t> BufferZeroPointB;
|
||||
MatrixGuardBuffer<int32_t> BufferC;
|
||||
MatrixGuardBuffer<int32_t> BufferCReference;
|
||||
const bool BIsSigned = std::is_signed<xint8_t>::value;
|
||||
|
|
@ -769,12 +872,15 @@ public:
|
|||
{
|
||||
for (size_t b = 1; b < 16; b++) {
|
||||
Test(b, b, b, 14, 211);
|
||||
Test(b, b, b, 21);
|
||||
}
|
||||
for (size_t b = 1; b < 16; b++) {
|
||||
Test(b, b, b, 14, 211);
|
||||
Test(b, b, b, 17);
|
||||
}
|
||||
for (size_t b = 16; b <= 256; b <<= 1) {
|
||||
Test(b, b, b, 34, 1);
|
||||
Test(b, b, b, 1);
|
||||
}
|
||||
for (size_t b = 256; b < 320; b += 32) {
|
||||
Test(b, b, b, 85, 173);
|
||||
|
|
@ -786,6 +892,7 @@ public:
|
|||
}
|
||||
Test(43, 500, 401, 183, 223);
|
||||
Test(1023, 1023, 1023, 5, 8);
|
||||
Test(1023, 1023, 1023, 7);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -3130,7 +3237,6 @@ RunThreadedTests(
|
|||
printf("QGEMM U8U8=float tests.\n");
|
||||
onnxruntime::make_unique<MlasQgemmU8X8Test<uint8_t, float, false>>()->ExecuteShort();
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
if (MlasGemmPackBSize(128, 128, true) > 0) {
|
||||
printf("QGEMM U8S8=int32_t packed tests.\n");
|
||||
onnxruntime::make_unique<MlasQgemmU8X8Test<int8_t, int32_t, true>>()->ExecuteShort();
|
||||
|
|
@ -3143,7 +3249,6 @@ RunThreadedTests(
|
|||
printf("QGEMM U8U8=float packed tests.\n");
|
||||
onnxruntime::make_unique<MlasQgemmU8X8Test<uint8_t, float, true>>()->ExecuteShort();
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("Conv2D tests.\n");
|
||||
onnxruntime::make_unique<MlasConv2DTest>()->ExecuteShort();
|
||||
|
|
|
|||
Loading…
Reference in a new issue