mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Merge remote-tracking branch 'upstream/master' into DmlDev
This commit is contained in:
commit
6ee6fdbbba
4 changed files with 19 additions and 24 deletions
|
|
@ -12,6 +12,19 @@
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename TA, typename TB, typename TY>
|
||||
void QGemmWithEigen(const TA* A_data, const TB* B_data, TY* Y_data, int M, int N, int K, TA a_offset, TB b_offset) {
|
||||
auto A = ConstEigenMatrixMapRowMajor<TA>(A_data, M, K);
|
||||
auto B = ConstEigenMatrixMapRowMajor<TB>(B_data, K, N);
|
||||
|
||||
auto A_row_sum = (A.template cast<TY>().rowwise().sum()) * static_cast<TY>(b_offset);
|
||||
auto B_col_sum = (B.template cast<TY>().colwise().sum()) * static_cast<TY>(a_offset);
|
||||
EigenMatrixMapRowMajor<TY>(Y_data, M, N) = A.template cast<TY>() * B.template cast<TY>() + static_cast<TY>(K * a_offset * b_offset) * ConstEigenMatrixMapRowMajor<TY>::Ones(M, N);
|
||||
EigenMatrixMapRowMajor<TY>(Y_data, M, N).colwise() -= A_row_sum;
|
||||
EigenMatrixMapRowMajor<TY>(Y_data, M, N).rowwise() -= B_col_sum;
|
||||
}
|
||||
|
||||
template <>
|
||||
void QGemm<uint8_t, int8_t, int32_t>(
|
||||
int M,
|
||||
|
|
@ -29,20 +42,12 @@ void QGemm<uint8_t, int8_t, int32_t>(
|
|||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(M);
|
||||
ORT_UNUSED_PARAMETER(N);
|
||||
ORT_UNUSED_PARAMETER(K);
|
||||
ORT_UNUSED_PARAMETER(lhs_data);
|
||||
ORT_UNUSED_PARAMETER(lda);
|
||||
ORT_UNUSED_PARAMETER(lhs_offset);
|
||||
ORT_UNUSED_PARAMETER(rhs_data);
|
||||
ORT_UNUSED_PARAMETER(ldb);
|
||||
ORT_UNUSED_PARAMETER(rhs_offset);
|
||||
ORT_UNUSED_PARAMETER(result_data);
|
||||
ORT_UNUSED_PARAMETER(ldc);
|
||||
ORT_UNUSED_PARAMETER(thread_pool);
|
||||
|
||||
ORT_NOT_IMPLEMENTED("MatMulInteger: activation uint8 and weight int8 not supported on ARM");
|
||||
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For Eigen only RowMajor*RowMajor=RowMajor format is supported");
|
||||
|
||||
QGemmWithEigen<uint8_t, int8_t, int32_t>(lhs_data, rhs_data, result_data, M, N, K, lhs_offset, rhs_offset);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -104,8 +109,7 @@ void QGemm(
|
|||
#endif
|
||||
}
|
||||
|
||||
template
|
||||
void QGemm<uint8_t, int8_t>(
|
||||
template void QGemm<uint8_t, int8_t>(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
|
|
@ -121,8 +125,7 @@ void QGemm<uint8_t, int8_t>(
|
|||
const float* bias,
|
||||
concurrency::ThreadPool* thread_pool);
|
||||
|
||||
template
|
||||
void QGemm<uint8_t, uint8_t>(
|
||||
template void QGemm<uint8_t, uint8_t>(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
|
|
|
|||
|
|
@ -49,13 +49,11 @@ void TestDynamicQuantizeMatMul(const std::vector<int64_t>& A_dims,
|
|||
}
|
||||
|
||||
TEST(DynamicQuantizeMatMul, Int8_test) {
|
||||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
std::vector<int64_t> A_dims{4, 128};
|
||||
std::vector<int64_t> B_dims{128, 128};
|
||||
std::vector<int64_t> Y_dims{4, 128};
|
||||
|
||||
TestDynamicQuantizeMatMul<int8_t>(A_dims, B_dims, "testdata/dynamic_quantize_matmul_int8.onnx");
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicQuantizeMatMul, UInt8_test) {
|
||||
|
|
|
|||
|
|
@ -135,11 +135,9 @@ void RunQAttention(const std::vector<float>& input_data, // input:
|
|||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
} else {
|
||||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -302,13 +302,9 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
#define RUN_MATMUL_INTEGER_U8S8(M, N, K) \
|
||||
RunMatMulIntegerU8S8Test(M, N, K, false /*non_zero_zp*/); \
|
||||
RunMatMulIntegerU8S8Test(M, N, K, true /*non_zero_zp*/);
|
||||
#else
|
||||
#define RUN_MATMUL_INTEGER_U8S8(M, N, K)
|
||||
#endif // MLAS_SUPPORTS_GEMM_U8X8
|
||||
|
||||
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_Scalar) {
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 1, 32);
|
||||
|
|
|
|||
Loading…
Reference in a new issue