onnxruntime/onnxruntime/core/mlas/inc/mlas.h
Chen Fu 8dce83a818
Fuse 'Add' operator into FP16 Conv (#15213)
### Description
Adding 'Add' functionality to FP16 Conv operator. It takes a tensor that
has the same shape of the output tensor, and add it to the result
tensor.


### Motivation and Context
Needed to run Resnet 50
2023-04-07 09:51:03 -07:00

1691 lines
42 KiB
C++

/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
mlas.h
Abstract:
This module contains the public data structures and procedure prototypes
for the Microsoft Machine Learning algebra subprogram library.
--*/
#pragma once
#include <cstddef>
#include <cstdlib>
#include <cstdint>
//
// Define the calling convention for Windows targets.
//
#if (_MSC_VER >= 800) || defined(_STDCALL_SUPPORTED)
#define MLASCALL __stdcall
#else
#define MLASCALL
#endif
//
// Define the target architecture.
//
#if (defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__)
#define MLAS_TARGET_AMD64
#endif
#if defined(_M_IX86) || defined(__i386__)
#define MLAS_TARGET_IX86
#endif
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_IX86)
#define MLAS_TARGET_AMD64_IX86
#endif
#if defined(_M_ARM64) || defined(__aarch64__)
#define MLAS_TARGET_ARM64
#endif
#if defined(_M_ARM64EC)
#define MLAS_TARGET_ARM64EC
#endif
#if defined(_M_ARM) || defined(__arm__)
#define MLAS_TARGET_ARM
#endif
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) || defined(MLAS_TARGET_ARM)
#define MLAS_TARGET_ARM_ANY
#endif
#if defined(__VSX__)
#define MLAS_TARGET_POWER
#endif
#if defined(__wasm__)
#define MLAS_TARGET_WASM
#if defined(__wasm_simd128__)
#define MLAS_TARGET_WASM_SIMD
#else
#define MLAS_TARGET_WASM_SCALAR
#endif
#endif
//
// Define the support levels for the target architecture.
//
#if defined(MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER)
#define MLAS_SUPPORTS_GEMM_DOUBLE
#endif
#if (!defined(_MSC_VER)) || (_MSC_VER >= 1930)
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC)
#if !defined(__APPLE__)
// Had to temporary disable fp16 under APPLE ARM64, as compiling
// the source files require a hardware specific compilation flag.
// When building an universial binary for APPLE, this flag would
// cause trouble for x64 target.
#define MLAS_F16VEC_INTRINSICS_SUPPORTED
#endif //
#endif // ARM64
#endif // Visual Studio 16 or earlier does not support fp16 intrinsic
//
// Basic Linear Algebra Subprograms (BLAS) types.
//
#ifndef CBLAS_ENUM_DEFINED_H
#define CBLAS_ENUM_DEFINED_H
typedef enum { CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113 } CBLAS_TRANSPOSE;
typedef enum { CblasUpper=121, CblasLower=122 } CBLAS_UPLO;
typedef enum { CblasNonUnit=131, CblasUnit=132 } CBLAS_DIAG;
typedef enum { CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#endif
//
// Forward declare the thread pool implementation class and half precision floating point.
//
// N.B. Avoid including ONNX Runtime headers here to keep the dependencies for
// standalone MLAS test executables smaller.
//
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
};
struct MLFloat16;
}; // namespace onnxruntime
using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool;
//
// Platform routines.
//
size_t
MLASCALL
MlasGetPreferredBufferAlignment(
void
);
#ifdef MLAS_TARGET_AMD64_IX86
/**
* @brief Return whether the current CPU has over saturation problem
* when computing u8s8 matrix multiplication
* https://www.intel.com/content/www/us/en/develop/documentation/onednn-developer-guide-and-reference/top/advanced-topics/nuances-of-int8-computations.html
*/
bool
MLASCALL
MlasPlatformU8S8Overflow(
void
);
#endif
//
// Activation routines.
//
enum MLAS_ACTIVATION_KIND {
MlasIdentityActivation,
MlasReluActivation,
MlasLeakyReluActivation,
MlasTanhActivation,
MlasLogisticActivation,
MlasClipActivation,
MlasHardSigmoidActivation,
MlasActivationKindCount,
};
struct MLAS_ACTIVATION {
MLAS_ACTIVATION_KIND ActivationKind;
union {
struct {
float alpha;
} LeakyRelu;
struct {
float minimum;
float maximum;
} Clip;
struct {
float alpha;
float beta;
} HardSigmoid;
float Values[2];
} Parameters;
};
void
MLASCALL
MlasActivation(
const MLAS_ACTIVATION* Activation,
float* Buffer,
const float* Bias,
size_t M,
size_t N,
size_t ldc
);
//
// Matrix/matrix multiply routines.
// C := alpha * op(A) * op(B) + beta * C
// op(X) = X or op(X) = transpose(X) or op(X) = conjg(transpose(X))
//
/**
* @brief Supply matrices data information to single precision gemm functions
*/
struct MLAS_SGEMM_DATA_PARAMS {
const float* A = nullptr; /**< Supplies the address of matrix A */
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
const float* B = nullptr; /**< Supplies the address of matrix B */
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
float* C = nullptr; /**< Supplies the address of matrix C */
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
float alpha = 1.0f; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */
float beta = 0.0f; /**< Supplies the scalar beta multiplier (see SGEMM definition) */
bool BIsPacked = false; /**< Whether B is pre-packed */
};
/**
* @brief Batched single precision matrix/matrix multiply operation (SGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_SGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data Supplies the matrices data parameters
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_SGEMM_DATA_PARAMS& Data,
MLAS_THREADPOOL* ThreadPool
)
{
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
float alpha,
const float* A,
size_t lda,
const float* B,
size_t ldb,
float beta,
float* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
{
MLAS_SGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;
MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool);
}
/**
* @brief the single precision matrix/matrix multiply operation (SGEMM) with pre-packed B
*
* @param TransA - Supplies the transpose operation for matrix A.
* @param M - Supplies the number of rows of matrix A and matrix C.
* @param N - Supplies the number of columns of matrix B and matrix C.
* @param K - Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
* @param A - Supplies the address of matrix A.
* @param lda - Supplies the first dimension of matrix A.
* @param PackedB - Supplies the address of packed matrix B.
* @param beta - Supplies the scalar beta multiplier (see SGEMM definition).
* @param C - Supplies the address of matrix C.
* @param ldc - Supplies the first dimension of matrix C.
* @param ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
size_t M,
size_t N,
size_t K,
float alpha,
const float* A,
size_t lda,
const void* PackedB,
float beta,
float* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
{
MLAS_SGEMM_DATA_PARAMS DataParams;
DataParams.A = A;
DataParams.lda = lda;
DataParams.B = static_cast<const float*>(PackedB);
DataParams.ldb = 0;
DataParams.C = C;
DataParams.ldc = ldc;
DataParams.alpha = alpha;
DataParams.beta = beta;
DataParams.BIsPacked = true;
MlasGemmBatch(TransA,
CblasTrans, // deos not matter when B is packed
M, N, K, &DataParams, 1, ThreadPool);
}
/**
* @brief Supply matrices data information to double precision gemm functions
*/
struct MLAS_DGEMM_DATA_PARAMS {
const double* A = nullptr; /**< Supplies the address of matrix A */
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
const double* B = nullptr; /**< Supplies the address of matrix B */
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
double* C = nullptr; /**< Supplies the address of matrix C */
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
double alpha = 1.0; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */
double beta = 0.0; /**< Supplies the scalar beta multiplier (see SGEMM definition) */
};
/**
* @brief Batched double precision matrix/matrix multiply operation (DGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_DGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data Supplies the matrices data parameters
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_DGEMM_DATA_PARAMS& Data,
MLAS_THREADPOOL* ThreadPool
)
{
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
double alpha,
const double* A,
size_t lda,
const double* B,
size_t ldb,
double beta,
double* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
{
MLAS_DGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}
enum class MLAS_QUANTIZATION_GRANULARITY {
PerMatrix,
PerColumn,
};
enum class MLAS_QGEMM_OUTPUT_MODE {
ZeroMode, // overwrite the output buffer
AccumulateMode, // accumulate to the output buffer
};
class MLAS_QGEMM_OUTPUT_PROCESSOR {
public:
virtual
void
Process(
const int32_t*, // Supplies the address of matrix to process
size_t, // Supplies the start row index of matrix
size_t, // Supplies the start col index of matrix
size_t, // Supplies the element count per row to process
size_t, // Supplies the element count per col to process
size_t // Supplies the leading dimension of matrix
) const = 0;
virtual ~MLAS_QGEMM_OUTPUT_PROCESSOR() {}
};
class MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR {
public:
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR(
float* Output,
size_t LeadingDimensionOutput,
const float* Scale,
const float* Bias,
MLAS_QGEMM_OUTPUT_MODE Mode = MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
MLAS_QUANTIZATION_GRANULARITY QuantGran = MLAS_QUANTIZATION_GRANULARITY::PerMatrix) :
Output_(Output),
LeadingDimensionOutput_(LeadingDimensionOutput),
Scale_(Scale),
Bias_(Bias),
OutputMode_(Mode),
QuantGran_(QuantGran)
{
}
void
Process(
const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const override;
private:
template<bool HasBias, MLAS_QGEMM_OUTPUT_MODE Mode, MLAS_QUANTIZATION_GRANULARITY QuantGran>
inline
void
ProcessImpl(
const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const;
private:
float* Output_;
size_t LeadingDimensionOutput_;
const float* Scale_;
const float* Bias_;
MLAS_QGEMM_OUTPUT_MODE OutputMode_;
MLAS_QUANTIZATION_GRANULARITY QuantGran_;
};
/**
* @brief Supply matrices shape and data type information to quantized gemm functions
*
** NOTE: AIsSigned == true is not supported on non-ARM devices for now.
** AIsSigned == true is supported on ARM devices when BIsSigned is also true.
*
*/
struct MLAS_GEMM_QUANT_SHAPE_PARAMS {
size_t M = 0; /**< Supplies the row size of matrix A */
size_t N = 0; /**< Supplies the column size of matrix B */
size_t K = 0; /**< Supplies the column size of matrix A and row size of matrix B */
bool AIsSigned = false; /**< Indicates whether type of A is int8_t or uint8_t.*/
bool BIsSigned = false; /**< Indicates whether type of B is int8_t or uint8_t */
bool IsAccumulateMode = false; /**< Indicates whether to accumulate to matrix C or override matrix C */
};
struct MLAS_GEMM_QUANT_DATA_PARAMS {
const uint8_t* A = nullptr;
size_t lda = 0;
uint8_t ZeroPointA = 0;
const void* B = 0;
size_t ldb = 0;
const uint8_t* ZeroPointB = nullptr;
bool BIsPacked = false;
bool PerColumnZeroPoints = false;
int32_t* C = nullptr;
size_t ldc = 0;
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr;
};
/**
* @brief Batched GEMM, for multiplying multiple pairs of matrices.
* Note: We only support uniform batching, so shapes and types of the
* input must be same: M, N, K, BIsSigned must be the
* same across all parameter blocks.
*
* @param [IN] Shape A single shape descriptor for all the multiplications
* @param [IN] DataParams Array of data descriptors for the matrices.
* @param [IN] BatchN Size of the parameters array, also number of multiplications to perform
* @param [IN] ThreadPool optional thread pool for parallel processing
*/
void
MLASCALL
MlasGemmBatch(
const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_QUANT_DATA_PARAMS* DataParams,
const size_t BatchN,
MLAS_THREADPOOL* ThreadPool
);
inline
void
MlasGemm(
const MLAS_GEMM_QUANT_SHAPE_PARAMS &Shape,
const MLAS_GEMM_QUANT_DATA_PARAMS &DataParams,
MLAS_THREADPOOL *ThreadPool)
{
MlasGemmBatch(Shape, &DataParams, 1, ThreadPool);
}
//
// Symmetric QGEMM has limited buffer overrun.
// Currently only supported in ARM64
//
#if defined(MLAS_TARGET_ARM64)
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 30;
#else
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 0;
#endif
/**
* @brief Supply data parameters for symmetric quantized GEMM.
* B matrix zero point must be zero, and it must be
* pre-packed, with column sums scaled by (-ZeroPointA)
*/
struct MLAS_SYMM_QGEMM_DATA_PARAMS {
const void* A = nullptr;
size_t lda = 0;
const void* B = 0;
void* C = nullptr;
size_t ldc = 0;
// TODO!! add re-quantization parameters
};
/**
* @brief Batched QGEMM. Similar to MlasGemmBatch, but right hand side matrix
* must be symmetrically quantized and prepacked.
*
* @param [IN] Shape A single shape descriptor for all multiplicatons.
Currently A and B must be signed, and accumulation
mode not supported
* @param [IN] DataParams Array of data descriptors, one for each multiplication
* B must be prepacked
* @param [IN] BatchN Number of multiplications
* @param [IN] ThreadPool
*/
void
MLASCALL
MlasSymmQgemmBatch(
const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape,
const MLAS_SYMM_QGEMM_DATA_PARAMS* DataParams,
const size_t BatchN,
MLAS_THREADPOOL* ThreadPool
);
//
// Buffer packing routines.
//
size_t
MLASCALL
MlasGemmPackBSize(
size_t N,
size_t K
);
void
MLASCALL
MlasGemmPackB(
CBLAS_TRANSPOSE TransB,
size_t N,
size_t K,
const float* B,
size_t ldb,
void* PackedB
);
size_t
MLASCALL
MlasGemmPackBSize(
size_t N,
size_t K,
bool AIsSigned,
bool BIsSigned
);
void
MLASCALL
MlasGemmPackB(
size_t N,
size_t K,
const uint8_t* B,
size_t ldb,
bool AIsSigned,
bool BIsSigned,
void* PackedB
);
/**
* @brief For symmetric quantized GEMM, returns size of the
* packing buffer needed for right hand side
* @param N Number of columns
* @param K Number of rows
* @param AIsSigned Whether left hand size is signed int8_t
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t
MLASCALL
MlasSymmQgemmPackBSize(
size_t N,
size_t K,
bool AIsSigned
);
void
MLASCALL
MlasSymmQgemmPackB(
size_t N,
size_t K,
const int8_t* B,
size_t ldb,
bool AIsSigned,
int32_t ZeroPointA,
void* PackedB
);
//
// Convolution routines.
//
enum MLAS_CONV_ALGORITHM {
MlasConvAlgorithmGemmDirect,
MlasConvAlgorithmExpandThenGemm,
MlasConvAlgorithmExpandThenGemmSegmented,
#if defined(MLAS_TARGET_WASM_SCALAR)
MlasConvAlgorithmDepthwise,
#endif
};
struct MLAS_CONV_PARAMETERS {
const MLAS_ACTIVATION* Activation;
size_t Dimensions;
size_t BatchCount;
size_t GroupCount;
size_t InputChannels;
size_t InputShape[3];
size_t KernelShape[3];
size_t DilationShape[3];
size_t Padding[6];
size_t StrideShape[3];
size_t FilterCount;
size_t OutputShape[3];
size_t InputSize;
size_t OutputSize;
size_t K;
float Beta;
MLAS_CONV_ALGORITHM Algorithm;
ptrdiff_t ThreadCount;
union {
struct {
CBLAS_TRANSPOSE TransB;
size_t ldb;
} GemmDirect;
struct {
size_t ThreadStrideN;
} ExpandThenGemmSegmented;
} u;
};
void MLASCALL
MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters,
size_t Dimensions,
size_t BatchCount,
size_t GroupCount,
size_t InputChannels,
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* DilationShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
size_t FilterCount,
const MLAS_ACTIVATION* Activation,
size_t* WorkingBufferSize,
float Beta,
MLAS_THREADPOOL* ThreadPool);
void
MLASCALL
MlasConv(
const MLAS_CONV_PARAMETERS* Parameters,
const float* Input,
const float* Filter,
const float* Bias,
float* WorkingBuffer,
float* Output,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasConvDepthwise(
const void* const* Input,
int32_t InputZeroPoint,
bool InputIsSigned,
const void* Filter,
int32_t FilterZeroPoint,
bool FilterIsSigned,
int32_t* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize
);
//
// Symmetric quantized integer convolution routines.
//
size_t
MlasConvSymPackWSize(
size_t GroupCount,
size_t InputChannels,
size_t OutputChannels,
size_t KernelSize,
bool InputIsSigned
);
void
MlasConvSymPackW(
size_t GroupCount,
size_t InputChannels,
size_t OutputChannels,
size_t KernelSize,
const int8_t* W,
int8_t* PackedW,
size_t PackedWSize,
bool InputIsSigned
);
int32_t
MlasConvSymFixupInputZeroPoint(
int32_t zero_point_value,
bool InputIsSigned
);
//
// Convolution operators (or maybe others in the future) need to do their
// own job partition. Since filters (right hand side B matrix) is usually
// small in size, activations are divided horizontally. We need to provide
// kernel stride units to facilitate the divide.
//
int32_t
MlasConvSymGetKernelOutputCount(
bool InputIsSigned
);
int32_t
MlasConvSymDepthwiseGetKernelOutputCnt(
bool InputIsSigned
);
/**
* @brief Returns the stride M of depthwise conv kernel
*
* Most optimized path is Symmetric conv. See
* MlasConvSymDepthwiseGetKernelOutputCnt(bool)
*
* These kernels are implemented in qdwconv.cpp using
* intrincic, all of them with stride val 1. We use
* a slightly bigger value to improve cache reuse.
*
* This needs to be changed if we optimize depthwise
* kernels.
*
* @return
*/
inline
int32_t
MlasConvDepthwiseGetKernelOutputCnt()
{
return 4;
}
int32_t
MlasSymmQgemmGetKernelOutputCnt();
int32_t
MlasQgemmGetKernelOutputCnt(
bool AIsSigned,
bool BIsSigned
);
struct MLAS_CONV_SYM_PARAMS {
const void* InputDirect;
const void* const* InputIndirection;
const void* Filter;
void* Output;
size_t InputChannels;
size_t OutputChannels;
size_t OutputCount;
size_t KernelSize;
const int32_t* Bias;
const float* Scale;
bool PerChannelScale;
int32_t OutputZeroPoint;
bool InputIsSigned;
};
void
MlasConvSym(
const MLAS_CONV_SYM_PARAMS& Params
);
void
MlasConvSymDepthwise(
const MLAS_CONV_SYM_PARAMS& Params
);
//
// Pooling routines.
//
enum MLAS_POOLING_KIND {
MlasMaximumPooling,
MlasAveragePoolingExcludePad,
MlasAveragePoolingIncludePad,
MlasPoolingKindCount,
};
void
MLASCALL
MlasPool(
MLAS_POOLING_KIND PoolingKind,
size_t Dimensions,
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
const float* Input,
float* Output,
MLAS_THREADPOOL* ThreadPool
);
template<typename T8Bits>
void
MLASCALL
MlasMaximumPool(
const T8Bits* const* Input,
T8Bits* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize
);
//
// Miscellaneous compute routines.
//
void
MLASCALL
MlasComputeErf(
const float* Input,
float* Output,
size_t N
);
void
MLASCALL
MlasComputeExp(
const float* Input,
float* Output,
size_t N
);
void
MLASCALL
MlasComputeLogistic(
const float* Input,
float* Output,
size_t N
);
void
MLASCALL
MlasComputeSoftmax(
const float* Input,
float* Output,
size_t N,
size_t D,
bool LogSoftmax,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasComputeTanh(
const float* Input,
float* Output,
size_t N
);
//
// Half-precision floating-point routines.
//
extern "C"
void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
);
//
// Transpose routines.
//
void
MLASCALL
MlasTranspose(
const uint8_t* Input,
uint8_t* Output,
size_t M,
size_t N
);
void
MLASCALL
MlasTranspose(
const int8_t* Input,
int8_t* Output,
size_t M,
size_t N
);
void
MLASCALL
MlasTranspose(
const uint16_t* Input,
uint16_t* Output,
size_t M,
size_t N
);
void
MLASCALL
MlasTranspose(
const uint32_t* Input,
uint32_t* Output,
size_t M,
size_t N
);
void
MLASCALL
MlasTranspose(
const float* Input,
float* Output,
size_t M,
size_t N
);
//
// Buffer reordering routines.
//
void
MLASCALL
MlasReorderInputNchw(
const float* S,
float* D,
size_t InputChannels,
size_t InputSize
);
void
MLASCALL
MlasReorderInputNhwc(
const float* S,
float* D,
size_t InputChannels,
size_t RowCount,
size_t FullRowCount
);
void
MLASCALL
MlasReorderOutputNchw(
const int64_t* OutputShape,
const float* S,
float* D,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasReorderOutputNhwc(
const int64_t* OutputShape,
const float* S,
float* D
);
void
MLASCALL
MlasReorderFilterOIHWBiBo(
const int64_t* FilterShape,
const float* S,
float* D
);
void
MLASCALL
MlasReorderFilterOIHWBo(
const int64_t* FilterShape,
const float* S,
float* D
);
//
// Single precision NCHWc routines.
//
size_t
MLASCALL
MlasNchwcGetBlockSize(
void
);
void
MLASCALL
MlasNchwcConv(
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* DilationShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
size_t GroupCount,
const float* Input,
const float* Filter,
const float* Bias,
float* Output,
const MLAS_ACTIVATION* Activation,
bool ZeroMode,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasNchwcPool(
MLAS_POOLING_KIND PoolingKind,
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* DilationShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
const float* Input,
float* Output,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasNchwcUpsampleNearest(
const int64_t* InputShape,
const int64_t* Scales,
const float* Input,
float* Output
);
void
MLASCALL
MlasNchwcUpsampleLinear(
size_t InputHeight,
size_t InputWidth,
size_t OutputWidth,
float InterpolationHeight,
const float* InterpolationWidth,
const float* Input,
float* Output
);
//
// Linear quantization routines.
//
template<typename OutputType>
void
MLASCALL
MlasQuantizeLinear(
const float* Input,
OutputType* Output,
size_t N,
float Scale,
OutputType ZeroPoint
);
/**
* @brief Requantize a block of the intermediate buffer to the output buffer,
* optionally adding the supplied bias
*
* @param Input Input matrix
* @param InputLeadingDimension Input matrix leading dimension
* @param Output Output matrix
* @param OutputLeadingDimension Output matrix leading dimension
* @param Bias Optional bias vector, to be added
to the input before quantization
* @param Scale Quantization scale
* @param PerColumnScale true if scale is per-column
* @param ZeroPoint quantization zero point value
* @param StartM
* @param StartN
* @param CountM
* @param CountN
* @return
*/
template<typename OutputType>
void
MLASCALL
MlasRequantizeOutput(
const int32_t* Input,
size_t InputLeadingDimension,
OutputType* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
const float* Scale,
bool PerColumnScale,
OutputType ZeroPoint,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN
);
class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR
{
public:
MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR(
void* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
const float* Scale,
bool PerColumnScale,
int32_t ZeroPoint,
bool OutputIsSigned)
: Output_(Output),
OutputLeadingDimension_(OutputLeadingDimension),
Bias_(Bias),
Scale_(Scale),
PerColumnScale_(PerColumnScale),
ZeroPoint_(ZeroPoint),
OutputIsSigned_(OutputIsSigned)
{
}
void Process(const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc) const override
{
if(OutputIsSigned_){
MlasRequantizeOutput(C, ldc, reinterpret_cast<int8_t*>(Output_), OutputLeadingDimension_,
Bias_, Scale_, PerColumnScale_, static_cast<int8_t>(ZeroPoint_),
StartM, StartN, CountM, CountN);
} else {
MlasRequantizeOutput(C, ldc, reinterpret_cast<uint8_t*>(Output_), OutputLeadingDimension_,
Bias_, Scale_, PerColumnScale_, static_cast<uint8_t>(ZeroPoint_),
StartM, StartN, CountM, CountN);
}
}
private:
void* Output_;
size_t OutputLeadingDimension_;
const int32_t* Bias_;
const float* Scale_;
bool PerColumnScale_;
int32_t ZeroPoint_;
bool OutputIsSigned_;
};
void
MLASCALL
MlasFindMinMaxElement(
const float* Input,
float* Min,
float* Max,
size_t N
);
size_t
MLASCALL
MlasQLinearSafePaddingElementCount(
size_t ElementSize,
size_t ElementCount
);
template<typename T8Bits>
void
MLASCALL
MlasQLinearGlobalAveragePoolNchw(
const T8Bits* Input,
float ScaleInput,
int32_t ZeroPointInput,
T8Bits* Output,
float ScaleOutput,
int32_t ZeroPointOutput,
size_t Channels,
size_t ImageSize,
int32_t* AccumulateBuffer
);
template <typename T8Bits>
void
MLASCALL
MlasQLinearGlobalAveragePoolNhwc(
const T8Bits* Input,
float ScaleInput,
int32_t ZeroPointInput,
T8Bits* Output,
float ScaleOutput,
int32_t ZeroPointOutput,
size_t Batch,
size_t ImageSize,
size_t Stride,
size_t Channels,
int32_t* AccumulateBuffer,
const T8Bits* ZeroBuffer
);
//
// InputA is of size N,
// Input B is of size 1 if IsScalarB == true, otherwise it is of size N
//
template<typename DataType>
void
MLASCALL
MlasQLinearAdd(
const DataType* InputA,
float ScaleA,
int32_t ZeroPointA,
const DataType* InputB,
float ScaleB,
int32_t ZeroPointB,
float ScaleC,
int32_t ZeroPointC,
DataType* OutputC,
size_t N,
bool IsScalarB
);
template<typename DataType>
void
MLASCALL
MlasQLinearMul(
const DataType* InputA,
float ScaleA,
int32_t ZeroPointA,
const DataType* InputB,
float ScaleB,
int32_t ZeroPointB,
float ScaleC,
int32_t ZeroPointC,
DataType* OutputC,
size_t N,
bool IsScalarB
);
//
// Half precision routines
//
// Any type with size=2 should work
using MLAS_FP16 = onnxruntime::MLFloat16;
constexpr size_t FP16_SIZE = sizeof(uint16_t);
/**
* @brief Whether current CPU supports FP16 acceleration.
*/
bool MLASCALL
MlasFp16AccelerationSupported();
/**
* @brief Interface for half gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from half precision to single precision, etc.
*
* Half GEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_HALF_GEMM_POSTPROCESSOR {
public:
virtual
void
Process(
MLAS_FP16*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;
virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {}
};
/**
* @brief Half precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
* And the supplied sum tensor will be added to the final result.
*/
class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR
{
public:
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR(
const MLAS_ACTIVATION& Activation,
const MLAS_FP16* SumBuf = nullptr)
: Activation_(Activation), SumBuf_(SumBuf)
{}
void Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const override;
private:
const MLAS_ACTIVATION& Activation_;
const MLAS_FP16* SumBuf_;
};
inline
void
MlasFp16Activation(
const MLAS_ACTIVATION* Activation,
MLAS_FP16* Buffer,
size_t M,
size_t N,
size_t ldc
)
{
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(*Activation);
proc.Process(Buffer, 0, 0, M, N, ldc);
}
/**
* @brief Convert half gemm result matrix to single precision float matrix
*/
class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
public:
MLAS_HALF_GEMM_2FLOAT_PROCESSOR(
const MLAS_ACTIVATION& Activation,
float* Output, /**< address of the output matrix, row major */
size_t RowStride /**< row stride of the output matrix */
) : Activation_(Activation),
Output_(Output),
RowStride_(RowStride)
{}
void
Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const override;
private:
const MLAS_ACTIVATION& Activation_;
float* Output_;
const size_t RowStride_;
};
/**
* @brief Data parameters for half precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_HALF_GEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */
MLAS_FP16* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/
};
/**
* @brief Half precision Batched GEMM: C = A * B + Bias
* Either A or B can be fp32 or fp16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void
MLASCALL
MlasHalfGemmBatch(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr
);
/**
* @brief For half precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] float2half Whether the input is float that
* needs to be converted to half precision
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t
MLASCALL
MlasHalfGemmPackBSize(
size_t N,
size_t K,
bool float2half
);
/**
* @brief For half precision GEMM, pack the right hand
* side matrix B
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void
MLASCALL
MlasHalfGemmPackB(
size_t N,
size_t K,
const MLAS_FP16* B,
size_t ldb,
void* PackedB
);
/**
* @brief For half precision GEMM, convert the float matrix B
* to half precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void
MLASCALL
MlasHalfGemmConvertPackB(
size_t N,
size_t K,
const float* B,
size_t ldb,
void* PackedB
);
/**
* @brief Indirect Depthwise convolution for fp16
* @param Input Supplies the indirect buffer for NHWC input
* @param Filter Supplies the address for filter tensor
* @param Output Supplies the address for the result tensor
* @param Channels # of input channels
* @param OutputCount # of output pixels
* @param KernelSize # kernel size
* @return
*/
void
MLASCALL
MlasConvDepthwise(
const MLAS_FP16* const* Input,
const MLAS_FP16* Filter,
MLAS_FP16* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize,
MLAS_HALF_GEMM_POSTPROCESSOR* PostProc
);
inline
void
MlasTranspose(
const MLAS_FP16* Input,
MLAS_FP16* Output,
size_t M,
size_t N
)
{
MlasTranspose(
reinterpret_cast<const uint16_t*>(Input),
reinterpret_cast<uint16_t*>(Output),
M, N);
}
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
/**
* @brief Max Pooling for fp16 NHWC
* @param Input Indirect buffer to activations
* @param Output Address of the result tensor
* @param Channels C in NHWC
* @param OutputCount Number of output pixels
* @param KernelSize Size of the kernel
* @return
*/
void
MLASCALL
MlasNhwcMaxPool(
const MLAS_FP16* const* Input,
MLAS_FP16* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize
);
/**
* @brief Avg Pooling for fp16 nhwc
* @param Input Indirect buffer to activations
* @param Output Address of the output data
* @param Channels C in NHWC
* @param OutputCount Number of output pixels
* @param KernelSize size of the kernel
* @return
*/
void
MLASCALL
MlasNhwcAvgPool(
const MLAS_FP16* const* Input,
MLAS_FP16* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize
);
#endif