diff --git a/onnxruntime/core/util/gemmlowp_common.cc b/onnxruntime/core/util/gemmlowp_common.cc index d4db43d715..1117b60ae9 100644 --- a/onnxruntime/core/util/gemmlowp_common.cc +++ b/onnxruntime/core/util/gemmlowp_common.cc @@ -40,9 +40,19 @@ MakeOutputPipelineWithOutBias(std::int32_t result_offset, return std::make_tuple(quantize_down_stage, saturating_cast_stage); } -void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data, - const int lhs_offset, const int rhs_offset, const int result_offset, - int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias) { +void GemmlowpMultiplyu8u8_u8( + const uint8_t* lhs_data, + const uint8_t* rhs_data, + uint8_t* result_data, + const int lhs_offset, + const int rhs_offset, + const int result_offset, + int m, + int n, + int k, + int32_t int_multiplier, + int32_t right_shift, + const int32_t* bias) { // TODO exp ColMajor order for rhs and result. That may be faster const auto matOrder = gemmlowp::MapOrder::RowMajor; gemmlowp::MatrixMap lhs(lhs_data, m, k); @@ -64,13 +74,23 @@ void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, u } } -void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data, - const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool* ) { - +void GemmlowpMultiplyu8u8_s32( + const uint8_t* lhs_data, + const uint8_t* rhs_data, + int32_t* result_data, + const int lhs_offset, + const int rhs_offset, + int m, + int n, + int k, + int lda, + int ldb, + int ldc, + concurrency::ThreadPool*) { const auto matOrder = gemmlowp::MapOrder::RowMajor; - gemmlowp::MatrixMap lhs(lhs_data, m, k); - gemmlowp::MatrixMap rhs(rhs_data, k, n); - gemmlowp::MatrixMap result(result_data, m, n); + gemmlowp::MatrixMap lhs(lhs_data, m, k, lda); + gemmlowp::MatrixMap rhs(rhs_data, k, n, ldb); + gemmlowp::MatrixMap result(result_data, m, n, ldc); gemmlowp::GemmContext gemm_context; @@ -80,4 +100,4 @@ void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, &gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, empty_pipeline); } -} +} // namespace onnxruntime diff --git a/onnxruntime/core/util/gemmlowp_common.h b/onnxruntime/core/util/gemmlowp_common.h index 01d7632bcd..ccd3f8503a 100644 --- a/onnxruntime/core/util/gemmlowp_common.h +++ b/onnxruntime/core/util/gemmlowp_common.h @@ -29,11 +29,12 @@ void inline QuantizeMultiplier(float fp_multiplier, std::int32_t* integer_multip } void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data, - const int lhs_offset, const int rhs_offset, const int result_offset, - int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias = nullptr); + const int lhs_offset, const int rhs_offset, const int result_offset, + int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias = nullptr); void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data, - const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool*); + const int lhs_offset, const int rhs_offset, + int m, int n, int k, int lda, int ldb, int ldc, concurrency::ThreadPool*); } // namespace onnxruntime diff --git a/onnxruntime/core/util/math_cpuonly.h b/onnxruntime/core/util/math_cpuonly.h index bf61e22e8e..827b555138 100644 --- a/onnxruntime/core/util/math_cpuonly.h +++ b/onnxruntime/core/util/math_cpuonly.h @@ -63,25 +63,42 @@ namespace onnxruntime { // common Eigen types that we will often use template using EigenMatrixMap = Eigen::Map>; + template using EigenArrayMap = Eigen::Map>; + template using EigenVectorMap = Eigen::Map>; + template using EigenVectorArrayMap = Eigen::Map>; + template using ConstEigenMatrixMap = Eigen::Map>; + template using ConstEigenArrayMap = Eigen::Map>; + template using ConstEigenVectorMap = Eigen::Map>; + template using ConstEigenVectorArrayMap = Eigen::Map>; + template using EigenMatrixMapRowMajor = Eigen::Map>; + template using ConstEigenMatrixMapRowMajor = Eigen::Map>; +template +using EigenMatrixMapRowMajorOuterStride = + Eigen::Map, 0, Eigen::OuterStride<>>; + +template +using ConstEigenMatrixMapRowMajorOuterStride = + Eigen::Map, 0, Eigen::OuterStride<>>; + template auto EigenMap(Tensor& t) -> EigenVectorMap { return EigenVectorMap(t.template MutableData(), t.Shape().Size()); diff --git a/onnxruntime/core/util/qmath.cc b/onnxruntime/core/util/qmath.cc index 556130e844..e3aa0b761f 100644 --- a/onnxruntime/core/util/qmath.cc +++ b/onnxruntime/core/util/qmath.cc @@ -14,15 +14,29 @@ namespace onnxruntime { template -void QGemmWithEigen(const TA* A_data, const TB* B_data, TY* Y_data, int M, int N, int K, TA a_offset, TB b_offset) { - auto A = ConstEigenMatrixMapRowMajor(A_data, M, K); - auto B = ConstEigenMatrixMapRowMajor(B_data, K, N); +void QGemmWithEigen( + const TA* A_data, + const TB* B_data, + TY* Y_data, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + TA a_offset, + TB b_offset) { + auto A = ConstEigenMatrixMapRowMajorOuterStride(A_data, M, K, Eigen::OuterStride<>(lda)); + auto B = ConstEigenMatrixMapRowMajorOuterStride(B_data, K, N, Eigen::OuterStride<>(ldb)); auto A_row_sum = (A.template cast().rowwise().sum()) * static_cast(b_offset); auto B_col_sum = (B.template cast().colwise().sum()) * static_cast(a_offset); - EigenMatrixMapRowMajor(Y_data, M, N) = A.template cast() * B.template cast() + static_cast(K * a_offset * b_offset) * ConstEigenMatrixMapRowMajor::Ones(M, N); - EigenMatrixMapRowMajor(Y_data, M, N).colwise() -= A_row_sum; - EigenMatrixMapRowMajor(Y_data, M, N).rowwise() -= B_col_sum; + EigenMatrixMapRowMajorOuterStride(Y_data, M, N, Eigen::OuterStride<>(ldc)) = + A.template cast() * B.template cast() + + static_cast(K * a_offset * b_offset) * ConstEigenMatrixMapRowMajor::Ones(M, N); + + EigenMatrixMapRowMajorOuterStride(Y_data, M, N, Eigen::OuterStride<>(ldc)).colwise() -= A_row_sum; + EigenMatrixMapRowMajorOuterStride(Y_data, M, N, Eigen::OuterStride<>(ldc)).rowwise() -= B_col_sum; } void QGemm( @@ -42,12 +56,14 @@ void QGemm( #ifdef MLAS_SUPPORTS_GEMM_U8X8 MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, result_data, ldc, thread_pool); #else - ORT_ENFORCE(lda == K && ldb == N && ldc == N, "Only RowMajor*RowMajor=RowMajor format is supported"); - if (rhs_signed) { - QGemmWithEigen(lhs_data, reinterpret_cast(rhs_data), result_data, M, N, K, lhs_offset, static_cast(rhs_offset)); + QGemmWithEigen(lhs_data, reinterpret_cast(rhs_data), result_data, + M, N, K, lda, ldb, ldc, + lhs_offset, static_cast(rhs_offset)); } else { - GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool); + GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, + lhs_offset, rhs_offset, + M, N, K, lda, ldb, ldc, thread_pool); } #endif }