diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 4feeb5f830..0b819b77b8 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -98,7 +98,7 @@ void TryEnableQNNSaver(ProviderOptions& qnn_options) { void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, - float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, + float fp32_abs_err, std::optional log_severity, bool verify_outputs, std::function* ep_graph_checker) { EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; @@ -108,7 +108,8 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(log_severity); + + ScopedDefaultLoggerSeverity scoped_log_severity{logging_manager, log_severity}; onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 676460e108..86d08560e7 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -5,6 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) #include +#include #include #include #include @@ -23,6 +24,32 @@ namespace onnxruntime { namespace test { +// Class that sets (on construction) and resets (on destruction) a default logger's severity. +// The specified severity is optional. If not provided, the instance doesn't modify the severity. +class ScopedDefaultLoggerSeverity { + public: + ScopedDefaultLoggerSeverity(logging::LoggingManager& logging_manager, std::optional severity) + : logging_manager_{logging_manager}, + original_severity_{} { + if (severity.has_value()) { + original_severity_ = logging_manager_.DefaultLogger().GetSeverity(); + logging_manager_.SetDefaultLoggerSeverity(*severity); + } + } + + ~ScopedDefaultLoggerSeverity() { + if (original_severity_.has_value()) { + logging_manager_.SetDefaultLoggerSeverity(*original_severity_); + } + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ScopedDefaultLoggerSeverity); + + private: + logging::LoggingManager& logging_manager_; + std::optional original_severity_; +}; + // Signature for function that builds a float32 model. using GetTestModelFn = std::function; @@ -525,7 +552,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ProviderOptions qnn_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, QDQTolerance tolerance = QDQTolerance(), - logging::Severity log_severity = logging::Severity::kERROR, + std::optional log_severity = {}, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}, std::function* qnn_ep_graph_checker = nullptr) { @@ -537,7 +564,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Uncomment to dump LOGGER() output to stdout. // logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(log_severity); + ScopedDefaultLoggerSeverity scoped_log_severity{logging_manager, log_severity}; // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -738,14 +765,15 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float tolerance = 0.004, - logging::Severity log_severity = logging::Severity::kERROR, + std::optional log_severity = {}, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}) { // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(log_severity); + + ScopedDefaultLoggerSeverity scoped_log_severity{logging_manager, log_severity}; // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -1082,7 +1110,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f, - logging::Severity log_severity = logging::Severity::kERROR, + std::optional log_severity = {}, bool verify_outputs = true, std::function* ep_graph_checker = nullptr);