From 08abab0b141ee172b51209f5f35245eae824e4c4 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:40:04 +0000 Subject: [PATCH] [CPU] Fix mamtulnbits accuracy level (#22963) ### Description Fix mamtulnbits accuracy level ### Motivation and Context --- onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | 1 + onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 473ec51524..c3e43f897c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -33,6 +33,7 @@ constexpr size_t A = 0, }; typedef enum { + Level0, /*!< input fp32, accumulator fp32 */ Level1, /*!< input fp32, accumulator fp32 */ Level2, /*!< input fp16, accumulator fp16 */ Level3, /*!< input bf16, accumulator fp32 */ diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 87a9ef762b..6dedce24e7 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -274,11 +274,12 @@ void TestMatMulNBitsTyped() { base_opts.block_size = block_size; base_opts.accuracy_level = accuracy_level; - if constexpr (std::is_same::value) { + if (base_opts.accuracy_level == 4) { + base_opts.output_abs_error = 0.1f; + base_opts.output_rel_error = 0.02f; + } else if constexpr (std::is_same::value) { base_opts.output_abs_error = 0.055f; base_opts.output_rel_error = 0.02f; - } else if (base_opts.accuracy_level == 4) { - base_opts.output_abs_error = 0.1f; } {