From 1007d8f3d1904ff2efc4e0647e795939fe049464 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 9 Feb 2024 09:24:54 -0800 Subject: [PATCH] Revert "Revert NeuralSpeed code for x64 MatMulNBits (#19382)" (#19474) This reverts commit 0d10c7f3c1111cfff064e7990aa897ac9fd05c82. --- cgmanifests/generated/cgmanifest.json | 10 + cmake/CMakeLists.txt | 12 + cmake/deps.txt | 1 + cmake/external/neural_speed.cmake | 15 + cmake/onnxruntime_providers_cpu.cmake | 15 + .../cpu/quantization/matmul_nbits.cc | 144 ++++++ .../cpu/quantization/neural_speed_defs.h | 45 ++ .../cpu/quantization/neural_speed_gemm.cc | 438 ++++++++++++++++++ .../cpu/quantization/neural_speed_gemm.h | 129 ++++++ .../cpu/quantization/neural_speed_wrapper.h | 39 ++ .../test/contrib_ops/matmul_4bits_test.cc | 175 +++++++ 11 files changed, 1023 insertions(+) create mode 100644 cmake/external/neural_speed.cmake create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index fc4ea25603..efd901787f 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -202,6 +202,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", + "repositoryUrl": "https://github.com/intel/neural-speed.git" + }, + "comments": "neural_speed" + } + }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0ccd874cee..90fe8276ea 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,6 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -901,6 +902,10 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() + if(USE_NEURAL_SPEED) + target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) + endif() + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span @@ -1188,6 +1193,13 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() +if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) + include(neural_speed) + if (USE_NEURAL_SPEED) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) + endif() +endif() + # TVM EP if (onnxruntime_USE_TVM) if (NOT TARGET tvm) diff --git a/cmake/deps.txt b/cmake/deps.txt index 17c3cbf9a6..cb431f8c77 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -35,6 +35,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 +neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake new file mode 100644 index 0000000000..ed71135140 --- /dev/null +++ b/cmake/external/neural_speed.cmake @@ -0,0 +1,15 @@ +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + set(USE_NEURAL_SPEED TRUE) +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + set(USE_NEURAL_SPEED TRUE) +endif() + +if(USE_NEURAL_SPEED) + FetchContent_Declare( + neural_speed + URL ${DEP_URL_neural_speed} + URL_HASH SHA1=${DEP_SHA1_neural_speed} + ) + set(BTLA_USE_OPENMP OFF) + onnxruntime_fetchcontent_makeavailable(neural_speed) +endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f60faa4d39..b81a5c79ac 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() + set(onnxruntime_cpu_neural_speed_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" + ) + if(NOT USE_NEURAL_SPEED) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(USE_NEURAL_SPEED) + onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) + endif() +endif() + if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index e8d8bbca66..166f5c8f52 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -10,6 +10,10 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#ifdef ORT_NEURAL_SPEED +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#endif + namespace onnxruntime { namespace contrib { @@ -19,6 +23,16 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level static_cast(CompMostAccurate), static_cast(CompLeastAccurate)); +#if defined(ORT_NEURAL_SPEED) + + ORT_UNUSED_PARAMETER(nbits); + ORT_UNUSED_PARAMETER(block_size); + + // Neural Speed APIs already expect a minimum accuracy level so just use the given value. + return accuracy_level; + +#else // defined(ORT_NEURAL_SPEED) + // Find a supported accuracy level that is not less accurate than the one given. // CompMostAccurate is always supported with the fallback implementation. // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. @@ -31,6 +45,8 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } return effective_accuracy_level; + +#endif // defined(ORT_NEURAL_SPEED) } } // namespace @@ -45,6 +61,17 @@ class MatMulNBits final : public OpKernel { accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); +#ifdef ORT_NEURAL_SPEED + const Tensor* tensor_B = nullptr; + const Tensor* tensor_scale = nullptr; + const Tensor* tensor_zero_point = nullptr; + bool B_constant = info.TryGetConstantInput(1, &tensor_B); + bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); + bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + is_asym_ = info.GetInputCount() >= 4; + all_constant_ = B_constant && scale_constant; + all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; +#endif } Status Compute(OpKernelContext* context) const override; @@ -65,6 +92,13 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; + +#if defined(ORT_NEURAL_SPEED) + + bool is_asym_{false}; + bool all_constant_{false}; + +#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, @@ -72,6 +106,54 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; +#if defined(ORT_NEURAL_SPEED) + + if (!all_constant_) { + return Status::OK(); + } + MLAS_THREADPOOL* pool = NULL; + if (nbits_ != 4) { + return Status::OK(); + } + auto comp_type = static_cast(accuracy_level_); + auto nbits = static_cast(nbits_); + if (input_idx == 1) { + packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.Data(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + std::memset(packed_b_.get(), 0, packed_b_size_); + NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, + comp_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 2 && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, + comp_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 3 && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, + comp_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + +#else // defined(ORT_NEURAL_SPEED) + if (input_idx == 1) { const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { @@ -91,6 +173,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } +#endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } @@ -98,10 +182,30 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; +#if defined(ORT_NEURAL_SPEED) + + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } + if (input_idx == 2) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 3) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + +#else // defined(ORT_NEURAL_SPEED) + + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + +#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } @@ -112,6 +216,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); +#if defined(ORT_NEURAL_SPEED) + + if (packed_b_) { + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + std::vector gemm_params(max_len); + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + for (size_t i = 0; i < max_len; i++) { + gemm_params[i].A = a_data + helper.LeftOffsets()[i]; + gemm_params[i].lda = lda; + gemm_params[i].B = packed_b_.get(); + gemm_params[i].C = y_data + helper.OutputOffsets()[i]; + gemm_params[i].ldc = N; + } + auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + // workspace for activation process(dynamic quantization and others) + auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); + NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); + return Status::OK(); + } + +#endif // defined(ORT_NEURAL_SPEED) + const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); const auto* scales_data = scales->Data(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h new file mode 100644 index 0000000000..864abffd13 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h @@ -0,0 +1,45 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" + +namespace bestla { + +using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; +using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; +using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; +using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; +using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; +using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; + +template +using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; +template +using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; + +class ORTThreading : public parallel::IThreading { + public: + explicit ORTThreading(void* tp); + void parallel_for(const parallel::thread_func& func) const override; + void set_threads(int nthreads) override { + (void)(nthreads); + assert(0); + } + void sync() const override { assert(0); } + void* mTp; +}; + +} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc new file mode 100644 index 0000000000..73aaa4ae61 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -0,0 +1,438 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.cpp + +Abstract: + + GEMM template combinations of neural_speed. +--*/ + +#include "contrib_ops/cpu/quantization/neural_speed_defs.h" +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#include "core/platform/threadpool.h" + +using ThreadPool = onnxruntime::concurrency::ThreadPool; + +namespace bestla { + +ORTThreading::ORTThreading(void* tp) + : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} + +void ORTThreading::parallel_for(const parallel::thread_func& func) const { + ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, + [&](ptrdiff_t tid) { func(static_cast(tid)); }); +} + +template +static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + if (M <= 16) { + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->IsAsym()) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); + } + typename Launcher::Param args{gp, + {A, lda_, &reduceA}, + {B}, + {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), + reduceA.template RPtr(), reduceA.lda}, + {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } else { + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + static Launcher kernel; + typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } +} + +template +static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); +} + +template +static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(A); + (void)(N); + (void)(C); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using ProA = prologue_a::gemm::ActivationKBlockBaseF32; + static ProA proA; + if (B->IsAsym()) { + auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + // using ProA = prologue_a::gemm::ActivationBase; + return 0; + } +} + +template +static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + (void)(N); + (void)(lda); + (void)(ldc); + (void)(A); + (void)(C); + using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; + static ProA proA; + auto quanA = + proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); + return quanA.mSize; +} + +} // namespace bestla + +using namespace bestla; + +static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, + void* ThreadPool) { + GetCPUDevice(); + bestla::ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, + DataParams[i].ldc, WorkSpace, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && + BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && + BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + } + } + } + return size; +} + +template +static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { + static T proB; + auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), + BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto wptr = reinterpret_cast(ptr); + auto BlkSize = wptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} + +template +static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, + const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, + void* ThreadPool) { + static T proB; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + BTLA_DTYPE::BF16, IsAsym); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + proB.reduceWeight(&stor, &orth); + } +} + +static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + [[fallthrough]]; + default: + return 0; + } +} + +static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, + size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, + K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, + lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, + ldb, ThreadPool); + return true; + } + [[fallthrough]]; + default: + return false; + } +} + +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, + NS_SQNBIT_COMPUTE_TYPE CompType) { + if (nbits == 4) { + auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } + return 0; +} + +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + if (nbits == 4) { + if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +} + +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +} + +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +} + +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by bestla + return; + } +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h new file mode 100644 index 0000000000..ebcb3027a2 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.h + +Abstract: + + Prepack-weight GEMM APIs of neural_speed. +--*/ + +#pragma once + +#include +#include + +/** + * @brief Define compute types of block quantization + */ +enum NS_SQNBIT_COMPUTE_TYPE { + NSCompUndef = 0, /*!< undef */ + NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ + NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ + NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ + NSCompInt8 = 4 /*!< input int8, accumulator int32 */ +}; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, + NS_SQNBIT_COMPUTE_TYPE comp_type); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, + NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h new file mode 100644 index 0000000000..d3902f9bd6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h @@ -0,0 +1,39 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-value" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" + +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4457) +#pragma warning(disable : 4189) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4702) +#endif + +#include "bestla/bestla_prologue_a.h" +#include "bestla/bestla_wrapper.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index d22da2a3da..2ad20eafc2 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -149,10 +149,17 @@ TEST(MatMulNBits, Float32) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { +#ifdef ORT_NEURAL_SPEED + for (auto accuracy_level : {0, 1, 4}) { + RunTest(M, N, K, block_size, accuracy_level, false, false); + RunTest(M, N, K, block_size, accuracy_level, true, false); + } +#else for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); } +#endif } } } @@ -185,6 +192,174 @@ TEST(MatMulNBits, Float16Large) { #endif +void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_size, bool is_asym, + int64_t acc_lvl) { + // (M x K) X (K x N) + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("accuracy_level", acc_lvl); + test.AddAttribute("block_size", int64_t(block_size)); + test.AddAttribute("bits", QBits); + test.AddAttribute("N", N); + test.AddAttribute("K", K); + + std::vector input0_vals(M * K); + float fv = -135.f; + for (auto& f : input0_vals) { + f = fv / 127; + fv++; + if (fv > 135.f) { + fv = -135.f; + } + } + + size_t kblks = K / block_size; + std::vector input1_vals(N * K / 2); + for (size_t i = 0; i < input1_vals.size(); i++) { + input1_vals[i] = uint8_t(i); + } + std::vector input2_vals(N * kblks, 0.002f); + for (size_t i = 0; i < N * kblks; i++) { + input2_vals[i] += (i % 100) * 0.00003f; + } + std::vector input3_vals(N * kblks / 2, static_cast(0x88)); + + std::vector input1_f_vals(N * K); + if (is_asym) { + for (size_t i = 0; i < N * kblks; i += 2) { + input3_vals[i / 2] = static_cast(i + 1); + } + for (int64_t i = 0; i < K; i += 2) { + for (int64_t j = 0; j < N; j++) { + auto srcv = input1_vals[j * K / 2 + i / 2]; + auto koff = i % (block_size * 2); + auto zpv = input3_vals[j * kblks / 2 + i / block_size / 2]; + auto zp0 = koff < block_size ? (zpv & 0xf) - 8 : ((zpv & 0xf0) >> 4) - 8; + auto src0 = (srcv & 0xf) - 8; + auto src1 = ((srcv & 0xf0) >> 4) - 8; + auto scale0 = input2_vals[j * kblks + i / block_size]; + auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; + input1_f_vals[i * N + j] = (static_cast(src0) - zp0) * scale0; + input1_f_vals[(i + 1) * N + j] = (static_cast(src1) - zp0) * scale1; + } + } + } else { + for (int64_t i = 0; i < K; i += 2) { + for (int64_t j = 0; j < N; j++) { + auto srcv = input1_vals[j * K / 2 + i / 2]; + auto src0 = (srcv & 0xf) - 8; + auto src1 = ((srcv & 0xf0) >> 4) - 8; + auto scale0 = input2_vals[j * kblks + i / block_size]; + auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; + input1_f_vals[i * N + j] = static_cast(src0) * scale0; + input1_f_vals[(i + 1) * N + j] = static_cast(src1) * scale1; + } + } + } + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[k * N + n]; + } + expected_vals[m * N + n] = sum; + } + } + + test.AddInput("A", {M, K}, input0_vals, false); + + test.AddInput("B", {N, static_cast(kblks), static_cast(block_size / 2)}, input1_vals, + true); + test.AddInput("scales", {N, static_cast(kblks)}, input2_vals, true); + if (is_asym) { + test.AddInput("zero_points", {N, static_cast(kblks / 2)}, input3_vals, true); + } + test.AddOutput("Y", {M, N}, expected_vals, false); + if (acc_lvl == 4) { + test.SetOutputAbsErr("Y", 0.1f); + } + + OrtValue b, scale, zp; + Tensor::InitOrtValue(DataTypeImpl::GetType(), + TensorShape({N, static_cast(kblks), static_cast(block_size / 2)}), + input1_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks)}), + input2_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), scale); + if (is_asym) { + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks / 2)}), + input3_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), zp); + } + SessionOptions so; + // Set up B as a shared initializer to be shared between sessions + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_EQ(so.AddInitializer("scales", &scale), Status::OK()); + if (is_asym) { + ASSERT_EQ(so.AddInitializer("zero_points", &zp), Status::OK()); + } + + // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP + // and we want to ensure that it is available in this build + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + // Assert that no pre-packed weights have been shared thus far + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); + // Assert that the number of elements in the shared container + // is the same as the number of weights that have been pre-packed + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); + + // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements + // that have been pre-packed will be zero in which case we do not continue with the testing + // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. + if (number_of_pre_packed_weights_counter_session_1 == 0) return; + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + // Assert that the same number of weights were pre-packed in both sessions + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Assert that the number of pre-packed weights that were shared equals + // the number of pre-packed weights in the second session + ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, + static_cast(number_of_shared_pre_packed_weights_counter)); + } +} + +#ifdef ORT_NEURAL_SPEED +TEST(MatMulNBits, SharedPrepackedWeights) { + RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, true, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, false, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 4); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 1024, false, 4); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 4096, false, 4); +} +#endif } // namespace test } // namespace onnxruntime