Revert "Revert NeuralSpeed code for x64 MatMulNBits (#19382)" (#19474)

This reverts commit 0d10c7f3c1.
This commit is contained in:
Changming Sun 2024-02-09 09:24:54 -08:00 committed by GitHub
parent 3d2ddf96e3
commit 1007d8f3d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1023 additions and 0 deletions

View file

@ -202,6 +202,16 @@
"comments": "mp11"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a",
"repositoryUrl": "https://github.com/intel/neural-speed.git"
},
"comments": "neural_speed"
}
},
{
"component": {
"type": "git",

View file

@ -88,6 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
@ -901,6 +902,10 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
endif()
if(USE_NEURAL_SPEED)
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
endif()
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if (onnxruntime_USE_CUDA)
# Suppress a "conversion_function_not_usable" warning in gsl/span
@ -1188,6 +1193,13 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()
if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD)
include(neural_speed)
if (USE_NEURAL_SPEED)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla)
endif()
endif()
# TVM EP
if (onnxruntime_USE_TVM)
if (NOT TARGET tvm)

View file

@ -35,6 +35,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939
onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11
#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459)
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035

15
cmake/external/neural_speed.cmake vendored Normal file
View file

@ -0,0 +1,15 @@
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
set(USE_NEURAL_SPEED TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
set(USE_NEURAL_SPEED TRUE)
endif()
if(USE_NEURAL_SPEED)
FetchContent_Declare(
neural_speed
URL ${DEP_URL_neural_speed}
URL_HASH SHA1=${DEP_SHA1_neural_speed}
)
set(BTLA_USE_OPENMP OFF)
onnxruntime_fetchcontent_makeavailable(neural_speed)
endif()

View file

@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc"
)
endif()
set(onnxruntime_cpu_neural_speed_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h"
)
if(NOT USE_NEURAL_SPEED)
list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs})
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs})
list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs})
@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL)
target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical")
endif()
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
if(USE_NEURAL_SPEED)
onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla)
endif()
endif()
if (MSVC)
target_compile_options(onnxruntime_providers PRIVATE "/bigobj")
# if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)

View file

@ -10,6 +10,10 @@
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#ifdef ORT_NEURAL_SPEED
#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
#endif
namespace onnxruntime {
namespace contrib {
@ -19,6 +23,16 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
static_cast<int64_t>(CompMostAccurate),
static_cast<int64_t>(CompLeastAccurate));
#if defined(ORT_NEURAL_SPEED)
ORT_UNUSED_PARAMETER(nbits);
ORT_UNUSED_PARAMETER(block_size);
// Neural Speed APIs already expect a minimum accuracy level so just use the given value.
return accuracy_level;
#else // defined(ORT_NEURAL_SPEED)
// Find a supported accuracy level that is not less accurate than the one given.
// CompMostAccurate is always supported with the fallback implementation.
// Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed.
@ -31,6 +45,8 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
}
return effective_accuracy_level;
#endif // defined(ORT_NEURAL_SPEED)
}
} // namespace
@ -45,6 +61,17 @@ class MatMulNBits final : public OpKernel {
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
const Tensor* tensor_zero_point = nullptr;
bool B_constant = info.TryGetConstantInput(1, &tensor_B);
bool scale_constant = info.TryGetConstantInput(2, &tensor_scale);
bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point);
is_asym_ = info.GetInputCount() >= 4;
all_constant_ = B_constant && scale_constant;
all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_;
#endif
}
Status Compute(OpKernelContext* context) const override;
@ -65,6 +92,13 @@ class MatMulNBits final : public OpKernel {
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_{0};
#if defined(ORT_NEURAL_SPEED)
bool is_asym_{false};
bool all_constant_{false};
#endif // defined(ORT_NEURAL_SPEED)
};
Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
@ -72,6 +106,54 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
#if defined(ORT_NEURAL_SPEED)
if (!all_constant_) {
return Status::OK();
}
MLAS_THREADPOOL* pool = NULL;
if (nbits_ != 4) {
return Status::OK();
}
auto comp_type = static_cast<NS_SQNBIT_COMPUTE_TYPE>(accuracy_level_);
auto nbits = static_cast<int>(nbits_);
if (input_idx == 1) {
packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.Data<uint8_t>();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
std::memset(packed_b_.get(), 0, packed_b_size_);
NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false,
comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 2 && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_,
comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 3 && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_,
comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
#else // defined(ORT_NEURAL_SPEED)
if (input_idx == 1) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
@ -91,6 +173,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}
#endif // defined(ORT_NEURAL_SPEED)
return Status::OK();
}
@ -98,10 +182,30 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
#if defined(ORT_NEURAL_SPEED)
// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 2) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 3) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
#else // defined(ORT_NEURAL_SPEED)
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
#endif // defined(ORT_NEURAL_SPEED)
return Status::OK();
}
@ -112,6 +216,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();
#if defined(ORT_NEURAL_SPEED)
if (packed_b_) {
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
Tensor* y = ctx->Output(0, helper.OutputShape());
// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0) return Status::OK();
auto* y_data = y->MutableData<float>();
const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);
std::vector<NS_SQNBITS_GEMM_DATA_PACKED_PARAMS> gemm_params(max_len);
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
for (size_t i = 0; i < max_len; i++) {
gemm_params[i].A = a_data + helper.LeftOffsets()[i];
gemm_params[i].lda = lda;
gemm_params[i].B = packed_b_.get();
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr<int8_t>(allocator, ws_size);
NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool);
return Status::OK();
}
#endif // defined(ORT_NEURAL_SPEED)
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const auto* scales_data = scales->Data<float>();

View file

@ -0,0 +1,45 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
--*/
#pragma once
#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h"
namespace bestla {
using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>;
using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>;
using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>;
using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>;
using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>;
using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>;
using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>;
using tAVX2 = gemm::SCoreRowNAvx2<24, 4>;
using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>;
using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>;
using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>;
using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>;
template <class GC_T, BTLA_ISA ISA_T>
using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger<GC_T, ISA_T>;
template <class GC_T, BTLA_ISA ISA_T>
using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat<GC_T, ISA_T>;
class ORTThreading : public parallel::IThreading {
public:
explicit ORTThreading(void* tp);
void parallel_for(const parallel::thread_func& func) const override;
void set_threads(int nthreads) override {
(void)(nthreads);
assert(0);
}
void sync() const override { assert(0); }
void* mTp;
};
} // namespace bestla

View file

@ -0,0 +1,438 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
neural_speed_gemm.cpp
Abstract:
GEMM template combinations of neural_speed.
--*/
#include "contrib_ops/cpu/quantization/neural_speed_defs.h"
#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
#include "core/platform/threadpool.h"
using ThreadPool = onnxruntime::concurrency::ThreadPool;
namespace bestla {
ORTThreading::ORTThreading(void* tp)
: IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast<ThreadPool*>(tp))), mTp(tp) {}
void ORTThreading::parallel_for(const parallel::thread_func& func) const {
ThreadPool::TrySimpleParallelFor(reinterpret_cast<ThreadPool*>(mTp), mThreadNum,
[&](ptrdiff_t tid) { func(static_cast<int>(tid)); });
}
template <class GemmCore_T>
static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace,
parallel::IThreading* th) {
auto M_ = static_cast<int>(M);
auto N_ = static_cast<int>(N);
auto K_ = static_cast<int>(K);
auto lda_ = static_cast<int>(lda);
auto ldc_ = static_cast<int>(ldc);
utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize);
if (M <= 16) {
using Parallel = parallel::gemm::SchedulerKBlock<GemmCore_T>;
using Launcher =
wrapper::gemm::LauncherKBlock<GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationKBlockBaseF32,
prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::CompFp32BlockEpilogue,
epilogue::gemm::AccumulatorWriteBackFp32>;
static Launcher kernel;
auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize);
if (B->IsAsym()) {
reduceA.assign(WorkSpace);
ORTThreading single(nullptr);
kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single);
}
typename Launcher::Param args{gp,
{A, lda_, &reduceA},
{B},
{B->template SPtr<int8_t>(), B->SDtype(), B->CStep(), B->template ZPtr<int8_t>(),
reduceA.template RPtr<float>(), reduceA.lda},
{C, ldc_, nullptr}};
parallel::GemmRun<Parallel>(kernel, args, th);
} else {
using Parallel = parallel::gemm::SchedulerBase<GemmCore_T>;
using Launcher =
wrapper::gemm::LauncherBase<GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationBase,
prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::AccumulatorWriteBackFp32>;
static Launcher kernel;
typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}};
parallel::GemmRun<Parallel>(kernel, args, th);
}
}
template <class GemmCore_T>
static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace,
parallel::IThreading* th) {
using Parallel = parallel::gemm::SchedulerKBlockS<GemmCore_T>;
using Launcher =
wrapper::gemm::LauncherIntKBlock<GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationF32KBlockQuantize,
prologue_b::gemm::WeightKBlockNInteger,
epilogue::gemm::AccumulatorWriteBackFp32>;
auto M_ = static_cast<int>(M);
auto N_ = static_cast<int>(N);
auto K_ = static_cast<int>(K);
auto lda_ = static_cast<int>(lda);
auto ldc_ = static_cast<int>(ldc);
static Launcher kernel;
auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym());
quanA.assign(WorkSpace);
if (M <= 16) {
ORTThreading single(nullptr);
kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single);
} else {
kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th);
}
utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize);
typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}};
parallel::GemmRun<Parallel>(kernel, args, th);
}
template <class GemmCore_T>
static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) {
auto M_ = static_cast<int>(M);
auto K_ = static_cast<int>(K);
(void)(A);
(void)(N);
(void)(C);
(void)(lda);
(void)(ldc);
if (M <= 16) {
using ProA = prologue_a::gemm::ActivationKBlockBaseF32<GemmCore_T, GemmCore_T::ISA>;
static ProA proA;
if (B->IsAsym()) {
auto reduceA = proA.createStorage(M_, K_, B->mBlockSize);
return reduceA.mSize;
}
return 0;
} else {
// using ProA = prologue_a::gemm::ActivationBase<GemmCore_T, GemmCore_T::ISA>;
return 0;
}
}
template <class GemmCore_T>
static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) {
(void)(N);
(void)(lda);
(void)(ldc);
(void)(A);
(void)(C);
using ProA = prologue_a::gemm::ActivationF32KBlockQuantize<GemmCore_T, GemmCore_T::ISA>;
static ProA proA;
auto quanA =
proA.createStorage(static_cast<int>(M), static_cast<int>(K), static_cast<int>(B->mBlockSize), B->IsAsym());
return quanA.mSize;
}
} // namespace bestla
using namespace bestla;
static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace,
void* ThreadPool) {
GetCPUDevice();
bestla::ORTThreading orth(ThreadPool);
bool processed = true;
for (size_t i = 0; i < BatchN; i++) {
auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
auto uptr = std::unique_ptr<bestla::storage::gemm::IWeightBase>(ptr);
if (ptr) {
auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
auto btype = static_cast<gemm::CompType>(gemm::CompTypeHelper::get_B(CType));
if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
auto kptr = reinterpret_cast<bestla::storage::gemm::StorageWeightKBlockNInteger*>(ptr);
auto BlkSize = kptr->mBlockSize;
if (btype == gemm::CompType::tFP32 && PackRow == 1) {
if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
bestla::NSSQ4GemmCompF32<bestla::tAVX512F>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth);
} else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
bestla::NSSQ4GemmCompF32<bestla::tAVX2>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C,
DataParams[i].ldc, WorkSpace, &orth);
}
}
if (btype == gemm::CompType::tS8 && PackRow == 4) {
if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() &&
BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
bestla::NSSQ4GemmCompInt8<bestla::tAMX_INT8_SS_KBlock>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace,
&orth);
} else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() &&
BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
bestla::NSSQ4GemmCompInt8<bestla::tAVX512_VNNI_KBlock>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace,
&orth);
} else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() &&
BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
bestla::NSSQ4GemmCompInt8<bestla::tAVX_VNNI_KBlock>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth);
}
}
}
} else {
processed = false;
break;
}
}
return processed;
}
static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) {
GetCPUDevice();
size_t size = 0;
for (size_t i = 0; i < BatchN; i++) {
auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
auto uptr = std::unique_ptr<storage::gemm::IWeightBase>(ptr);
if (ptr) {
if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
auto kptr = reinterpret_cast<storage::gemm::StorageWeightKBlockNInteger*>(ptr);
auto NTile =
gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
auto btype = static_cast<gemm::CompType>(gemm::CompTypeHelper::get_B(CType));
auto BlkSize = kptr->mBlockSize;
if (btype == gemm::CompType::tFP32 && PackRow == 1) {
if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
size = std::max(NSSQ4GemmCompF32WorkspaceSize<tAVX512F>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc),
size);
} else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
size = std::max(NSSQ4GemmCompF32WorkspaceSize<tAVX2>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc),
size);
}
}
if (btype == gemm::CompType::tS8 && PackRow == 4) {
if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
size = std::max(NSSQ4GemmCompInt8WorkspaceSize<tAMX_INT8_SS_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
} else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() &&
BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
size = std::max(NSSQ4GemmCompInt8WorkspaceSize<tAVX512_VNNI_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
} else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
size = std::max(NSSQ4GemmCompInt8WorkspaceSize<tAVX_VNNI_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
}
}
}
}
}
return size;
}
template <typename T>
static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) {
static T proB;
auto stor = proB.createStorage(static_cast<int>(N), static_cast<int>(K), static_cast<int>(block_size),
BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym);
// TODO(Yu) support more scale dtype
return stor.mSize;
}
static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) {
auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf);
auto uptr = std::unique_ptr<storage::gemm::IWeightBase>(ptr);
ORTThreading orth(ThreadPool);
auto N_ = static_cast<int>(N);
auto K_ = static_cast<int>(K);
auto ldb_ = static_cast<int>(ldb);
GetCPUDevice();
if (ptr) {
auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
auto btype = static_cast<gemm::CompType>(gemm::CompTypeHelper::get_B(CType));
if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
auto wptr = reinterpret_cast<storage::gemm::StorageWeightKBlockNInteger*>(ptr);
auto BlkSize = wptr->mBlockSize;
if (btype == gemm::CompType::tFP32 && PackRow == 1) {
if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
static tWeiNInt<tAVX512F, tAVX512F::ISA> proB;
proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
} else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
static tWeiNInt<tAVX2, tAVX2::ISA> proB;
proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
}
}
if (btype == gemm::CompType::tS8 && PackRow == 4) {
if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
static tWeiNInt<tAMX_INT8_SS_KBlock, tAMX_INT8_SS_KBlock::ISA> proB;
proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
} else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() &&
BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
static tWeiNInt<tAVX512_VNNI_KBlock, tAVX512_VNNI_KBlock::ISA> proB;
proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
} else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
static tWeiNInt<tAVX_VNNI_KBlock, tAVX_VNNI_KBlock::ISA> proB;
proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
}
}
}
return true;
}
return false;
}
template <typename T>
static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale,
const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb,
void* ThreadPool) {
static T proB;
auto N_ = static_cast<int>(N);
auto K_ = static_cast<int>(K);
auto stor = proB.createStorage(N_, K_, static_cast<int>(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32,
BTLA_DTYPE::BF16, IsAsym);
stor.assign(reinterpret_cast<int8_t*>(PackedBuf));
ORTThreading orth(ThreadPool);
proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast<int>(ldb), Scale, Zp, &stor, &orth);
if (lastCall) {
proB.reduceWeight(&stor, &orth);
}
}
static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) {
GetCPUDevice();
if (K % BlkSize != 0) {
return 0;
}
// from low precision to high precision
switch (CompType) {
case NSCompInt8:
if (!isAsym) { // asym int8 is not optimized, so fall through to others.
if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
return NSQ4BuSize<tWeiNInt<tAMX_INT8_SS_KBlock, tAMX_INT8_SS_KBlock::ISA>>(BlkSize, N, K, isAsym);
}
if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
return NSQ4BuSize<tWeiNInt<tAVX512_VNNI_KBlock, tAVX512_VNNI_KBlock::ISA>>(BlkSize, N, K, isAsym);
}
if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
return NSQ4BuSize<tWeiNInt<tAVX_VNNI_KBlock, tAVX_VNNI_KBlock::ISA>>(BlkSize, N, K, isAsym);
}
}
[[fallthrough]];
case NSCompBf16:
case NSCompFp16:
case NSCompFp32:
case NSCompUndef:
if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
return NSQ4BuSize<tWeiNInt<tAVX512F, tAVX512F::ISA>>(BlkSize, N, K, isAsym);
}
if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
return NSQ4BuSize<tWeiNInt<tAVX2, tAVX2::ISA>>(BlkSize, N, K, isAsym);
}
[[fallthrough]];
default:
return 0;
}
}
static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N,
size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall,
NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) {
GetCPUDevice();
// explicit statement fall through.
switch (CompType) {
case NSCompInt8:
if (!isAsym) { // asym int8 is not optimized, so fall through to others.
if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAMX_INT8_SS_KBlock, tAMX_INT8_SS_KBlock::ISA>>(
PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool);
return true;
}
if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAVX512_VNNI_KBlock, tAVX512_VNNI_KBlock::ISA>>(
PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool);
return true;
}
if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAVX_VNNI_KBlock, tAVX_VNNI_KBlock::ISA>>(PackedBuf, BlkSize, QData, Scale, Zp, N,
K, isAsym, lastCall, ldb, ThreadPool);
return true;
}
}
[[fallthrough]];
case NSCompBf16:
case NSCompFp16:
case NSCompFp32:
case NSCompUndef:
if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAVX512F, tAVX512F::ISA>>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym,
lastCall, ldb, ThreadPool);
return true;
}
if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAVX2, tAVX2::ISA>>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall,
ldb, ThreadPool);
return true;
}
[[fallthrough]];
default:
return false;
}
}
size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym,
NS_SQNBIT_COMPUTE_TYPE CompType) {
if (nbits == 4) {
auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType);
if (jsize) {
return jsize;
}
}
return 0;
}
void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K,
size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall,
NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) {
if (nbits == 4) {
if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) {
return;
}
}
}
void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) {
// only nbits=4 can be packed, so not necessary to check the nbits in DataParams
if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) {
return;
}
}
size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) {
// only nbits=4 can be packed, so not necessary to check the nbits in DataParams
return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams);
}
void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace,
void* ThreadPool) {
// only nbits=4 can be packed, so not necessary to check the nbits in DataParams
if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast<int8_t*>(WorkSpace), ThreadPool)) {
// PackedWeight is created by bestla
return;
}
}

View file

@ -0,0 +1,129 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
neural_speed_gemm.h
Abstract:
Prepack-weight GEMM APIs of neural_speed.
--*/
#pragma once
#include <stdint.h>
#include <cstddef>
/**
* @brief Define compute types of block quantization
*/
enum NS_SQNBIT_COMPUTE_TYPE {
NSCompUndef = 0, /*!< undef */
NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */
NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */
NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */
NSCompInt8 = 4 /*!< input int8, accumulator int32 */
};
/**
* @brief Data parameters for NBits GEMM routine
* C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
* All except C are [in] parameters
*/
struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS {
const float* A = nullptr; /**< address of A (float32 matrix)*/
const void* B = nullptr; /**< address of B (packed nbits blob)*/
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldc = 0; /**< leading dimension of C*/
};
/**
* @brief Compute the byte size of the parameter combination
*
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @return size of the packing buffer, 0 if the operation is not yet supported.
*/
size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym,
NS_SQNBIT_COMPUTE_TYPE comp_type);
/**
* @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers.
*
* @param PackedBuf packed data buffer
* @param QData quantized data buffer
* @param Scale scale pointer
* @param Zp zero point pointer
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization (default 4)
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
* one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
* they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
* inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
* (is_asym is false) and Zp(is_asym is true).
* @param thread_pool
*/
void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K,
size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call,
NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool);
/**
* @brief Unpack and dequantize to fp32
*
* @param FpData unpacked float32 data
* @param PackedBuf quantized and packed data
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param thread_pool
*/
void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool);
/**
* @brief Get the workspace size required by computation.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @return Workspace size in bytes
*/
size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams);
/**
* @brief Batched GEMM: C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] WorkSpace temporary buffer
* @param[in] ThreadPool
* @return
*/
void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace,
void* ThreadPool = nullptr);

