mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Support OuterStride for QGemm when MLAS_SUPPORTS_GEMM_U8X8 undefined (#5374)
Quantized GEMM on ARM doesn't support the case that leading dimension is not equal to column size. The PR adds support of this case.
This commit is contained in:
parent
668ab04917
commit
24f99b3be8
4 changed files with 77 additions and 23 deletions
|
|
@ -40,9 +40,19 @@ MakeOutputPipelineWithOutBias(std::int32_t result_offset,
|
|||
return std::make_tuple(quantize_down_stage, saturating_cast_stage);
|
||||
}
|
||||
|
||||
void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data,
|
||||
const int lhs_offset, const int rhs_offset, const int result_offset,
|
||||
int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias) {
|
||||
void GemmlowpMultiplyu8u8_u8(
|
||||
const uint8_t* lhs_data,
|
||||
const uint8_t* rhs_data,
|
||||
uint8_t* result_data,
|
||||
const int lhs_offset,
|
||||
const int rhs_offset,
|
||||
const int result_offset,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int32_t int_multiplier,
|
||||
int32_t right_shift,
|
||||
const int32_t* bias) {
|
||||
// TODO exp ColMajor order for rhs and result. That may be faster
|
||||
const auto matOrder = gemmlowp::MapOrder::RowMajor;
|
||||
gemmlowp::MatrixMap<const uint8_t, matOrder> lhs(lhs_data, m, k);
|
||||
|
|
@ -64,13 +74,23 @@ void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, u
|
|||
}
|
||||
}
|
||||
|
||||
void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data,
|
||||
const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool* ) {
|
||||
|
||||
void GemmlowpMultiplyu8u8_s32(
|
||||
const uint8_t* lhs_data,
|
||||
const uint8_t* rhs_data,
|
||||
int32_t* result_data,
|
||||
const int lhs_offset,
|
||||
const int rhs_offset,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
concurrency::ThreadPool*) {
|
||||
const auto matOrder = gemmlowp::MapOrder::RowMajor;
|
||||
gemmlowp::MatrixMap<const uint8_t, matOrder> lhs(lhs_data, m, k);
|
||||
gemmlowp::MatrixMap<const uint8_t, matOrder> rhs(rhs_data, k, n);
|
||||
gemmlowp::MatrixMap<std::int32_t, matOrder> result(result_data, m, n);
|
||||
gemmlowp::MatrixMap<const uint8_t, matOrder> lhs(lhs_data, m, k, lda);
|
||||
gemmlowp::MatrixMap<const uint8_t, matOrder> rhs(rhs_data, k, n, ldb);
|
||||
gemmlowp::MatrixMap<std::int32_t, matOrder> result(result_data, m, n, ldc);
|
||||
|
||||
gemmlowp::GemmContext gemm_context;
|
||||
|
||||
|
|
@ -80,4 +100,4 @@ void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data,
|
|||
&gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, empty_pipeline);
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -29,11 +29,12 @@ void inline QuantizeMultiplier(float fp_multiplier, std::int32_t* integer_multip
|
|||
}
|
||||
|
||||
void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data,
|
||||
const int lhs_offset, const int rhs_offset, const int result_offset,
|
||||
int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias = nullptr);
|
||||
const int lhs_offset, const int rhs_offset, const int result_offset,
|
||||
int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias = nullptr);
|
||||
|
||||
void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data,
|
||||
const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool*);
|
||||
const int lhs_offset, const int rhs_offset,
|
||||
int m, int n, int k, int lda, int ldb, int ldc, concurrency::ThreadPool*);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -63,25 +63,42 @@ namespace onnxruntime {
|
|||
// common Eigen types that we will often use
|
||||
template <typename T>
|
||||
using EigenMatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
||||
|
||||
template <typename T>
|
||||
using EigenArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
||||
|
||||
template <typename T>
|
||||
using EigenVectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
||||
|
||||
template <typename T>
|
||||
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
|
||||
|
||||
template <typename T>
|
||||
using ConstEigenMatrixMap = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
||||
|
||||
template <typename T>
|
||||
using ConstEigenArrayMap = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
||||
|
||||
template <typename T>
|
||||
using ConstEigenVectorMap = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
||||
|
||||
template <typename T>
|
||||
using ConstEigenVectorArrayMap = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
|
||||
|
||||
template <typename T>
|
||||
using EigenMatrixMapRowMajor = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
|
||||
template <typename T>
|
||||
using ConstEigenMatrixMapRowMajor = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
|
||||
template <typename T>
|
||||
using EigenMatrixMapRowMajorOuterStride =
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>, 0, Eigen::OuterStride<>>;
|
||||
|
||||
template <typename T>
|
||||
using ConstEigenMatrixMapRowMajorOuterStride =
|
||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>, 0, Eigen::OuterStride<>>;
|
||||
|
||||
template <typename T>
|
||||
auto EigenMap(Tensor& t) -> EigenVectorMap<T> {
|
||||
return EigenVectorMap<T>(t.template MutableData<T>(), t.Shape().Size());
|
||||
|
|
|
|||
|
|
@ -14,15 +14,29 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
template <typename TA, typename TB, typename TY>
|
||||
void QGemmWithEigen(const TA* A_data, const TB* B_data, TY* Y_data, int M, int N, int K, TA a_offset, TB b_offset) {
|
||||
auto A = ConstEigenMatrixMapRowMajor<TA>(A_data, M, K);
|
||||
auto B = ConstEigenMatrixMapRowMajor<TB>(B_data, K, N);
|
||||
void QGemmWithEigen(
|
||||
const TA* A_data,
|
||||
const TB* B_data,
|
||||
TY* Y_data,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
TA a_offset,
|
||||
TB b_offset) {
|
||||
auto A = ConstEigenMatrixMapRowMajorOuterStride<TA>(A_data, M, K, Eigen::OuterStride<>(lda));
|
||||
auto B = ConstEigenMatrixMapRowMajorOuterStride<TB>(B_data, K, N, Eigen::OuterStride<>(ldb));
|
||||
|
||||
auto A_row_sum = (A.template cast<TY>().rowwise().sum()) * static_cast<TY>(b_offset);
|
||||
auto B_col_sum = (B.template cast<TY>().colwise().sum()) * static_cast<TY>(a_offset);
|
||||
EigenMatrixMapRowMajor<TY>(Y_data, M, N) = A.template cast<TY>() * B.template cast<TY>() + static_cast<TY>(K * a_offset * b_offset) * ConstEigenMatrixMapRowMajor<TY>::Ones(M, N);
|
||||
EigenMatrixMapRowMajor<TY>(Y_data, M, N).colwise() -= A_row_sum;
|
||||
EigenMatrixMapRowMajor<TY>(Y_data, M, N).rowwise() -= B_col_sum;
|
||||
EigenMatrixMapRowMajorOuterStride<TY>(Y_data, M, N, Eigen::OuterStride<>(ldc)) =
|
||||
A.template cast<TY>() * B.template cast<TY>() +
|
||||
static_cast<TY>(K * a_offset * b_offset) * ConstEigenMatrixMapRowMajor<TY>::Ones(M, N);
|
||||
|
||||
EigenMatrixMapRowMajorOuterStride<TY>(Y_data, M, N, Eigen::OuterStride<>(ldc)).colwise() -= A_row_sum;
|
||||
EigenMatrixMapRowMajorOuterStride<TY>(Y_data, M, N, Eigen::OuterStride<>(ldc)).rowwise() -= B_col_sum;
|
||||
}
|
||||
|
||||
void QGemm(
|
||||
|
|
@ -42,12 +56,14 @@ void QGemm(
|
|||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, result_data, ldc, thread_pool);
|
||||
#else
|
||||
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "Only RowMajor*RowMajor=RowMajor format is supported");
|
||||
|
||||
if (rhs_signed) {
|
||||
QGemmWithEigen<uint8_t, int8_t, int32_t>(lhs_data, reinterpret_cast<const int8_t*>(rhs_data), result_data, M, N, K, lhs_offset, static_cast<int8_t>(rhs_offset));
|
||||
QGemmWithEigen<uint8_t, int8_t, int32_t>(lhs_data, reinterpret_cast<const int8_t*>(rhs_data), result_data,
|
||||
M, N, K, lda, ldb, ldc,
|
||||
lhs_offset, static_cast<int8_t>(rhs_offset));
|
||||
} else {
|
||||
GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool);
|
||||
GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data,
|
||||
lhs_offset, rhs_offset,
|
||||
M, N, K, lda, ldb, ldc, thread_pool);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue