diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 473ec51524..c3e43f897c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -33,6 +33,7 @@ constexpr size_t A = 0, }; typedef enum { + Level0, /*!< input fp32, accumulator fp32 */ Level1, /*!< input fp32, accumulator fp32 */ Level2, /*!< input fp16, accumulator fp16 */ Level3, /*!< input bf16, accumulator fp32 */ diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 87a9ef762b..6dedce24e7 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -274,11 +274,12 @@ void TestMatMulNBitsTyped() { base_opts.block_size = block_size; base_opts.accuracy_level = accuracy_level; - if constexpr (std::is_same::value) { + if (base_opts.accuracy_level == 4) { + base_opts.output_abs_error = 0.1f; + base_opts.output_rel_error = 0.02f; + } else if constexpr (std::is_same::value) { base_opts.output_abs_error = 0.055f; base_opts.output_rel_error = 0.02f; - } else if (base_opts.accuracy_level == 4) { - base_opts.output_abs_error = 0.1f; } {