View file

@ -0,0 +1,39 @@
//-----------------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
//-----------------------------------------------------------------------------
#pragma once
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wsign-compare"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#pragma GCC diagnostic ignored "-Wunused-variable"
#pragma GCC diagnostic ignored "-Wunused-value"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#pragma GCC diagnostic ignored "-Wunused-but-set-variable"
#pragma GCC diagnostic ignored "-Wunused-but-set-parameter"
#elif defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4457)
#pragma warning(disable : 4189)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4267)
#pragma warning(disable : 4702)
#endif
#include "bestla/bestla_prologue_a.h"
#include "bestla/bestla_wrapper.h"
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#elif defined(_MSC_VER)
#pragma warning(pop)
#endif

View file

@ -149,10 +149,17 @@ TEST(MatMulNBits, Float32) {
for (auto N : {1, 2, 32, 288}) {
for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) {
for (auto block_size : {16, 32, 64, 128}) {
#ifdef ORT_NEURAL_SPEED
for (auto accuracy_level : {0, 1, 4}) {
RunTest(M, N, K, block_size, accuracy_level, false, false);
RunTest(M, N, K, block_size, accuracy_level, true, false);
}
#else
for (auto accuracy_level : {0}) {
RunTest(M, N, K, block_size, accuracy_level, false, false);
RunTest(M, N, K, block_size, accuracy_level, true, false);
}
#endif
}
}
}
@ -185,6 +192,174 @@ TEST(MatMulNBits, Float16Large) {
#endif
void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_size, bool is_asym,
int64_t acc_lvl) {
// (M x K) X (K x N)
OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("accuracy_level", acc_lvl);
test.AddAttribute<int64_t>("block_size", int64_t(block_size));
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("K", K);
std::vector<float> input0_vals(M * K);
float fv = -135.f;
for (auto& f : input0_vals) {
f = fv / 127;
fv++;
if (fv > 135.f) {
fv = -135.f;
}
}
size_t kblks = K / block_size;
std::vector<uint8_t> input1_vals(N * K / 2);
for (size_t i = 0; i < input1_vals.size(); i++) {
input1_vals[i] = uint8_t(i);
}
std::vector<float> input2_vals(N * kblks, 0.002f);
for (size_t i = 0; i < N * kblks; i++) {
input2_vals[i] += (i % 100) * 0.00003f;
}
std::vector<uint8_t> input3_vals(N * kblks / 2, static_cast<uint8_t>(0x88));
std::vector<float> input1_f_vals(N * K);
if (is_asym) {
for (size_t i = 0; i < N * kblks; i += 2) {
input3_vals[i / 2] = static_cast<uint8_t>(i + 1);
}
for (int64_t i = 0; i < K; i += 2) {
for (int64_t j = 0; j < N; j++) {
auto srcv = input1_vals[j * K / 2 + i / 2];
auto koff = i % (block_size * 2);
auto zpv = input3_vals[j * kblks / 2 + i / block_size / 2];
auto zp0 = koff < block_size ? (zpv & 0xf) - 8 : ((zpv & 0xf0) >> 4) - 8;
auto src0 = (srcv & 0xf) - 8;
auto src1 = ((srcv & 0xf0) >> 4) - 8;
auto scale0 = input2_vals[j * kblks + i / block_size];
auto scale1 = input2_vals[j * kblks + (i + 1) / block_size];
input1_f_vals[i * N + j] = (static_cast<float>(src0) - zp0) * scale0;
input1_f_vals[(i + 1) * N + j] = (static_cast<float>(src1) - zp0) * scale1;
}
}
} else {
for (int64_t i = 0; i < K; i += 2) {
for (int64_t j = 0; j < N; j++) {
auto srcv = input1_vals[j * K / 2 + i / 2];
auto src0 = (srcv & 0xf) - 8;
auto src1 = ((srcv & 0xf0) >> 4) - 8;
auto scale0 = input2_vals[j * kblks + i / block_size];
auto scale1 = input2_vals[j * kblks + (i + 1) / block_size];
input1_f_vals[i * N + j] = static_cast<float>(src0) * scale0;
input1_f_vals[(i + 1) * N + j] = static_cast<float>(src1) * scale1;
}
}
}
std::vector<float> expected_vals(M * N);
for (int64_t m = 0; m < M; m++) {
for (int64_t n = 0; n < N; n++) {
float sum = 0.0f;
for (int64_t k = 0; k < K; k++) {
sum += input0_vals[m * K + k] * input1_f_vals[k * N + n];
}
expected_vals[m * N + n] = sum;
}
}
test.AddInput<float>("A", {M, K}, input0_vals, false);
test.AddInput<uint8_t>("B", {N, static_cast<int64_t>(kblks), static_cast<int64_t>(block_size / 2)}, input1_vals,
true);
test.AddInput<float>("scales", {N, static_cast<int64_t>(kblks)}, input2_vals, true);
if (is_asym) {
test.AddInput<uint8_t>("zero_points", {N, static_cast<int64_t>(kblks / 2)}, input3_vals, true);
}
test.AddOutput<float>("Y", {M, N}, expected_vals, false);
if (acc_lvl == 4) {
test.SetOutputAbsErr("Y", 0.1f);
}
OrtValue b, scale, zp;
Tensor::InitOrtValue(DataTypeImpl::GetType<uint8_t>(),
TensorShape({N, static_cast<int64_t>(kblks), static_cast<int64_t>(block_size / 2)}),
input1_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b);
Tensor::InitOrtValue(DataTypeImpl::GetType<float>(), TensorShape({N, static_cast<int64_t>(kblks)}),
input2_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), scale);
if (is_asym) {
Tensor::InitOrtValue(DataTypeImpl::GetType<uint8_t>(), TensorShape({N, static_cast<int64_t>(kblks / 2)}),
input3_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), zp);
}
SessionOptions so;
// Set up B as a shared initializer to be shared between sessions
ASSERT_EQ(so.AddInitializer("B", &b), Status::OK());
ASSERT_EQ(so.AddInitializer("scales", &scale), Status::OK());
if (is_asym) {
ASSERT_EQ(so.AddInitializer("zero_points", &zp), Status::OK());
}
// We want all sessions running using this OpTester to be able to share pre-packed weights if applicable
test.EnableSharingOfPrePackedWeightsAcrossSessions();
// Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP
// and we want to ensure that it is available in this build
auto cpu_ep = []() -> std::vector<std::unique_ptr<IExecutionProvider>> {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
return execution_providers;
};
size_t number_of_pre_packed_weights_counter_session_1 = 0;
size_t number_of_shared_pre_packed_weights_counter = 0;
// Session 1
{
auto ep_vec = cpu_ep();
test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {},
&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter);
// Assert that no pre-packed weights have been shared thus far
ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast<size_t>(0));
}
auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared();
// Assert that the number of elements in the shared container
// is the same as the number of weights that have been pre-packed
ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container);
// On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements
// that have been pre-packed will be zero in which case we do not continue with the testing
// of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all.
if (number_of_pre_packed_weights_counter_session_1 == 0) return;
// Session 2
{
size_t number_of_pre_packed_weights_counter_session_2 = 0;
auto ep_vec = cpu_ep();
test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {},
&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter);
// Assert that the same number of weights were pre-packed in both sessions
ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2);
// Assert that the number of pre-packed weights that were shared equals
// the number of pre-packed weights in the second session
ASSERT_EQ(number_of_pre_packed_weights_counter_session_2,
static_cast<size_t>(number_of_shared_pre_packed_weights_counter));
}
}
#ifdef ORT_NEURAL_SPEED
TEST(MatMulNBits, SharedPrepackedWeights) {
RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, true, 1);
RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, false, 1);
RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 1);
RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 4);
RunSharedPrepackedWeightsTest(2, 4096, 4096, 1024, false, 4);
RunSharedPrepackedWeightsTest(2, 4096, 4096, 4096, false, 4);
}
#endif
} // namespace test
} // namespace onnxruntime