mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update vsinpu ep cross-compiling patch (#21963)
- Block the bf16 && ummla gemm functions because we cannot support these features yet
This commit is contained in:
parent
dd2425932d
commit
d4290f6e7f
1 changed files with 216 additions and 2 deletions
|
|
@ -1,5 +1,5 @@
|
|||
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
|
||||
index 66f4aea606..481109e560 100644
|
||||
index c02ac2096d..2bc51298f0 100644
|
||||
--- a/cmake/onnxruntime_mlas.cmake
|
||||
+++ b/cmake/onnxruntime_mlas.cmake
|
||||
@@ -361,7 +361,7 @@ else()
|
||||
|
|
@ -12,7 +12,7 @@ index 66f4aea606..481109e560 100644
|
|||
${mlas_platform_srcs}
|
||||
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
|
||||
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
|
||||
index 675f7c7a13..eb7ed77911 100644
|
||||
index e46105324a..414c46a1ce 100644
|
||||
--- a/onnxruntime/core/mlas/inc/mlas.h
|
||||
+++ b/onnxruntime/core/mlas/inc/mlas.h
|
||||
@@ -82,6 +82,9 @@ Abstract:
|
||||
|
|
@ -33,3 +33,217 @@ index 675f7c7a13..eb7ed77911 100644
|
|||
#endif //
|
||||
#endif // ARM64
|
||||
#endif // Visual Studio 16 or earlier does not support fp16 intrinsic
|
||||
@@ -1635,6 +1639,7 @@ MlasHalfGemmConvertPackB(
|
||||
);
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
/**
|
||||
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
|
||||
*/
|
||||
@@ -1746,6 +1751,7 @@ MlasSBGemmPackBSize(size_t N, size_t K);
|
||||
void MLASCALL
|
||||
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
|
||||
#endif
|
||||
+#endif
|
||||
|
||||
/**
|
||||
* @brief Indirect Depthwise convolution for fp16
|
||||
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
|
||||
index 4239e2ecae..3df7e5573d 100644
|
||||
--- a/onnxruntime/core/mlas/lib/mlasi.h
|
||||
+++ b/onnxruntime/core/mlas/lib/mlasi.h
|
||||
@@ -361,6 +361,7 @@ size_t
|
||||
#else
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)(
|
||||
const float* A,
|
||||
const bfloat16_t* B,
|
||||
@@ -373,6 +374,7 @@ typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)(
|
||||
const float* Bias
|
||||
);
|
||||
#endif
|
||||
+#endif
|
||||
|
||||
typedef
|
||||
size_t
|
||||
@@ -763,8 +765,10 @@ extern "C" {
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero;
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd;
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero;
|
||||
MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd;
|
||||
+#endif
|
||||
#endif
|
||||
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero;
|
||||
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd;
|
||||
@@ -899,8 +903,10 @@ extern "C" {
|
||||
#define MLAS_QGEMM_THREAD_COMPLEXITY 65536
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024))
|
||||
#endif
|
||||
+#endif
|
||||
|
||||
//
|
||||
// Single-threaded single precision matrix/matrix multiply operation.
|
||||
@@ -2570,4 +2576,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH
|
||||
static_assert(std::is_same_v<UnpackedType, uint8_t> || std::is_same_v<UnpackedType, int8_t>);
|
||||
*Output = static_cast<uint8_t>(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF));
|
||||
}
|
||||
-
|
||||
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
|
||||
index ed437f20f7..8c9d0a75fd 100644
|
||||
--- a/onnxruntime/core/mlas/lib/platform.cpp
|
||||
+++ b/onnxruntime/core/mlas/lib/platform.cpp
|
||||
@@ -20,7 +20,7 @@ Abstract:
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
|
||||
-#if defined(MLAS_TARGET_POWER)
|
||||
+#if defined(MLAS_TARGET_POWER)
|
||||
#if defined(__linux__)
|
||||
#include <sys/auxv.h>
|
||||
#elif defined(_AIX)
|
||||
@@ -536,7 +536,7 @@ Return Value:
|
||||
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
|
||||
}
|
||||
|
||||
-#if defined(__linux__)
|
||||
+#if defined(__linux__) && !defined(USE_VSINPU)
|
||||
//
|
||||
// Check if the processor supports ASIMD I8MM instructions.
|
||||
//
|
||||
diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h
|
||||
index de7fd72fad..4f75dbd6fa 100644
|
||||
--- a/onnxruntime/core/mlas/lib/sbgemm.h
|
||||
+++ b/onnxruntime/core/mlas/lib/sbgemm.h
|
||||
@@ -31,6 +31,7 @@ Abstract:
|
||||
--*/
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -396,4 +397,5 @@ MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t Bat
|
||||
}
|
||||
);
|
||||
}
|
||||
+#endif
|
||||
#endif // defined(__aarch64__) && defined(__linux__)
|
||||
diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc
|
||||
index 6a71283f9d..d8bd348854 100644
|
||||
--- a/onnxruntime/core/providers/cpu/math/matmul.cc
|
||||
+++ b/onnxruntime/core/providers/cpu/math/matmul.cc
|
||||
@@ -132,7 +132,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
-#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if defined(__aarch64__) && defined(__linux__) && !defined(USE_VSINPU)
|
||||
bool GemmPackBBfloat16(AllocatorPtr& alloc,
|
||||
const Tensor& tensor_b,
|
||||
bool trans_b,
|
||||
@@ -180,6 +180,7 @@ Status MatMul<float>::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc
|
||||
if (input_idx == 1) {
|
||||
size_t packed_b_size;
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
size_t dim1 = 0;
|
||||
size_t dim2 = 0;
|
||||
TensorShape b_shape = tensor.Shape();
|
||||
@@ -192,6 +193,7 @@ Status MatMul<float>::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc
|
||||
if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) {
|
||||
is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_);
|
||||
} else
|
||||
+#endif
|
||||
#endif
|
||||
{
|
||||
is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_);
|
||||
@@ -257,6 +259,7 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
|
||||
const size_t lda = helper.Lda(trans_a);
|
||||
const size_t ldb = helper.Ldb(trans_b);
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) {
|
||||
std::vector<MLAS_SBGEMM_DATA_PARAMS> data(max_len);
|
||||
for (size_t i = 0; i < max_len; i++) {
|
||||
@@ -273,6 +276,7 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
|
||||
}
|
||||
MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool);
|
||||
} else
|
||||
+#endif
|
||||
#endif
|
||||
{
|
||||
std::vector<MLAS_SGEMM_DATA_PARAMS> data(max_len);
|
||||
diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h
|
||||
index b9bbe36583..2f570502d2 100644
|
||||
--- a/onnxruntime/core/providers/cpu/math/matmul.h
|
||||
+++ b/onnxruntime/core/providers/cpu/math/matmul.h
|
||||
@@ -31,8 +31,10 @@ class MatMul<float> final : public OpKernel {
|
||||
trans_batch_b_ = trans_batch_b_attr != 0;
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16);
|
||||
use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported();
|
||||
+#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -57,12 +59,14 @@ class MatMul<float> final : public OpKernel {
|
||||
bool trans_batch_b_;
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
// fastmath mode state
|
||||
bool use_fastmath_mode_;
|
||||
// sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2
|
||||
// so a minimum of 32 elements is defined to outweigh the additional prepacking overhead
|
||||
const size_t kFastMathModeKernelsizeThreshold = 32;
|
||||
#endif
|
||||
+#endif
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp
|
||||
index f85fe97776..6039b7fa9e 100644
|
||||
--- a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp
|
||||
+++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp
|
||||
@@ -16,6 +16,7 @@ Abstract:
|
||||
--*/
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
|
||||
#include "test_sbgemm.h"
|
||||
|
||||
@@ -138,4 +139,5 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe
|
||||
}
|
||||
return SBGemmRegistLongExecute() > 0;
|
||||
});
|
||||
+#endif
|
||||
#endif // defined(__aarch64__) && defined(__linux__)
|
||||
diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h
|
||||
index 13701e2e3d..7e432f53c2 100644
|
||||
--- a/onnxruntime/test/mlas/unittest/test_sbgemm.h
|
||||
+++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h
|
||||
@@ -16,6 +16,7 @@ Abstract:
|
||||
--*/
|
||||
|
||||
#if defined(__aarch64__) && defined(__linux__)
|
||||
+#if !defined(USE_VSINPU)
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -278,4 +279,5 @@ class MlasSBGemmTest : public MlasTestBase {
|
||||
}
|
||||
};
|
||||
|
||||
+#endif
|
||||
#endif // defined(__aarch64__) && defined(__linux__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue