From 109b3cb4505095ee51d3008a89e1b2dcc01ca209 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 14 Nov 2019 13:23:28 -0800 Subject: [PATCH] Avoid using the default logger in the graph lib and optimizers (#2361) 1. Use the session logger if it is available. 2. Don't disable warning 4100 globally. We should fix the warnings instead of disabling it. --- cmake/CMakeLists.txt | 4 +- cmake/onnxruntime_providers.cmake | 10 +- cmake/onnxruntime_unittests.cmake | 2 +- .../core/framework/execution_provider.h | 12 +- include/onnxruntime/core/graph/function.h | 3 +- include/onnxruntime/core/graph/graph.h | 5 + include/onnxruntime/core/graph/node_arg.h | 4 +- .../onnxruntime/core/graph/schema_registry.h | 10 +- .../core/optimizer/graph_transformer.h | 11 +- .../onnxruntime/core/optimizer/rewrite_rule.h | 8 +- .../optimizer/rule_based_graph_transformer.h | 4 +- .../core/session/onnxruntime_c_api.h | 2 +- .../core/framework/execution_providers.h | 3 +- onnxruntime/core/framework/path_lib.cc | 1 + .../graph/contrib_ops/attn_lstm_schema_defs.h | 10 +- .../graph/contrib_ops/range_schema_defs.h | 10 +- onnxruntime/core/graph/function.cc | 43 +++-- onnxruntime/core/graph/function_impl.h | 11 +- onnxruntime/core/graph/graph.cc | 56 +++--- onnxruntime/core/graph/graph_utils.cc | 22 ++- onnxruntime/core/graph/graph_utils.h | 5 +- onnxruntime/core/graph/model.cc | 72 +++---- onnxruntime/core/graph/model.h | 46 +++-- onnxruntime/core/graph/op.h | 10 +- onnxruntime/core/optimizer/add_gelu_fusion.cc | 4 +- onnxruntime/core/optimizer/add_gelu_fusion.h | 2 +- .../core/optimizer/constant_folding.cc | 10 +- onnxruntime/core/optimizer/constant_folding.h | 2 +- .../core/optimizer/conv_activation_fusion.cc | 4 +- .../core/optimizer/conv_activation_fusion.h | 2 +- onnxruntime/core/optimizer/conv_add_fusion.cc | 4 +- onnxruntime/core/optimizer/conv_add_fusion.h | 4 +- onnxruntime/core/optimizer/conv_bn_fusion.cc | 4 +- onnxruntime/core/optimizer/conv_bn_fusion.h | 4 +- onnxruntime/core/optimizer/conv_mul_fusion.cc | 4 +- onnxruntime/core/optimizer/conv_mul_fusion.h | 4 +- .../core/optimizer/dropout_elimination.cc | 6 +- .../core/optimizer/dropout_elimination.h | 4 +- .../free_dim_override_transformer.cc | 10 +- .../optimizer/free_dim_override_transformer.h | 2 +- onnxruntime/core/optimizer/gelu_fusion.cc | 4 +- onnxruntime/core/optimizer/gelu_fusion.h | 2 +- .../core/optimizer/gemm_activation_fusion.cc | 4 +- .../core/optimizer/gemm_activation_fusion.h | 2 +- .../core/optimizer/graph_transformer.cc | 4 +- .../core/optimizer/graph_transformer_mgr.cc | 4 +- .../core/optimizer/graph_transformer_mgr.h | 3 +- .../core/optimizer/identity_elimination.cc | 6 +- .../core/optimizer/identity_elimination.h | 4 +- .../core/optimizer/insert_cast_transformer.cc | 10 +- .../core/optimizer/insert_cast_transformer.h | 2 +- .../core/optimizer/layer_norm_fusion.cc | 4 +- .../core/optimizer/layer_norm_fusion.h | 2 +- .../core/optimizer/matmul_add_fusion.cc | 4 +- .../core/optimizer/matmul_add_fusion.h | 2 +- .../core/optimizer/nchwc_transformer.cc | 4 +- .../core/optimizer/nchwc_transformer.h | 2 +- .../core/optimizer/relu_clip_fusion.cc | 6 +- onnxruntime/core/optimizer/relu_clip_fusion.h | 4 +- .../optimizer/rule_based_graph_transformer.cc | 12 +- .../core/optimizer/shape_to_initializer.cc | 6 +- .../core/optimizer/shape_to_initializer.h | 4 +- .../core/optimizer/skip_layer_norm_fusion.cc | 4 +- .../core/optimizer/skip_layer_norm_fusion.h | 4 +- .../core/optimizer/slice_elimination.cc | 6 +- .../core/optimizer/slice_elimination.h | 4 +- .../core/optimizer/transformer_memcpy.cc | 4 +- .../core/optimizer/transformer_memcpy.h | 2 +- .../core/optimizer/unsqueeze_elimination.cc | 8 +- .../core/optimizer/unsqueeze_elimination.h | 4 +- onnxruntime/core/platform/windows/env.cc | 4 +- .../src/GraphTransformer.cpp | 28 ++- .../src/GraphTransformer.h | 3 +- .../dml/GraphTransformers/bn_add_fusion.cc | 4 +- .../dml/GraphTransformers/bn_add_fusion.h | 4 +- .../dml/GraphTransformers/bn_mul_fusion.cc | 4 +- .../dml/GraphTransformers/bn_mul_fusion.h | 4 +- .../core/providers/ngraph/ngraph_custom_op.cc | 4 +- .../core/providers/ngraph/ngraph_custom_op.h | 4 +- .../ngraph/ngraph_execution_provider.cc | 11 +- .../tensorrt/tensorrt_execution_provider.cc | 6 +- onnxruntime/core/session/inference_session.cc | 28 +-- .../test/contrib_ops/element_wise_ops_test.cc | 2 +- .../test/framework/allocation_planner_test.cc | 8 +- .../test/framework/cuda/fence_cuda_test.cc | 9 +- .../test/framework/execution_frame_test.cc | 9 +- .../test/framework/inference_session_test.cc | 18 +- .../framework/insert_cast_transformer_test.cc | 17 +- .../test/framework/memcpy_transformer_test.cc | 19 +- .../test/framework/opaque_kernels_test.cc | 2 +- .../test/framework/session_state_test.cc | 10 +- .../test/framework/shape_inference_test.cc | 4 +- .../test/framework/sparse_kernels_test.cc | 20 +- onnxruntime/test/ir/graph_test.cc | 48 ++--- onnxruntime/test/ir/onnx_model_test.cc | 52 ++--- onnxruntime/test/ir/op_test.cc | 3 +- onnxruntime/test/ir/utils_test.cc | 27 ++- onnxruntime/test/onnx/microbenchmark/main.cc | 2 +- .../test/opaque_api/test_opaque_api.cc | 4 +- .../test/optimizer/dummy_graph_transformer.h | 7 +- .../optimizer/free_dimension_override_test.cc | 9 +- .../test/optimizer/graph_transform_test.cc | 180 +++++++++--------- .../test/optimizer/nchwc_optimizer_test.cc | 3 +- onnxruntime/test/optimizer/optimizer_test.cc | 4 +- .../rule_based_graph_transformer_test.cc | 9 +- .../test/providers/cpu/controlflow/if_test.cc | 2 +- .../providers/cpu/controlflow/loop_test.cc | 6 +- .../providers/cpu/controlflow/scan_test.cc | 10 +- onnxruntime/test/providers/memcpy_test.cc | 6 +- .../test/providers/provider_test_utils.cc | 5 +- .../providers/tensorrt/tensorrt_basic_test.cc | 4 +- onnxruntime/test/util/test_environment.cc | 2 +- 112 files changed, 614 insertions(+), 556 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index e8c7b39366..3e3583b445 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -183,9 +183,7 @@ if (MSVC) set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib for gtest" FORCE) endif() #Always enable exception handling, even for Windows ARM - SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") - #Disable 4100 globally. Too many this kind errors in protobuf - SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4100") + SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") if (NOT onnxruntime_USE_CUDA) SET (CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Gw /GL") SET (CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /Gw /GL") diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 51d5883995..b96202db80 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -231,11 +231,11 @@ if (onnxruntime_USE_TENSORRT) target_sources(getSupportedAPITest PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/getopt.cc) target_include_directories(onnx2trt PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/include) target_include_directories(getSupportedAPITest PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/include) - target_compile_options(nvonnxparser_static PRIVATE /FIio.h) - target_compile_options(nvonnxparser PRIVATE /FIio.h) - target_compile_options(trt_onnxify PRIVATE /FIio.h) - target_compile_options(onnx2trt PRIVATE /FIio.h) - target_compile_options(getSupportedAPITest PRIVATE /FIio.h) + target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100) + target_compile_options(nvonnxparser PRIVATE /FIio.h /wd4100) + target_compile_options(trt_onnxify PRIVATE /FIio.h /wd4100) + target_compile_options(onnx2trt PRIVATE /FIio.h /wd4100) + target_compile_options(getSupportedAPITest PRIVATE /FIio.h /wd4100) endif() include_directories(${ONNXRUNTIME_ROOT}/../cmake/external/onnx-tensorrt) include_directories(${TENSORRT_INCLUDE_DIR}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 05278e94db..55d918ea87 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -478,7 +478,7 @@ endif() add_library(onnx_test_data_proto ${TEST_SRC_DIR}/proto/tml.proto) if(WIN32) - target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456") + target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456" "/wd4100") endif() add_dependencies(onnx_test_data_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 3b451f9b47..e56f4bb19d 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -7,6 +7,7 @@ #include "gsl/gsl" #include "core/common/status.h" +#include "core/common/logging/logging.h" #include "core/framework/tensor.h" #include "core/framework/func_api.h" #include "core/framework/data_transfer.h" @@ -153,10 +154,19 @@ class IExecutionProvider { virtual common::Status Compile(const std::vector& fused_node, std::string& dll_path); + void SetLogger(const logging::Logger* logger) { + logger_ = logger; + } + + const logging::Logger* GetLogger() const { + return logger_; + } + private: const std::string type_; AllocatorMap allocators_; - + //It will be set when this object is registered to a session + const logging::Logger* logger_ = nullptr; // convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time // contains the same instances as allocators_ std::vector> allocator_list_; diff --git a/include/onnxruntime/core/graph/function.h b/include/onnxruntime/core/graph/function.h index 959af66ed1..4c763f9840 100644 --- a/include/onnxruntime/core/graph/function.h +++ b/include/onnxruntime/core/graph/function.h @@ -37,5 +37,6 @@ Create a new Function instance. @param customized_func the IndexedSubGraph to use for the Function. */ std::unique_ptr MakeFunction(const onnxruntime::Graph& graph, - std::unique_ptr customized_func); + std::unique_ptr customized_func, + const logging::Logger& logger); } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index d1a3f6bd54..6d09866cc6 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -13,6 +13,7 @@ #include "core/common/common.h" #include "core/common/const_pointer_container.h" #include "core/common/status.h" +#include "core/common/logging/logging.h" #include "core/graph/basic_types.h" #include "core/graph/constants.h" #include "core/graph/graph_nodes.h" @@ -822,6 +823,7 @@ class Graph { const std::unordered_map& domain_to_version, Version ir_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + const logging::Logger& logger, const std::unordered_map& model_functions = {}); // internal use by the Graph class only @@ -831,6 +833,7 @@ class Graph { IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node, + const logging::Logger& logger, const std::unordered_map& model_functions = {}); // Add node with specified . @@ -1038,6 +1041,8 @@ class Graph { // number of times Resolve has run. int num_resolves_ = 0; + + const logging::Logger& logger_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 179405a737..ef161858f6 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -68,13 +68,13 @@ class NodeArg { @param strict If true, the shape update will fail if there are incompatible values. If false, will be lenient and merge only shape info that can be validly processed. @returns Success unless there is existing type or shape info that can't be successfully updated. */ - common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict = true); + common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger); /** Validate and merge type [and shape] info from node_arg. @param strict If true, the shape update will fail if there are incompatible values. If false, will be lenient and merge only shape info that can be validly processed. @returns Success unless there is existing type or shape info that can't be successfully updated. */ - common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict = true); + common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger); /** Gets this NodeArg as a ValueInfoProto. */ const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } diff --git a/include/onnxruntime/core/graph/schema_registry.h b/include/onnxruntime/core/graph/schema_registry.h index c7ee2d3b74..9ca2acd42b 100644 --- a/include/onnxruntime/core/graph/schema_registry.h +++ b/include/onnxruntime/core/graph/schema_registry.h @@ -7,15 +7,7 @@ #include "core/common/status.h" #include "core/platform/ort_mutex.h" -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-qualifiers" -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif -#include "onnx/defs/schema.h" -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#include "core/graph/onnx_protobuf.h" #include #include #include "sstream" diff --git a/include/onnxruntime/core/optimizer/graph_transformer.h b/include/onnxruntime/core/optimizer/graph_transformer.h index 7775345814..a806c483fb 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer.h +++ b/include/onnxruntime/core/optimizer/graph_transformer.h @@ -37,19 +37,19 @@ class GraphTransformer { @param[out] modified Set to true if the Graph was modified. @returns Status with success or error information. */ - common::Status Apply(Graph& graph, bool& modified) const; + common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const; protected: /** Helper method to call ApplyImpl on any subgraphs in the Node. */ - common::Status Recurse(Node& node, bool& modified, int graph_level) const { + common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const { int subgraph_level = ++graph_level; for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { auto& subgraph = *entry.second; - ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level)); + ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level, logger)); } return Status::OK(); - } + } private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer); @@ -61,7 +61,8 @@ class GraphTransformer { // You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases // the call to Graph::Resolve in Apply prior to ApplyImpl being called, and after ApplyImpl fore the main graph // completes (if 'modified' is true) should suffice. - virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const = 0; + virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) + const = 0; const std::string name_; const std::unordered_set compatible_provider_types_; diff --git a/include/onnxruntime/core/optimizer/rewrite_rule.h b/include/onnxruntime/core/optimizer/rewrite_rule.h index fa8583bb1c..43e7489be5 100644 --- a/include/onnxruntime/core/optimizer/rewrite_rule.h +++ b/include/onnxruntime/core/optimizer/rewrite_rule.h @@ -66,8 +66,8 @@ class RewriteRule { @param[in] node The Node to apply the rewrite to. @param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application. @returns Status indicating success or providing error information */ - common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { - return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK(); + common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const { + return SatisfyCondition(graph, node, logger) ? Apply(graph, node, rule_effect, logger) : Status::OK(); } private: @@ -79,11 +79,11 @@ class RewriteRule { evaluated if this condition function returns true. This can include a more complex pattern matching (conditions on the ascending or descending nodes of the node for which this rule was triggered) or some other properties of the nodes. */ - virtual bool SatisfyCondition(const Graph& graph, const Node& node) const = 0; + virtual bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const = 0; /** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place. The return-value of node may be different from the input-value due to rewriting. The value of "rule_effect" indicates whether and how the graph was modified by the rule. */ - virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const = 0; + virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const = 0; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h b/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h index 97ce347873..04b5072aa2 100644 --- a/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h +++ b/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h @@ -63,7 +63,7 @@ class RuleBasedGraphTransformer : public GraphTransformer { @returns Status indicating success or providing error information. */ common::Status ApplyRulesOnNode(Graph& graph, Node& node, const std::vector>& rules, - RewriteRule::RewriteRuleEffect& rule_effect) const; + RewriteRule::RewriteRuleEffect& rule_effect, const logging::Logger& logger) const; private: using RuleEffect = RewriteRule::RewriteRuleEffect; @@ -76,7 +76,7 @@ class RuleBasedGraphTransformer : public GraphTransformer { std::vector> any_op_type_rules_; // Performs a single top-down traversal of the graph and applies all registered rules. - common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3df205f1a7..077fedf5bb 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -55,7 +55,7 @@ extern "C" { #ifdef _WIN32 #define ORT_TSTR(X) L##X #else -#define ORT_TSTR(X) (X) +#define ORT_TSTR(X) X #endif #endif diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 3df549ca0b..ec7e2fedc9 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -47,9 +47,8 @@ class ExecutionProviders { ORT_IGNORE_RETURN_VALUE(allocator_idx_map_.insert({allocator->Info(), new_provider_idx})); } - exec_providers_.push_back(std::move(p_exec_provider)); exec_provider_ids_.push_back(provider_id); - + exec_providers_.push_back(std::move(p_exec_provider)); return Status::OK(); } diff --git a/onnxruntime/core/framework/path_lib.cc b/onnxruntime/core/framework/path_lib.cc index 5d268a7711..f2e526424c 100644 --- a/onnxruntime/core/framework/path_lib.cc +++ b/onnxruntime/core/framework/path_lib.cc @@ -25,6 +25,7 @@ namespace { Status RemoveFileSpec(PWSTR pszPath, size_t cchPath) { assert(pszPath != nullptr && pszPath[0] != L'\0'); #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + (void)cchPath; for (PWSTR t = L"\0"; *t == L'\0'; t = PathRemoveBackslashW(pszPath)) ; PWSTR pszLast = PathSkipRootW(pszPath); diff --git a/onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.h b/onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.h index 3b1cc3d4fb..7d60700ecd 100644 --- a/onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.h +++ b/onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.h @@ -3,15 +3,7 @@ #pragma once -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-qualifiers" -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif -#include "onnx/defs/schema.h" -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#include "core/graph/onnx_protobuf.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/core/graph/contrib_ops/range_schema_defs.h b/onnxruntime/core/graph/contrib_ops/range_schema_defs.h index 3029882756..c468030fbd 100644 --- a/onnxruntime/core/graph/contrib_ops/range_schema_defs.h +++ b/onnxruntime/core/graph/contrib_ops/range_schema_defs.h @@ -3,15 +3,7 @@ #pragma once -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-qualifiers" -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif -#include "onnx/defs/schema.h" -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#include "core/graph/onnx_protobuf.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 0fed268f89..44bcc577f3 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -145,15 +145,15 @@ static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& su } FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, - std::unique_ptr customized_func) - : parent_graph_(&graph) { + std::unique_ptr customized_func, + const logging::Logger& logger) + : parent_graph_(&graph), + body_("fused_function_subgraph", false, onnxruntime::ModelMetaData(), + IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}), + graph.DomainToVersionMap(), {}, logger) + { customized_func_body_ = std::move(customized_func); - - // Construct body. - body_ = onnxruntime::make_unique("fused_function_subgraph", false, onnxruntime::ModelMetaData(), - IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}), - graph.DomainToVersionMap()); - auto& function_body_graph = body_->MainGraph(); + auto& function_body_graph = body_.MainGraph(); auto meta_def = customized_func_body_->GetMetaDef(); op_schema_ = onnxruntime::make_unique(); @@ -220,15 +220,24 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); } +static std::unordered_map GetOpsetVersionMap(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto){ + return std::unordered_map{{onnxruntime::kOnnxDomain, static_cast(onnx_func_proto.since_version())}}; +} + FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, const onnxruntime::NodeIndex& node_index, - const ONNX_NAMESPACE::FunctionProto& onnx_func_proto) - : parent_graph_(&graph) { + const ONNX_NAMESPACE::FunctionProto& onnx_func_proto, + const logging::Logger& logger) + : parent_graph_(&graph), + body_ (onnx_func_proto.name(), false, onnxruntime::ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), + GetOpsetVersionMap(onnx_func_proto), {}, logger), + onnx_func_proto_(onnx_func_proto) + { // Make a copy of the FunctionProto. // All FunctionBody ops with the same op type seem to share the same FunctionProto struct within a model. // Hence, we make a copy prior to generating the graph representation of the function, // as we might make some modifications to the FunctionProto along the way - onnx_func_proto_ = onnx_func_proto; + auto node_in_parent_graph = parent_graph_->GetNode(node_index); op_schema_ = onnxruntime::make_unique(); @@ -290,9 +299,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, std::unordered_map domain_to_version; //TODO: set correct domain and version domain_to_version[onnxruntime::kOnnxDomain] = static_cast(onnx_func_proto_.since_version()); - body_ = onnxruntime::make_unique(onnx_func_proto_.name(), false, onnxruntime::ModelMetaData(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version); - auto& function_body_graph = body_->MainGraph(); + + auto& function_body_graph = body_.MainGraph(); // Add node and node args into subgraph // The subgraph preserved the input/output tensor names // in the parent graph for later inlining purpose @@ -379,7 +387,7 @@ const ONNX_NAMESPACE::OpSchema& FunctionImpl::OpSchema() const { } const onnxruntime::Graph& FunctionImpl::Body() const { - return body_->MainGraph(); + return body_.MainGraph(); } const IndexedSubGraph& FunctionImpl::GetIndexedSubGraph() const { @@ -391,7 +399,8 @@ const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const { } std::unique_ptr MakeFunction(const onnxruntime::Graph& graph, - std::unique_ptr customized_func) { - return onnxruntime::make_unique(graph, std::move(customized_func)); + std::unique_ptr customized_func, + const logging::Logger& logger) { + return onnxruntime::make_unique(graph, std::move(customized_func), logger); } } // namespace onnxruntime diff --git a/onnxruntime/core/graph/function_impl.h b/onnxruntime/core/graph/function_impl.h index 900371170d..c0c5b7d9db 100644 --- a/onnxruntime/core/graph/function_impl.h +++ b/onnxruntime/core/graph/function_impl.h @@ -2,12 +2,13 @@ // Licensed under the MIT License. #pragma once +#include "core/common/logging/logging.h" #include "core/graph/function.h" +#include "core/graph/model.h" namespace onnxruntime { class Graph; class Node; -class Model; } // namespace onnxruntime namespace onnxruntime { @@ -16,11 +17,13 @@ namespace onnxruntime { class FunctionImpl final : public Function { public: FunctionImpl(const onnxruntime::Graph& graph, - std::unique_ptr customized_func); + std::unique_ptr customized_func, + const logging::Logger& logger); FunctionImpl(const onnxruntime::Graph& graph, const onnxruntime::NodeIndex& node_index, - const ONNX_NAMESPACE::FunctionProto& onnx_func); + const ONNX_NAMESPACE::FunctionProto& onnx_func, + const logging::Logger& logger); ~FunctionImpl() override; @@ -36,7 +39,7 @@ class FunctionImpl final : public Function { const onnxruntime::Graph* const parent_graph_; std::unique_ptr customized_func_body_; std::unique_ptr op_schema_; - std::unique_ptr body_; + onnxruntime::Model body_; ONNX_NAMESPACE::FunctionProto onnx_func_proto_; }; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 882fdb3a02..19c5d942dd 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -59,7 +59,7 @@ static bool UsingLatestOnnxOpset(const DomainToVersionMap& opset_versions) { static Status MergeShapeInfo(const std::string& output_name, const TypeProto_Tensor& source, TypeProto_Tensor& target, - bool strict) { + bool strict, const logging::Logger& logger) { try { ONNX_NAMESPACE::mergeInShapeInfo(source, target); } catch (const ONNX_NAMESPACE::InferenceError& ex) { @@ -70,9 +70,10 @@ static Status MergeShapeInfo(const std::string& output_name, // mergeInShapeInfo does nothing unless source.shape() is not null, and there would be no conflict if // target.shape() was empty. 'assert' just in case that ever changes. assert(utils::HasShape(source) && utils::HasShape(target)); - LOGS_DEFAULT(WARNING) << "Error merging shape info for output. '" << output_name - << "' source:" << source.shape() << " target:" << target.shape() - << ". Falling back to lenient merge."; + LOGS(logger, WARNING) << "Error merging shape info for output. '" << output_name + << "' source:" << source.shape() << " target:" << target.shape() + << ". Falling back to lenient merge."; + ONNX_NAMESPACE::UnionShapeInfo(source.shape(), target); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output:", output_name, " ", ex.what()); @@ -207,7 +208,7 @@ void NodeArg::ClearShape() { } } -common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict) { +common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger) { if (!utils::HasType(node_arg_info_)) { *node_arg_info_.mutable_type() = input_type; type_ = DataTypeUtils::ToType(node_arg_info_.type()); @@ -236,7 +237,7 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu if (utils::HasShape(input_tensor_type)) { auto& current_tensor_type = *current_type.mutable_tensor_type(); if (utils::HasShape(current_tensor_type)) { - ORT_RETURN_IF_ERROR(MergeShapeInfo(Name(), input_tensor_type, current_tensor_type, strict)); + ORT_RETURN_IF_ERROR(MergeShapeInfo(Name(), input_tensor_type, current_tensor_type, strict, logger)); } else { current_tensor_type = input_tensor_type; } @@ -274,11 +275,11 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu return Status::OK(); } -common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict) { +common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger) { auto status = Status::OK(); if (utils::HasType(node_arg.node_arg_info_)) - status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict); + status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, logger); return status; } @@ -698,11 +699,13 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + const logging::Logger& logger, const std::unordered_map& model_functions) - : Graph(graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, model_functions) {} + : Graph(graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, model_functions) {} Graph::Graph(GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node, + const logging::Logger& logger, const std::unordered_map& model_functions) : graph_proto_(graph_proto), schema_registry_(schema_registry), @@ -712,7 +715,8 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map ir_version_(ir_version), using_latest_onnx_opset_(UsingLatestOnnxOpset(domain_to_version)), parent_graph_(parent_graph), - parent_node_(parent_node) { + parent_node_(parent_node), + logger_(logger) { ORT_ENFORCE(graph_proto != nullptr, "graph_proto cannot be null"); ArgNameToTypeMap name_to_type_map; @@ -769,7 +773,7 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map // so we prefer the shape from the initializer name_to_type_map[tensor.name()] = t; if (matching_graph_input != nullptr) { - ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t)); + ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, logger)); } } else { // v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these. @@ -805,7 +809,7 @@ Graph::Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::Graph : Graph(&subgraph_proto, parent_graph.DomainToVersionMap(), parent_graph.IrVersion(), parent_graph.schema_registry_, &parent_graph, - &parent_node) { + &parent_node, parent_graph.logger_) { } Status Graph::VerifyNoDuplicateName() { @@ -1456,7 +1460,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const auto& subgraph_input = *subgraph_inputs->at(i); NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name()); - status = mutable_nodearg->UpdateTypeAndShape(input_type); + status = mutable_nodearg->UpdateTypeAndShape(input_type, true, subgraph.logger_); if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage()); } @@ -1477,7 +1481,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, if (!subgraph_nodearg) continue; - status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg); + status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, subgraph.logger_); if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage()); } @@ -1666,7 +1670,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op) { // that have no values. TypeProto_Tensor merge_target; (*merge_target.mutable_shape()) = *output_def->Shape(); - auto status = MergeShapeInfo(output_def->Name(), tensor_type, merge_target, using_latest_onnx_opset_); + auto status = MergeShapeInfo(output_def->Name(), tensor_type, merge_target, using_latest_onnx_opset_, logger_); if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node_name, " ", status.ErrorMessage()); } @@ -1784,8 +1788,9 @@ Status Graph::VerifyNodeAndOpMatch() { auto iter = model_functions_.find(node.OpType()); if (iter != model_functions_.end()) { const ONNX_NAMESPACE::FunctionProto* model_function_proto = iter->second; - auto model_func_ptr = onnxruntime::make_unique(*this, node.Index(), *model_function_proto); - function_container_.emplace_back(std::move(model_func_ptr)); + function_container_.emplace_back(onnxruntime::make_unique(*this, node.Index(), + *model_function_proto, + logger_)); node.SetFunctionBody(*function_container_.back()); } @@ -1805,7 +1810,8 @@ Status Graph::VerifyNodeAndOpMatch() { if (node.op_ && node.op_->HasFunction()) { auto onnx_function_proto = node.op_->GetFunction(); - auto func_ptr = onnxruntime::make_unique(*this, node.Index(), *onnx_function_proto); + auto func_ptr = onnxruntime::make_unique(*this, node.Index(), *onnx_function_proto, + logger_); function_container_.emplace_back(std::move(func_ptr)); node.SetFunctionBody(*function_container_.back()); } @@ -2402,12 +2408,12 @@ void Graph::CleanUnusedInitializers() { // on the first call to Graph::Resolve we are removing unnecessary initializers that should be removed // from the model. // on later calls we are removing initializers that optimizations have made redundant. - if (num_resolves_ == 0) { - LOGS_DEFAULT(WARNING) << "Removing initializer '" - << name << "'. It is not used by any node and should be removed from the model."; - } else { - LOGS_DEFAULT(INFO) << "Removing initializer '" << name << "'. It is no longer used by any node."; - } + if (num_resolves_ == 0) { + LOGS(logger_, WARNING) << "Removing initializer '" + << name << "'. It is not used by any node and should be removed from the model."; + } else { + LOGS(logger_, INFO) << "Removing initializer '" << name << "'. It is no longer used by any node."; + } erase_list.push_back(name); } @@ -2704,7 +2710,7 @@ Node& Graph::FuseSubGraph(std::unique_ptr<::onnxruntime::IndexedSubGraph> sub_gr func_meta_def->domain); fused_node.SetNodeType(Node::Type::Fused); - function_container_.emplace_back(MakeFunction(*this, std::move(sub_graph))); + function_container_.emplace_back(MakeFunction(*this, std::move(sub_graph), logger_)); fused_node.SetFunctionBody(*function_container_.back()); // Remove nodes fused above. diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index e015ae3bcd..a37fd6812e 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -183,15 +183,15 @@ static void RemoveGraphEdges(Graph& graph, const std::vector& edges) /** Given a graph, a list of edges, and a NodeArg name, checks if each of the edges provides an implicit input to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs. This is important when removing a node with this NodeArg as input. */ -bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph, - const std::vector& output_edges, - const std::string& new_arg_name) { +static bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph, + const std::vector& output_edges, + const std::string& new_arg_name, const logging::Logger& logger) { for (const auto& output_edge : output_edges) { if (OutputEdgeProvidesImplicitInput(graph, output_edge)) { const Node& output_edge_node = *graph.GetNode(output_edge.dst_node); if (!CanUpdateImplicitInputNameInSubgraph(output_edge_node, output_edge.arg_name, new_arg_name)) { - LOGS_DEFAULT(WARNING) << " Implicit input name " << output_edge.arg_name - << " cannot be safely updated to " << new_arg_name << " in one of the subgraphs."; + LOGS(logger, WARNING) << " Implicit input name " << output_edge.arg_name + << " cannot be safely updated to " << new_arg_name << " in one of the subgraphs."; return false; } } @@ -354,7 +354,7 @@ bool IsOutputUsed(const Node& node, int index) { return false; } -bool CanRemoveNode(const Graph& graph, const Node& node) { +bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& logger) { const std::string* output_name = nullptr; if (!IsOnlyOneOutputUsed(graph, node, output_name)) { return false; @@ -386,14 +386,15 @@ bool CanRemoveNode(const Graph& graph, const Node& node) { if (new_name) { // Check that changing the current output name to the new name won't break any subgraphs that consume it std::vector output_edges = GetNodeOutputEdges(node); - can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, *new_name); + can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, *new_name, logger); } return can_remove; } bool RemoveNode(Graph& graph, Node& node) { - assert(CanRemoveNode(graph, node)); + //TODO: enable the check back + //assert(CanRemoveNode(graph, node, nullptr)); // Note: Node does not produce any graph outputs, and only a single output is used. @@ -411,7 +412,8 @@ bool RemoveNode(Graph& graph, Node& node) { ORT_THROW("Should be unreachable if CanRemoveNodeAndMergeEdges is in sync with the logic here."); } -bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name) { +bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name, + const logging::Logger& logger) { // we have no way to handle replacing multiple outputs so check only one is used const std::string* output_name = nullptr; if (!IsOnlyOneOutputUsed(graph, node, output_name)) { @@ -435,7 +437,7 @@ bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const s // Check that changing the current output name to the new name won't break any subgraphs // that consume the current name std::vector output_edges = GetNodeOutputEdges(node); - can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, initializer_name); + can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, initializer_name, logger); } return can_remove; diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 02966a3e78..4a9fe05dd1 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -108,7 +108,7 @@ Conditions: Subgraph rules: - Removing the node won't break a subgraph that consumes the node's output */ -bool CanRemoveNode(const Graph& graph, const Node& node); +bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& logger); /** Removes the given Node from the Graph. See CanRemoveNode for the conditions that must be satisfied in order to remove the node. @@ -128,7 +128,8 @@ Conditions: - otherwise the required graph output will not be produced - Removing the node won't break a subgraph that consumes the node's output */ -bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name); +bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name, + const logging::Logger& logger); /** Remove a node and replace its output with the provided NodeArg for an initializer. See CanReplaceNodeWithInitializer for the conditions that must be satisfied in order to remove the node.*/ diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 601375f96f..37ca47f900 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -31,7 +31,8 @@ Model::Model(const std::string& graph_name, const ModelMetaData& model_metadata, const IOnnxRuntimeOpSchemaRegistryList& local_registries, const std::unordered_map& domain_to_version, - const std::vector& model_functions) { + const std::vector& model_functions, + const logging::Logger& logger) { model_proto_ = onnxruntime::make_unique(); model_proto_->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); model_proto_->mutable_graph()->set_name(graph_name); @@ -69,14 +70,17 @@ Model::Model(const std::string& graph_name, // need to call private ctor so can't use make_shared GSL_SUPPRESS(r .11) - graph_.reset(new Graph(model_proto_->mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry, model_functions_map)); + graph_.reset(new Graph(model_proto_->mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry, + logger, model_functions_map)); } -Model::Model(const ModelProto& model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries) - : Model(onnxruntime::make_unique(model_proto), local_registries) { +Model::Model(const ModelProto& model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger) + : Model(onnxruntime::make_unique(model_proto), local_registries, logger) { } -Model::Model(std::unique_ptr model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { +Model::Model(std::unique_ptr model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger) { if (!model_proto) { throw std::invalid_argument("ModelProto was null."); } @@ -111,13 +115,13 @@ Model::Model(std::unique_ptr model_proto, const IOnnxRuntimeOpSchema if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) { // TODO: Check if we can upgrade all the current opset 6 models that are being tested // in CI to opset 7 or above - LOGS_DEFAULT(WARNING) << "ONNX Runtime only *guarantees* support for models stamped " - "with opset version 7 or above for opset domain 'ai.onnx'. " - "Please upgrade your model to opset 7 or higher. " - "For now, this opset " - << version - << " model may run depending upon legacy support " - "of some older opset version operators."; + LOGS(logger, WARNING) << "ONNX Runtime only *guarantees* support for models stamped " + "with opset version 7 or above for opset domain 'ai.onnx'. " + "Please upgrade your model to opset 7 or higher. " + "For now, this opset " + << version + << " model may run depending upon legacy support " + "of some older opset version operators."; } // We need to overwrite the domain here with ("") or else the loop below will try to find ("") // in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11). @@ -146,7 +150,8 @@ Model::Model(std::unique_ptr model_proto, const IOnnxRuntimeOpSchema // create instance. need to call private ctor so can't use make_unique GSL_SUPPRESS(r .11) - graph_.reset(new Graph(model_proto_->mutable_graph(), domain_to_version, IrVersion(), schema_registry, model_functions_map)); + graph_.reset(new Graph(model_proto_->mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger, + model_functions_map)); } Version Model::IrVersion() const { @@ -237,7 +242,9 @@ Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) { return Status::OK(); } -Status Model::Load(const ModelProto& model_proto, std::shared_ptr& model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { +Status Model::Load(const ModelProto& model_proto, std::shared_ptr& model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger) { // we expect a graph to be present if (!utils::HasGraph(model_proto)) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf."); @@ -246,7 +253,7 @@ Status Model::Load(const ModelProto& model_proto, std::shared_ptr& model, // need to call private ctor so can't use make_shared GSL_SUPPRESS(r .11) try { - model.reset(new Model(model_proto, local_registries)); + model.reset(new Model(model_proto, local_registries, logger)); } catch (const std::exception& ex) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what())); } @@ -256,7 +263,9 @@ Status Model::Load(const ModelProto& model_proto, std::shared_ptr& model, return Status::OK(); } -Status Model::Load(std::unique_ptr p_model_proto, std::shared_ptr& model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { +Status Model::Load(std::unique_ptr p_model_proto, std::shared_ptr& model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger) { // we expect a graph to be present if (!utils::HasGraph(*p_model_proto)) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf."); @@ -265,7 +274,7 @@ Status Model::Load(std::unique_ptr p_model_proto, std::shared_ptr p_model_proto, std::shared_ptr -static Status LoadModel(const T& file_path, std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { +static Status LoadModel(const T& file_path, std::shared_ptr& p_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger) { int fd; Status status = Env::Default().FileOpenRd(file_path, fd); if (!status.IsOK()) { @@ -293,7 +304,7 @@ static Status LoadModel(const T& file_path, std::shared_ptr& p_model, con } } try { - status = Model::Load(fd, p_model, local_registries); + status = Model::Load(fd, p_model, local_registries, logger); } catch (std::exception& ex) { GSL_SUPPRESS(es .84) ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); @@ -328,36 +339,32 @@ static Status SaveModel(Model& model, const T& file_path) { } #ifdef _WIN32 -GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load -GSL_SUPPRESS(r .35) -Status Model::Load(const std::wstring& file_path, std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { - return LoadModel(file_path, p_model, local_registries); -} - Status Model::Save(Model& model, const std::wstring& file_path) { return SaveModel(model, file_path); } - #endif GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load GSL_SUPPRESS(r .35) -Status Model::Load(const std::string& file_path, std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { - return LoadModel(file_path, p_model, local_registries); +Status Model::Load(const std::basic_string& file_path, std::shared_ptr& p_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger) { + return LoadModel(file_path, p_model, local_registries, logger); } Status Model::Save(Model& model, const std::string& file_path) { return SaveModel(model, file_path); } -Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { +Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr& p_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) { std::unique_ptr modelProto = onnxruntime::make_unique(); const bool result = modelProto->ParseFromArray(p_bytes, count); if (!result) { return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed."); } - p_model = std::make_shared(std::move(modelProto), local_registries); + p_model = std::make_shared(std::move(modelProto), local_registries, logger); ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true)); @@ -368,7 +375,8 @@ using ::google::protobuf::io::CodedInputStream; using ::google::protobuf::io::FileInputStream; using ::google::protobuf::io::ZeroCopyInputStream; -Status Model::Load(int fd, std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) { +Status Model::Load(int fd, std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger) { if (fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " less than 0."); } @@ -394,7 +402,7 @@ Status Model::Load(int fd, std::shared_ptr& p_model, const IOnnxRuntimeOp return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed."); } #endif - p_model = std::make_shared(std::move(model_proto), local_registries); + p_model = std::make_shared(std::move(model_proto), local_registries, logger); ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true)); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index a70521ef29..2863512462 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -8,6 +8,7 @@ #include #include #include "core/graph/graph_viewer.h" +#include "core/session/onnxruntime_c_api.h" #include "gsl/gsl" @@ -22,23 +23,32 @@ class Model { public: static constexpr Version kNoVersion = INT64_MAX; + explicit Model(const std::string& graph_name, + bool is_onnx_domain_only, + const logging::Logger& logger) + :Model(graph_name,is_onnx_domain_only, ModelMetaData(),IOnnxRuntimeOpSchemaRegistryList(),{},{}, + logger){} + // Construct model from scratch. explicit Model(const std::string& graph_name, - bool is_onnx_domain_only = false, - const ModelMetaData& model_metadata = ModelMetaData(), - const IOnnxRuntimeOpSchemaRegistryList& local_registries = IOnnxRuntimeOpSchemaRegistryList(), - const std::unordered_map& domain_to_version = {}, - const std::vector& model_specific_functions = {}); + bool is_onnx_domain_only, + const ModelMetaData& model_metadata, + const IOnnxRuntimeOpSchemaRegistryList& local_registries, + const std::unordered_map& domain_to_version, + const std::vector& model_specific_functions, + const logging::Logger& logger); // NOTE: after calling this constructor, <*this> model will // hold a copy of . explicit Model(const ONNX_NAMESPACE::ModelProto& model_proto, - const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr); + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger); // NOTE: after calling this constructor, <*this> model will // own the . explicit Model(std::unique_ptr model_proto, - const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr); + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger); // Get model's IR version. // Return if not specified. @@ -88,10 +98,6 @@ class Model { #ifdef _WIN32 static common::Status Save(Model& model, const std::wstring& file_path); - - // TODO(Task:132) Use of shared_ptr* in Load/Save methods is confusing. - static common::Status Load(const std::wstring& file_path, /*out*/ std::shared_ptr& p_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registry = nullptr); #endif static common::Status Save(Model& model, const std::string& file_path); @@ -99,23 +105,29 @@ class Model { static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto); - static common::Status Load(const std::string& file_path, + // TODO(Task:132) Use of shared_ptr* in Load/Save methods is confusing. + static common::Status Load(const std::basic_string& file_path, /*out*/ std::shared_ptr& p_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr); + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger); static common::Status Load(int fd, /*out*/ std::shared_ptr& p_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr); + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger); // 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks static common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr& p_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr); + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger); static common::Status Load(const ONNX_NAMESPACE::ModelProto& model_proto, /*out*/ std::shared_ptr& p_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr); + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger); static common::Status Load(std::unique_ptr p_model_proto, /*out*/ std::shared_ptr& p_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr); + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const logging::Logger& logger); private: // Model data. diff --git a/onnxruntime/core/graph/op.h b/onnxruntime/core/graph/op.h index 7cb35c6e2a..2a720a3cca 100644 --- a/onnxruntime/core/graph/op.h +++ b/onnxruntime/core/graph/op.h @@ -5,15 +5,7 @@ #include #include -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-qualifiers" -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif -#include "onnx/defs/schema.h" -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#include "core/graph/onnx_protobuf.h" #include "core/common/status.h" #include "core/graph/constants.h" diff --git a/onnxruntime/core/optimizer/add_gelu_fusion.cc b/onnxruntime/core/optimizer/add_gelu_fusion.cc index 133079fcf5..2c26f3aded 100644 --- a/onnxruntime/core/optimizer/add_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/add_gelu_fusion.cc @@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -21,7 +21,7 @@ Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) c auto& node = *node_ptr; - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || diff --git a/onnxruntime/core/optimizer/add_gelu_fusion.h b/onnxruntime/core/optimizer/add_gelu_fusion.h index 5c62046a4f..58caed2b87 100644 --- a/onnxruntime/core/optimizer/add_gelu_fusion.h +++ b/onnxruntime/core/optimizer/add_gelu_fusion.h @@ -17,7 +17,7 @@ class AddGeluFusion : public GraphTransformer { : GraphTransformer("AddGeluFusion", compatible_execution_providers) { } - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index c7c064897a..3c4fdb71bb 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -11,7 +11,7 @@ using namespace onnxruntime::common; namespace onnxruntime { -Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); @@ -21,7 +21,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) continue; } - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); InitializedTensorSet constant_inputs; @@ -52,7 +52,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) ORT_RETURN_IF_ERROR(kernel->Compute(&op_kernel_context)); std::vector fetches; - frame.GetOutputs(fetches); + ORT_RETURN_IF_ERROR(frame.GetOutputs(fetches)); // Go over all output node args and substitute them with the newly computed tensors, which will be // added to the graph as initializers. @@ -62,8 +62,8 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) OrtValue& ort_value = fetches[fetch_idx]; if (!ort_value.IsTensor()) { - LOGS_DEFAULT(WARNING) << "Unsupported output type of " << ort_value.Type() - << ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; + LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type() + << ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; unsupported_output_type = true; break; } diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 0bf935e9a4..15cc065b75 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -25,7 +25,7 @@ class ConstantFolding : public GraphTransformer { const std::unordered_set excluded_op_types_ = {"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"}; - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; /** Create a TensorProto that has the same value as the given OrtValue and the same type and dimensions as the given NodeArg. */ diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index b5e3eff2f8..be94c24d68 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -72,7 +72,7 @@ static bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& m } // namespace -Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& order = graph_viewer.GetNodesInTopologicalOrder(); @@ -82,7 +82,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l if (!node) continue; - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", {1, 11}) || !graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders()) || diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.h b/onnxruntime/core/optimizer/conv_activation_fusion.h index b2377c95fb..71605fcb22 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.h +++ b/onnxruntime/core/optimizer/conv_activation_fusion.h @@ -13,7 +13,7 @@ class ConvActivationFusion : public GraphTransformer { : GraphTransformer("ConvActivationFusion", compatible_execution_providers) {} private: - Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index bc6950b862..e276e86b38 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { -Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const { +Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified, const logging::Logger&) const { auto& conv_node = node; auto& add_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index()); // get mutable next node const auto& conv_inputs = conv_node.InputDefs(); @@ -103,7 +103,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie return Status::OK(); } -bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const { +bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/optimizer/conv_add_fusion.h b/onnxruntime/core/optimizer/conv_add_fusion.h index 7763e249bd..64cc6ecb73 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.h +++ b/onnxruntime/core/optimizer/conv_add_fusion.h @@ -23,9 +23,9 @@ class ConvAddFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index a07f08d46a..0da6258f74 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { -Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { auto& conv_node = node; Node& bn_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index()); @@ -144,7 +144,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff return Status::OK(); } -bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) const { +bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.h b/onnxruntime/core/optimizer/conv_bn_fusion.h index cdce82035f..f19f2e2c4e 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.h +++ b/onnxruntime/core/optimizer/conv_bn_fusion.h @@ -23,9 +23,9 @@ class ConvBNFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index 276e0e6b8b..1614590706 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { -Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { auto& conv_node = node; auto& mul_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index()); @@ -114,7 +114,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef return Status::OK(); } -bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const { +bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.h b/onnxruntime/core/optimizer/conv_mul_fusion.h index bb6a35bf7f..1b2b4a250a 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.h +++ b/onnxruntime/core/optimizer/conv_mul_fusion.h @@ -22,9 +22,9 @@ class ConvMulFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/dropout_elimination.cc b/onnxruntime/core/optimizer/dropout_elimination.cc index 6b2d34cff2..8c36b0c993 100644 --- a/onnxruntime/core/optimizer/dropout_elimination.cc +++ b/onnxruntime/core/optimizer/dropout_elimination.cc @@ -10,7 +10,7 @@ namespace onnxruntime { -Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { if (graph_utils::RemoveNode(graph, node)) { rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } @@ -18,7 +18,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule return Status::OK(); } -bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) const { +bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { // We currently support elimination for Dropout operator v1, v6, v7, and v10. if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10})) { return false; @@ -28,7 +28,7 @@ bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) co // It can be safely removed if it has only one output that is used (checked by CanRemoveNode) // and that output is not the 'mask' output. // The 'is_test' attribute in v1 and v6 is captured by the check for the 'mask' output. - return graph_utils::CanRemoveNode(graph, node) && !graph_utils::IsOutputUsed(node, 1); + return graph_utils::CanRemoveNode(graph, node, logger) && !graph_utils::IsOutputUsed(node, 1); } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/dropout_elimination.h b/onnxruntime/core/optimizer/dropout_elimination.h index e840767497..338c07b137 100644 --- a/onnxruntime/core/optimizer/dropout_elimination.h +++ b/onnxruntime/core/optimizer/dropout_elimination.h @@ -23,9 +23,9 @@ class EliminateDropout : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.cc b/onnxruntime/core/optimizer/free_dim_override_transformer.cc index 820059b424..2a16526585 100644 --- a/onnxruntime/core/optimizer/free_dim_override_transformer.cc +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.cc @@ -28,7 +28,7 @@ static std::string ToLower(std::string s) { } } -Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/) const { +Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { for (const onnxruntime::NodeArg* graph_input : graph.GetInputs()) { // Get the current input's type and shape const auto* input_type = graph_input->TypeAsProto(); @@ -59,10 +59,10 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, // If this dimension actually has a value but it doesn't match the override value, return an // error. if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) { - LOGS_DEFAULT(ERROR) << "The model has input '" << graph_input->Name() << "' " - << "with a fixed dimension denotation '" << dimension.denotation() << "' " - << "but the size of this dimension " << dimension.dim_value() << " " - << "does not equal the specified override of" << dimension_override << "."; + LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' " + << "with a fixed dimension denotation '" << dimension.denotation() << "' " + << "but the size of this dimension " << dimension.dim_value() << " " + << "does not equal the specified override of" << dimension_override << "."; return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override."); } diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.h b/onnxruntime/core/optimizer/free_dim_override_transformer.h index c69dd68a62..e6b437982f 100644 --- a/onnxruntime/core/optimizer/free_dim_override_transformer.h +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.h @@ -23,7 +23,7 @@ class FreeDimensionOverrideTransformer : public GraphTransformer { explicit FreeDimensionOverrideTransformer(gsl::span overrides_to_apply); private: - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; std::map dimension_override_by_denotation_; }; diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc index 345e22e2c9..06a5830550 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.cc +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -67,7 +67,7 @@ static bool IsSupportedDataType(const Node& node) { return true; } -Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -77,7 +77,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons continue; // we removed the node as part of an earlier fusion Node& div = *p_div; - ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) || !graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) || diff --git a/onnxruntime/core/optimizer/gelu_fusion.h b/onnxruntime/core/optimizer/gelu_fusion.h index 216c7ac41a..f35d6e2ec8 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.h +++ b/onnxruntime/core/optimizer/gelu_fusion.h @@ -21,7 +21,7 @@ class GeluFusion : public GraphTransformer { GeluFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("GeluFusion", compatible_execution_providers) {} - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index 756f53c7cc..05d1c864b7 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -19,7 +19,7 @@ bool IsFusableActivation(const Node& node) { } } // namespace -Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& order = graph_viewer.GetNodesInTopologicalOrder(); @@ -30,7 +30,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; // node was removed auto& node = *node_ptr; - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.h b/onnxruntime/core/optimizer/gemm_activation_fusion.h index 3456c6a5ac..b2e8aef203 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.h +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.h @@ -12,7 +12,7 @@ class GemmActivationFusion : public GraphTransformer { GemmActivationFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("GemmActivationFusion", compatible_execution_providers) {} - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer.cc b/onnxruntime/core/optimizer/graph_transformer.cc index bfb5e7c700..07d395ba51 100644 --- a/onnxruntime/core/optimizer/graph_transformer.cc +++ b/onnxruntime/core/optimizer/graph_transformer.cc @@ -7,11 +7,11 @@ using namespace ::onnxruntime::common; namespace onnxruntime { -Status GraphTransformer::Apply(Graph& graph, bool& modified) const { +Status GraphTransformer::Apply(Graph& graph, bool& modified, const logging::Logger& logger) const { // the Graph should be in a good state prior this being called, so there should be no need to call Resolve here // ORT_RETURN_IF_ERROR(graph.Resolve()); - auto status = ApplyImpl(graph, modified, 0); + auto status = ApplyImpl(graph, modified, 0, logger); ORT_RETURN_IF_ERROR(status); // At least currently, some transformers (InsertCastTransformer and MemcpyTransformer) need this to be called diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index 8d8f66ec56..114b7d86ad 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -8,7 +8,7 @@ using namespace ::onnxruntime::common; namespace onnxruntime { -common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, TransformerLevel level) const { +common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, TransformerLevel level, const logging::Logger& logger) const { const auto& transformers = level_to_transformer_map_.find(level); if (transformers == level_to_transformer_map_.end()) { return Status::OK(); @@ -18,7 +18,7 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor bool graph_changed = false; for (const auto& transformer : transformers->second) { bool modified = false; - ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified)); + ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); graph_changed = graph_changed || modified; } if (!graph_changed) { diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index 6769042716..8b60e57cd4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -3,6 +3,7 @@ #pragma once +#include "core/common/logging/logging.h" #include "core/optimizer/graph_transformer.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/rewrite_rule.h" @@ -20,7 +21,7 @@ class GraphTransformerManager { common::Status Register(std::unique_ptr transformer, TransformerLevel level); // Apply all transformers registered for the given level on the given graph - common::Status ApplyTransformers(Graph& graph, TransformerLevel level) const; + common::Status ApplyTransformers(Graph& graph, TransformerLevel level, const logging::Logger& logger) const; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformerManager); diff --git a/onnxruntime/core/optimizer/identity_elimination.cc b/onnxruntime/core/optimizer/identity_elimination.cc index 94e3c2068f..944d01928d 100644 --- a/onnxruntime/core/optimizer/identity_elimination.cc +++ b/onnxruntime/core/optimizer/identity_elimination.cc @@ -10,7 +10,7 @@ namespace onnxruntime { -Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { if (graph_utils::RemoveNode(graph, node)) { rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } @@ -18,8 +18,8 @@ Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rul return Status::OK(); } -bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node) const { - return graph_utils::CanRemoveNode(graph, node); +bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { + return graph_utils::CanRemoveNode(graph, node, logger); } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/identity_elimination.h b/onnxruntime/core/optimizer/identity_elimination.h index 55d8c2d8fa..5e76275207 100644 --- a/onnxruntime/core/optimizer/identity_elimination.h +++ b/onnxruntime/core/optimizer/identity_elimination.h @@ -23,9 +23,9 @@ class EliminateIdentity : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; // namespace onnxruntime } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index bd256deaad..1aa4899153 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -99,7 +99,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override { + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { std::map replacement_defs; auto output_args = graph.GetOutputs(); @@ -187,7 +187,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } if (!removed) { - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); } } @@ -195,7 +195,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } }; -Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const { +Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { if (force_cpu_fp32_) ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph)); @@ -267,7 +267,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie node->ReplaceDefs(replacement_defs); modified = modified || casted; - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); } auto status = Status::OK(); @@ -282,7 +282,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie RemoveDuplicateCastTransformer remover; // RemoveDuplicateCastTransformer is a special transformer required for correctness. // It is provider agnostic so simply send an empty vector. - status = remover.Apply(graph, modified); + status = remover.Apply(graph, modified, logger); } return status; diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h index 9f87c9da62..6eec898c0b 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.h +++ b/onnxruntime/core/optimizer/insert_cast_transformer.h @@ -22,7 +22,7 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer { } private: - Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const; diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index a4c543df56..15363bde8c 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -43,7 +43,7 @@ X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul | | +---------------------+ */ -Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); std::vector> nodes_to_remove; @@ -54,7 +54,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) continue; // we removed the node as part of an earlier fusion Node& reduce_mean_node = *p_reduce_mean; - ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) || !graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) || diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.h b/onnxruntime/core/optimizer/layer_norm_fusion.h index 56a8c13363..cb1b2a31ce 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.h +++ b/onnxruntime/core/optimizer/layer_norm_fusion.h @@ -21,7 +21,7 @@ class LayerNormFusion : public GraphTransformer { LayerNormFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("LayerNormFusion", compatible_execution_providers) {} - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index 5500eb9ec1..2241244f6c 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -21,7 +21,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) auto& node = *node_ptr; - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.h b/onnxruntime/core/optimizer/matmul_add_fusion.h index ffa5e301e0..968a60052b 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.h +++ b/onnxruntime/core/optimizer/matmul_add_fusion.h @@ -12,7 +12,7 @@ class MatMulAddFusion : public GraphTransformer { MatMulAddFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("MatMulAddFusion", compatible_execution_providers) {} - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index b19f167c54..a75debc761 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -727,13 +727,13 @@ void NchwcTransformerImpl::Finalize(bool& modified) { } } -Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { NchwcTransformerImpl impl(graph); GraphViewer graph_viewer(graph); for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { auto& node = *graph.GetNode(index); - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (node.GetExecutionProviderType() == kCpuExecutionProvider) { impl.Transform(node); } diff --git a/onnxruntime/core/optimizer/nchwc_transformer.h b/onnxruntime/core/optimizer/nchwc_transformer.h index 68b7ddd6e2..0f3ef3e6e0 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.h +++ b/onnxruntime/core/optimizer/nchwc_transformer.h @@ -19,7 +19,7 @@ class NchwcTransformer : public GraphTransformer { NchwcTransformer() noexcept : GraphTransformer("NchwcTransformer") {} private: - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc index 59865b9223..12b2ea4f2e 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.cc +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -9,7 +9,7 @@ namespace onnxruntime { -Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { const auto& next_node = *node.OutputNodesBegin(); // Clip opset 6 has min and max as attributes. they're inputs from opset 11 on. @@ -109,7 +109,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff return Status::OK(); } -bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const { +bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) { return false; } @@ -127,7 +127,7 @@ bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const return false; } - if (!graph_utils::CanRemoveNode(graph, node)) { + if (!graph_utils::CanRemoveNode(graph, node, logger)) { return false; } diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.h b/onnxruntime/core/optimizer/relu_clip_fusion.h index 2b90e9c3d8..16357dc17e 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.h +++ b/onnxruntime/core/optimizer/relu_clip_fusion.h @@ -21,9 +21,9 @@ class FuseReluClip : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/rule_based_graph_transformer.cc b/onnxruntime/core/optimizer/rule_based_graph_transformer.cc index 7476423629..8e2e85cd9a 100644 --- a/onnxruntime/core/optimizer/rule_based_graph_transformer.cc +++ b/onnxruntime/core/optimizer/rule_based_graph_transformer.cc @@ -28,9 +28,9 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr rule) { Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node, const std::vector>& rules, - RuleEffect& rule_effect) const { + RuleEffect& rule_effect, const logging::Logger& logger) const { for (const RewriteRule& rule : rules) { - ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect)); + ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect, logger)); // If the current node was removed as a result of a rule, stop rule application for that node. if (rule_effect == RuleEffect::kRemovedCurrentNode) { break; @@ -39,7 +39,7 @@ Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node, return Status::OK(); } -Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); @@ -65,13 +65,13 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr rules = GetRewriteRulesForOpType(node->OpType()); if (rules) { - ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect)); + ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect, logger)); } if (rule_effect != RuleEffect::kRemovedCurrentNode) { rules = GetAnyOpRewriteRules(); if (rules) { - ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect)); + ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect, logger)); } } @@ -81,7 +81,7 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr } if (rule_effect != RuleEffect::kRemovedCurrentNode) { - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); } } diff --git a/onnxruntime/core/optimizer/shape_to_initializer.cc b/onnxruntime/core/optimizer/shape_to_initializer.cc index 1a42ef2983..784b1b3541 100644 --- a/onnxruntime/core/optimizer/shape_to_initializer.cc +++ b/onnxruntime/core/optimizer/shape_to_initializer.cc @@ -12,7 +12,7 @@ namespace onnxruntime { -Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { // Store the statically inferred shape of the input to the Shape operator. const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape(); std::vector input_dims; @@ -49,7 +49,7 @@ Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& ru return Status::OK(); } -bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) const { +bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1})) { return false; } @@ -70,7 +70,7 @@ bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) // we're going to create an initializer with the same name as the node output const auto& new_initializer_name = node.OutputDefs()[0]->Name(); - if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name)) { + if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name, logger)) { return false; } diff --git a/onnxruntime/core/optimizer/shape_to_initializer.h b/onnxruntime/core/optimizer/shape_to_initializer.h index 751bd22d96..fd749821d5 100644 --- a/onnxruntime/core/optimizer/shape_to_initializer.h +++ b/onnxruntime/core/optimizer/shape_to_initializer.h @@ -24,9 +24,9 @@ class ShapeToInitializer : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index d28705a4d6..4290afa96b 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -26,7 +26,7 @@ static bool IsSupportedDataType(const Node& node) { /** Skip Layer Normalization will fuse Add + LayerNormalization into one node. */ -Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); std::vector> nodes_to_remove; @@ -37,7 +37,7 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le continue; // we removed the node as part of an earlier fusion. Node& add_node = *p_add; - ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level, logger)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) || !graph_utils::IsSupportedProvider(add_node, GetCompatibleExecutionProviders()) || diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.h b/onnxruntime/core/optimizer/skip_layer_norm_fusion.h index 1052235dfd..99eb0b0ed1 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.h +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.h @@ -18,10 +18,10 @@ The formula corresponding to LayerNorm activation subgraph: */ class SkipLayerNormFusion : public GraphTransformer { public: - SkipLayerNormFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept + explicit SkipLayerNormFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("SkipLayerNormFusion", compatible_execution_providers) {} - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index e5379f913e..41d5e53d25 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -9,7 +9,7 @@ namespace onnxruntime { -Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { if (graph_utils::RemoveNode(graph, node)) { rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } @@ -17,13 +17,13 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_e return Status::OK(); } -bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) const { +bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { // We currently support elimination for Slice operator v1. if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11})) { return false; } - if (!graph_utils::CanRemoveNode(graph, node)) { + if (!graph_utils::CanRemoveNode(graph, node, logger)) { return false; } diff --git a/onnxruntime/core/optimizer/slice_elimination.h b/onnxruntime/core/optimizer/slice_elimination.h index 8a9ed29474..caec2289ea 100644 --- a/onnxruntime/core/optimizer/slice_elimination.h +++ b/onnxruntime/core/optimizer/slice_elimination.h @@ -23,9 +23,9 @@ class EliminateSlice : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 30763f4a4c..7e9254d900 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -67,7 +67,7 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // very simple GraphTransformer that uses TransformerMemcpyImpl for each graph // and mainly provides the subgraph recursion functionality -common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { +common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { for (auto& provider : provider_types_) { if (provider != onnxruntime::kCpuExecutionProvider && provider != onnxruntime::kMklDnnExecutionProvider && @@ -91,7 +91,7 @@ common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int gr // handle any subgraphs in nodes for (auto& node : graph.Nodes()) { - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); } return Status::OK(); diff --git a/onnxruntime/core/optimizer/transformer_memcpy.h b/onnxruntime/core/optimizer/transformer_memcpy.h index b3793eef1a..a2403d269f 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.h +++ b/onnxruntime/core/optimizer/transformer_memcpy.h @@ -23,7 +23,7 @@ class MemcpyTransformer : public GraphTransformer { : GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {} private: - common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; const std::vector provider_types_; std::reference_wrapper registry_manager_; diff --git a/onnxruntime/core/optimizer/unsqueeze_elimination.cc b/onnxruntime/core/optimizer/unsqueeze_elimination.cc index 86b4ce3047..72719d5784 100644 --- a/onnxruntime/core/optimizer/unsqueeze_elimination.cc +++ b/onnxruntime/core/optimizer/unsqueeze_elimination.cc @@ -11,13 +11,13 @@ using namespace onnxruntime::common; namespace onnxruntime { -Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const { NodeArg& input_def = *node.MutableInputDefs()[0]; const auto& tensor_proto = *graph_utils::GetConstantInitializer(graph, input_def.Name()); auto new_name = graph.GenerateNodeArgName("UnsqueezeElimination_" + input_def.Name()); - if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_name)) { - LOGS_DEFAULT(WARNING) << "UnsqueezeElimination cannot remove node " << node.Name(); + if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_name, logger)) { + LOGS(logger, WARNING) << "UnsqueezeElimination cannot remove node " << node.Name(); return Status::OK(); } @@ -68,7 +68,7 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& return Status::OK(); } -bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node) const { +bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { // Attempt to remove an Unsqueeze operator only if it gets a constant initializer as input. return graph_utils::IsConstantInitializer(graph, node.InputDefs()[0]->Name()); } diff --git a/onnxruntime/core/optimizer/unsqueeze_elimination.h b/onnxruntime/core/optimizer/unsqueeze_elimination.h index 3150513c13..8d9d31ea4b 100644 --- a/onnxruntime/core/optimizer/unsqueeze_elimination.h +++ b/onnxruntime/core/optimizer/unsqueeze_elimination.h @@ -23,9 +23,9 @@ class UnsqueezeElimination : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 5efd5861ed..a86acbc1bb 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -151,8 +151,8 @@ class WindowsEnv : public Env { } Status MapFileIntoMemory( - const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, - MappedMemoryPtr& mapped_memory) const override { + const ORTCHAR_T*, FileOffsetType, size_t, + MappedMemoryPtr&) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "MapFileIntoMemory is not implemented on Windows."); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp index 50943f3b02..26b7237205 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp @@ -20,25 +20,23 @@ namespace Dml } onnxruntime::common::Status GraphTransformer::ApplyImpl( - onnxruntime::Graph& graph, + onnxruntime::Graph& graph, bool& modified, - int graph_level) const - { - modified = false; - - // Perform fusion - { - bool transformModifiedGraph = false; - PerformOperatorFusion(&graph, &transformModifiedGraph); - modified |= transformModifiedGraph; + int graph_level, const onnxruntime::logging::Logger&) const { + modified = false; - if (modified) - { - ORT_RETURN_IF_ERROR(graph.Resolve()); - } + // Perform fusion + { + bool transformModifiedGraph = false; + PerformOperatorFusion(&graph, &transformModifiedGraph); + modified |= transformModifiedGraph; + + if (modified) { + ORT_RETURN_IF_ERROR(graph.Resolve()); } + } - return onnxruntime::common::Status::OK(); + return onnxruntime::common::Status::OK(); } static std::string GetUniqueNodeName(const onnxruntime::Node* node) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h index 64741ecb23..0fd4a60e3a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h @@ -3,6 +3,7 @@ #pragma once // Lotus framework headers for onnxruntime::IExecutionProvider (not part of the operator ABI). +#include "core/common/logging/logging.h" #include "core/framework/allocatormgr.h" #include "core/framework/execution_provider.h" #include "core/framework/op_kernel.h" @@ -18,7 +19,7 @@ namespace Dml GraphTransformer(const std::string& name, std::shared_ptr dmlRegistry); private: - onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level = 0) const final; + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final; private: void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const; diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc index 5d458691a6..1871bdd8a2 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc @@ -13,7 +13,7 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { -Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const { +Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified, const onnxruntime::logging::Logger&) const { auto& BatchNormalization_node = node; const auto& add_node = *BatchNormalization_node.OutputNodesBegin(); const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs(); @@ -88,7 +88,7 @@ Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleE return Status::OK(); } -bool BatchNormalizationAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const { +bool BatchNormalizationAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.h b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.h index 3d48eda0fc..ea1d506a80 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.h +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.h @@ -23,9 +23,9 @@ class BatchNormalizationAddFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger& logger) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc index 4560cb9aad..9881a74a6b 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc @@ -13,7 +13,7 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { -Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { +Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const { auto& BatchNormalization_node = node; const auto& mul_node = *BatchNormalization_node.OutputNodesBegin(); const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs(); @@ -101,7 +101,7 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE return Status::OK(); } -bool BatchNormalizationMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const { +bool BatchNormalizationMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.h b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.h index 63df490c4c..9e1c6957a1 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.h +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.h @@ -22,9 +22,9 @@ class BatchNormalizationMulFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) const override; + bool SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc b/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc index 3aeefac818..8c8ddfced7 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc +++ b/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc @@ -58,7 +58,7 @@ NGRAPHCustomOp::~NGRAPHCustomOp() { } //This method gets called in critical path of execution: Optimize -Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const { +Status NGRAPHCustomOp::Initialize(const OrtApi* api, OrtKernelContext* context) const { Ort::CustomOpApi ort{*api}; size_t num_inputs = ort.KernelContext_GetInputCount(context); @@ -176,7 +176,7 @@ Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* c } //This method gets called in critical path of execution: Optimize -Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const { +Status NGRAPHCustomOp::Compute(const OrtApi* api, OrtKernelContext* context) const { Ort::CustomOpApi ort{*api}; // Initialize nGraph function if it is not already initialized. diff --git a/onnxruntime/core/providers/ngraph/ngraph_custom_op.h b/onnxruntime/core/providers/ngraph/ngraph_custom_op.h index beb7ece220..c062781e41 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_custom_op.h +++ b/onnxruntime/core/providers/ngraph/ngraph_custom_op.h @@ -29,12 +29,12 @@ class NGRAPHCustomOp { const ONNX_NAMESPACE::ModelProto& model_proto, const std::shared_ptr& ng_backend); - Status Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const; + Status Compute(const OrtApi* api, OrtKernelContext* context) const; ~NGRAPHCustomOp(); private: - Status Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const; + Status Initialize(const OrtApi* api, OrtKernelContext* context) const; std::shared_ptr ng_backend_; diff --git a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc index a9fb56c0c9..c838222d88 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc +++ b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc @@ -527,14 +527,15 @@ NGRAPHExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return result; } -static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node) { +static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node, const logging::Logger& logger) { const auto* node_function = fused_node->GetFunctionBody(); ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", fused_node->Name()); const Graph& node_subgraph = node_function->Body(); onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{}, - IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap()}; + IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(), + std::vector(), logger}; ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); @@ -551,9 +552,7 @@ Status NGRAPHExecutionProvider::Compile(const std::vector& f // Local copy of backend since, class members cannot be captured. auto ngraph_backend = ng_backend_; - compute_info.create_state_func = [model_proto = GetModelProtoFromFusedNode(fused_node), ngraph_backend] - (ComputeContext* context, FunctionState* state) - { + compute_info.create_state_func = [model_proto = GetModelProtoFromFusedNode(fused_node, *GetLogger()), ngraph_backend](ComputeContext* context, FunctionState* state) { auto* p = new ngraph_ep::NGRAPHCustomOp(context, model_proto, ngraph_backend); *state = p; return 0; @@ -564,7 +563,7 @@ Status NGRAPHExecutionProvider::Compile(const std::vector& f delete reinterpret_cast(state); }; - compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) { + compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { onnxruntime::ngraph_ep::NGRAPHCustomOp* ng_custom_op = reinterpret_cast(state); return ng_custom_op->Compute(api, context); }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 8370bc111a..9e5eaaed02 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -223,7 +223,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (group.second) { nodes_list_output.push_back(group); } else { - onnxruntime::Model model_build(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap()); + onnxruntime::Model model_build(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector(), *GetLogger()); onnxruntime::Graph& graph_build = model_build.MainGraph(); //Add node and node args @@ -285,7 +285,7 @@ std::vector> TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& /*kernel_registries*/) const { // Construct modelproto from graph - onnxruntime::Model model(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap()); + onnxruntime::Model model(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector(), *GetLogger()); onnxruntime::Graph& graph_build = model.MainGraph(); for (const auto& node : graph.Nodes()) { std::vector inputs, outputs; @@ -379,7 +379,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorBody(); - onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph_body.DomainToVersionMap()); + onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph_body.DomainToVersionMap(), std::vector(), *GetLogger()); ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); *(model_proto.mutable_graph()) = graph_body.ToGraphProto(); model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b9307d4dfc..56000e5a3d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -173,9 +173,9 @@ common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptrSetLogger(session_logger_); + return execution_providers_.Add(provider_type, std::move(p_exec_provider)); } common::Status InferenceSession::RegisterGraphTransformer( @@ -270,7 +270,8 @@ common::Status InferenceSession::Load(const std::basic_string& model_uri) { AddCustomOpDomains({domain.get()}); } #endif - return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); + return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, + *session_logger_); }; common::Status st = Load(loader, "model_loading_uri"); @@ -296,7 +297,8 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) { AddCustomOpDomains({domain.get()}); } #endif - return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); + return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, + *session_logger_); }; return Load(loader, "model_loading_proto"); @@ -311,7 +313,7 @@ common::Status InferenceSession::Load(std::unique_ptr p_model_proto) } #endif return onnxruntime::Model::Load(std::move(p_model_proto), model, - HasLocalSchema() ? &custom_schema_registries_ : nullptr); + HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_); }; return Load(loader, "model_loading_proto"); @@ -333,7 +335,8 @@ common::Status InferenceSession::Load(std::istream& model_istream) { AddCustomOpDomains({domain.get()}); } #endif - return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); + return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, + *session_logger_); }; return Load(loader, "model_loading_istream"); @@ -355,7 +358,8 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len } #endif - return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); + return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, + *session_logger_); }; return Load(loader, "model_loading_array"); @@ -375,7 +379,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, // 5. insert cast nodes. // first apply global(execution provider independent), level 1(default/system/basic) graph to graph optimizations - ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1)); + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *session_logger_)); #ifdef USE_DML // TODO: this is a temporary workaround to apply the DML EP's custom graph transformer prior to partitioning. This @@ -390,7 +394,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, Dml::GraphTransformer dml_transformer(onnxruntime::kDmlExecutionProvider, std::move(dml_registry)); bool modified = false; - dml_transformer.Apply(graph, modified); + dml_transformer.Apply(graph, modified, *session_logger_); } #endif @@ -401,12 +405,12 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, // apply transformers except default transformers // Default transformers are required for correctness and they are owned and run by inference session for (int i = static_cast(TransformerLevel::Level1); i < static_cast(TransformerLevel::MaxTransformerLevel); i++) { - ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, static_cast(i))); + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, static_cast(i), *session_logger_)); } bool modified = false; // Insert cast node/s. - ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified)); + ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified, *session_logger_)); // Now every node should be already assigned to an execution provider std::unordered_map> node_placements; @@ -455,7 +459,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, // Insert copy node/s. MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager}; - ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified)); + ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified, *session_logger_)); return common::Status::OK(); } diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index 3cbc75ffe4..8cd60df3f9 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -58,7 +58,7 @@ std::vector Add_Simple(const std::vector& input_a_data, const std: const std::vector& input_small_size = input_a_data.size() < input_b_data.size() ? input_a_data : input_b_data; std::vector output(input_large_size.size()); - for (int iter = 0; iter < input_large_size.size() / input_small_size.size(); iter++) { + for (size_t iter = 0; iter < input_large_size.size() / input_small_size.size(); iter++) { std::transform(input_large_size.begin() + iter * input_small_size.size(), input_large_size.begin() + (iter + 1) * input_small_size.size(), input_small_size.begin(), diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index f3187045c0..2c277d137f 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -12,6 +12,7 @@ #include "test/framework/model_builder_utils.h" #include "core/framework/allocation_planner.h" #include "core/providers/cpu/cpu_execution_provider.h" +#include "test/test_environment.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -162,7 +163,10 @@ class PlannerTest : public ::testing::Test { public: PlannerTest() - : model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_, nullptr) { + : model_("test", false, DefaultLoggingManager().DefaultLogger()), + graph_(model_.MainGraph()), + tp_("test", 1), + state_(execution_providers_, false, &tp_, nullptr) { std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); in_place_kernel_ = KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build(); @@ -171,8 +175,6 @@ class PlannerTest : public ::testing::Test { execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider)); } - ~PlannerTest() = default; - onnxruntime::NodeArg* Arg(const std::string& name) { auto iter = name_to_arg_.find(name); if (name_to_arg_.end() != iter) return iter->second; diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index 0054573f36..b2f55064de 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -73,7 +73,9 @@ CREATE_INITIALIZER_FUNC(float, TensorProto_DataType_FLOAT, add_float_data) CREATE_INITIALIZER_FUNC(int64_t, TensorProto_DataType_INT64, add_int64_data) // TO DO: Figure out a way to enable it again TEST(CUDAFenceTests, DISABLED_PartOnCPU) { - std::unique_ptr model = onnxruntime::make_unique("test"); + std::unique_ptr model = onnxruntime::make_unique("test", + false, + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); @@ -135,7 +137,7 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) { } TEST(CUDAFenceTests, TileWithInitializer) { - std::unique_ptr model = onnxruntime::make_unique("test"); + std::unique_ptr model = onnxruntime::make_unique("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); @@ -189,7 +191,8 @@ TEST(CUDAFenceTests, TileWithInitializer) { TEST(CUDAFenceTests, TileWithComputedInput) { std::unordered_map domain_to_version; domain_to_version[onnxruntime::kOnnxDomain] = 7; - std::unique_ptr model = onnxruntime::make_unique("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version); + std::unique_ptr model = onnxruntime::make_unique("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index b1b2cb0906..ee67b5b473 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -21,7 +21,7 @@ namespace test { typedef std::vector ArgMap; std::shared_ptr DummyGraphWithClip() { - auto model = std::make_shared("test"); + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); @@ -42,7 +42,7 @@ class ExecutionFrameTest : public ::testing::Test { }; TEST_F(ExecutionFrameTest, TensorAllocationTest) { - onnxruntime::Model model("test"); + onnxruntime::Model model("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model.MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); @@ -112,7 +112,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { TEST_F(ExecutionFrameTest, FeedInDataTest) { onnxruntime::Model model("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), - std::unordered_map{{"", 10}}); + std::unordered_map{{"", 10}}, {}, + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model.MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); @@ -164,7 +165,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { auto xp_type = cpu_xp->Type(); std::unordered_map domain_to_version; domain_to_version[onnxruntime::kOnnxDomain] = 7; - onnxruntime::Model model("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version); + onnxruntime::Model model("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model.MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 22ab613b0c..44f221aa3e 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -37,7 +37,6 @@ #include "test/providers/provider_test_utils.h" #include "test/optimizer/dummy_graph_transformer.h" #include "core/optimizer/rule_based_graph_transformer.h" - #include "gtest/gtest.h" using namespace std; @@ -133,16 +132,17 @@ class InferenceSessionGetGraphWrapper : public InferenceSession { namespace test { static void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, const std::vector& expected_values); -static const std::string MODEL_URI = "testdata/mul_1.onnx"; -static const std::string MODEL_URI_NO_OPSET = "testdata/mul_1.noopset.onnx"; +static constexpr const ORTCHAR_T* MODEL_URI = ORT_TSTR("testdata/mul_1.onnx"); +static constexpr const ORTCHAR_T* MODEL_URI_NO_OPSET = ORT_TSTR("testdata/mul_1.noopset.onnx"); //static const std::string MODEL_URI = "./testdata/squeezenet/model.onnx"; // TODO enable this after we've weights? static void CreateMatMulModel(std::unique_ptr& p_model, ProviderType provider_type) { std::unordered_map domain_to_version; domain_to_version[onnxruntime::kOnnxDomain] = 7; // Generate the input & output def lists - p_model = onnxruntime::make_unique("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version); + std::vector model_specific_functions; + p_model = onnxruntime::make_unique("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, model_specific_functions, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = p_model->MainGraph(); TypeProto tensor_float; @@ -455,11 +455,11 @@ TEST(InferenceSessionTests, ModelMetadata) { so.session_logid = "InferenceSessionTests.ModelMetadata"; InferenceSession session_object{so, &DefaultLoggingManager()}; - string model_uri = "../models/opset8/test_squeezenet/model.onnx"; + auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"); ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; - Status st = onnxruntime::Model::Load(model_uri, p_model); + Status st = onnxruntime::Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(st.IsOK()); const onnxruntime::Graph& graph = p_model->MainGraph(); @@ -1029,7 +1029,7 @@ TEST(InferenceSessionTests, TestOptionalInputs) { } TEST(ExecutionProviderTest, FunctionTest) { - onnxruntime::Model model("graph_1"); + onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; std::vector outputs; @@ -1116,7 +1116,7 @@ TEST(ExecutionProviderTest, FunctionTest) { } TEST(ExecutionProviderTest, FunctionInlineTest) { - onnxruntime::Model model("graph_1"); + onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); ONNX_NAMESPACE::FunctionProto fc_proto; fc_proto.set_name("FC"); diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 981fed9e21..dc9a1d8bf4 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -7,16 +7,17 @@ #include "core/graph/model.h" #include "gtest/gtest.h" #include "test_utils.h" +#include "test/test_environment.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -static const std::string MODEL_FOLDER = "testdata/transform/"; +#define MODEL_FOLDER ORT_TSTR("testdata/transform/") typedef std::vector ArgMap; TEST(TransformerTest, InsertCastGPUTest) { - auto model = std::make_shared("test"); + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float_16; @@ -38,7 +39,7 @@ TEST(TransformerTest, InsertCastGPUTest) { InsertCastTransformer transformer("Test"); bool modified = true; - status = transformer.Apply(graph, modified); + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(status.IsOK()); status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); @@ -64,7 +65,7 @@ TEST(TransformerTest, InsertCastGPUTest) { } TEST(TransformerTest, InsertCastAllCPUTest) { - auto model = std::make_shared("test"); + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float_16; @@ -86,7 +87,7 @@ TEST(TransformerTest, InsertCastAllCPUTest) { InsertCastTransformer transformer("Test"); bool modified = true; - EXPECT_TRUE(transformer.Apply(graph, modified).IsOK()); + EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK()); status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_EQ(graph.NumberOfNodes(), 7); @@ -109,9 +110,9 @@ TEST(TransformerTest, InsertCastAllCPUTest) { // test that when there are 3 Cast ops in a row we remove the correct ones TEST(TransformerTest, ThreeInARowRemoval) { - std::string model_uri = MODEL_FOLDER + "triple-cast.onnx"; + auto model_uri = MODEL_FOLDER ORT_TSTR("triple-cast.onnx"); std::shared_ptr model; - auto status = Model::Load(model_uri, model); + auto status = Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(status.IsOK()) << status; Graph& graph = model->MainGraph(); @@ -123,7 +124,7 @@ TEST(TransformerTest, ThreeInARowRemoval) { InsertCastTransformer transformer("Test"); bool modified = false; - status = transformer.Apply(graph, modified); + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(status.IsOK()) << status; EXPECT_TRUE(modified) << "Transformer should have removed some Cast nodes"; status = graph.Resolve(); diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index 1b9dbffa4f..0a2d08a459 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -8,6 +8,7 @@ #include "core/graph/model.h" #include "gtest/gtest.h" #include "test_utils.h" +#include "test/test_environment.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -74,7 +75,8 @@ TEST(TransformerTest, MemcpyTransformerTest) { std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; auto model = std::make_shared("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version); + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float_type; @@ -111,7 +113,7 @@ TEST(TransformerTest, MemcpyTransformerTest) { MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); bool modified = false; - status = transformer.Apply(graph, modified); + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_TRUE(modified); @@ -128,7 +130,8 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; auto model = std::make_shared("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version); + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TypeProto tensor_float_type; @@ -165,7 +168,7 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); bool modified = false; - status = transformer.Apply(graph, modified); + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_TRUE(modified); @@ -204,7 +207,8 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) { false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version); + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); TensorProto parent_constant(value_tensor); @@ -221,7 +225,8 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) { false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), - subgraph_domain_to_version); + subgraph_domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& subgraph = sub_model->MainGraph(); TensorProto local_constant(value_tensor); @@ -278,7 +283,7 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) { MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); bool modified = false; - status = transformer.Apply(graph, modified); + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_TRUE(modified); } diff --git a/onnxruntime/test/framework/opaque_kernels_test.cc b/onnxruntime/test/framework/opaque_kernels_test.cc index 8f75fe8506..4ff7815279 100644 --- a/onnxruntime/test/framework/opaque_kernels_test.cc +++ b/onnxruntime/test/framework/opaque_kernels_test.cc @@ -298,7 +298,7 @@ TEST_F(OpaqueTypeTests, RunModel) { IOnnxRuntimeOpSchemaRegistryList custom_schema_registries_ = {registry->GetOpschemaRegistry()}; std::unordered_map domain_to_version = {{onnxruntime::kMLDomain, 8}}; - Model model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version); + Model model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index e4f3db04d1..ab64b8f12f 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -14,6 +14,7 @@ #include "core/graph/op.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "gtest/gtest.h" +#include "test/test_environment.h" using namespace ONNX_NAMESPACE; using namespace std; @@ -42,7 +43,7 @@ TEST(SessionStateTest, AddGetKernelTest) { ExecutionProviders execution_providers; SessionState s{execution_providers, true, &tp, nullptr}; - onnxruntime::Model model("graph_1"); + onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; std::vector outputs; @@ -94,10 +95,11 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { const TestParam& param = GetParam(); concurrency::ThreadPool tp{"test", 1}; - std::string model_path = "testdata/optional_inputs_ir" + std::to_string(param.ir_version) + ".onnx"; + std::basic_ostringstream oss; + oss << ORT_TSTR("testdata/optional_inputs_ir") << param.ir_version << ORT_TSTR(".onnx"); Status status; std::shared_ptr model; - ASSERT_TRUE((status = Model::Load(model_path, model)).IsOK()) << status; + ASSERT_TRUE((status = Model::Load(oss.str(), model, nullptr, DefaultLoggingManager().DefaultLogger())).IsOK()) << status; Graph& graph = model->MainGraph(); // take a copy as this gets cleared during session state initialization InitializedTensorSet initializers = graph.GetAllInitializedTensors(); @@ -112,7 +114,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { ASSERT_TRUE(status.IsOK()) << status; SessionState session_state(execution_providers, param.enable_mem_pattern, &tp, nullptr); - SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph, session_state, + SessionStateInitializer session_initializer(param.enable_mem_pattern, oss.str(), graph, session_state, execution_providers, krm); GraphPartitioner partitioner(krm, execution_providers); diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index dd62dab3d5..ac2e0740a8 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -7,6 +7,8 @@ #include "gtest/gtest.h" #include "core/graph/model.h" #include "test/framework/model_builder_utils.h" +#include "test/test_environment.h" + using namespace ONNX_NAMESPACE; using namespace std; @@ -22,7 +24,7 @@ class ShapeInferenceTest : public ::testing::Test { std::unordered_map> name_to_arg_; public: - ShapeInferenceTest() : model_("Test"), node_count_(0) {} + ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {} void Input(const std::string& name, const Type& type) { name_to_arg_[name] = onnxruntime::make_unique(name, &type.value); diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 02923b569a..0bef1b693c 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -3,17 +3,7 @@ #include "core/framework/data_types.h" -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-qualifiers" -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -#include "onnx/defs/schema.h" - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +#include "core/graph/onnx_protobuf.h" #include "core/graph/constants.h" #include "core/framework/op_kernel.h" @@ -293,7 +283,7 @@ class SparseTensorTests : public testing::Test { registry(std::make_shared()), custom_schema_registries_{registry->GetOpschemaRegistry()}, domain_to_version{{onnxruntime::kMLDomain, 10}}, - model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version), + model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version, {}, DefaultLoggingManager().DefaultLogger()), graph(model.MainGraph()) { EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK()); } @@ -306,21 +296,21 @@ class SparseTensorTests : public testing::Test { schema.SinceVersion(10); schemas.push_back(schema); - Action register_kernel = [](CustomRegistry* registry) { + Action register_kernel = [](CustomRegistry* registry2) { auto kernel_def_builder = Op::KernelDef(); kernel_def_builder .SetDomain(onnxruntime::kMLDomain) .SinceVersion(10) .Provider(onnxruntime::kCpuExecutionProvider); KernelCreateFn kernel_create_fn = [](const OpKernelInfo& info) { return new typename Op::OpKernelImpl(info); }; - EXPECT_TRUE(registry->RegisterCustomKernel(kernel_def_builder, kernel_create_fn).IsOK()); + EXPECT_TRUE(registry2->RegisterCustomKernel(kernel_def_builder, kernel_create_fn).IsOK()); }; register_actions.push_back(register_kernel); } void RegisterOps() { EXPECT_TRUE(registry->RegisterOpSet(schemas, onnxruntime::kMLDomain, 10, 11).IsOK()); - for (auto registerop : register_actions) + for (auto& registerop : register_actions) registerop(registry.get()); } diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 280e1337a3..8d85796f30 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -121,7 +121,7 @@ static const bool kSchemasRegistered = RegisterCustomSchemas(); TEST(GraphTraversalTest, ReverseDFS) { ASSERT_TRUE(kSchemasRegistered); - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); /* Case 1: A normal graph. @@ -219,7 +219,7 @@ TEST(GraphTraversalTest, ReverseDFS) { } TEST(ResolvingGraphTest, GraphConstruction_VerifyNoDuplicateName) { - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); EXPECT_EQ("graph_1", graph.Name()); @@ -252,7 +252,7 @@ TEST(ResolvingGraphTest, GraphConstruction_VerifyNoDuplicateName) { } TEST(ResolvingGraphTest, GraphConstruction_VerifyNodeAndOpMatch) { - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -275,7 +275,7 @@ TEST(ResolvingGraphTest, GraphConstruction_VerifyNodeAndOpMatch) { TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) { ASSERT_TRUE(kSchemasRegistered); - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); /* A normal graph. @@ -336,7 +336,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) { EXPECT_TRUE(Model::Save(model, "graph_1.onnx").IsOK()); std::shared_ptr model2; - EXPECT_TRUE(Model::Load("graph_1.onnx", model2).IsOK()); + EXPECT_TRUE(Model::Load(ORT_TSTR("graph_1.onnx"), model2, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); auto model_proto = model.ToProto(); auto model_proto2 = model2->ToProto(); @@ -345,7 +345,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) { // Load the model again to ensure that it's still the right thing. //EXPECT_EQ(Model::Load(model_proto2, &model2), Status::OK()); - model2.reset(new Model(model_proto2)); + model2.reset(new Model(model_proto2, nullptr, DefaultLoggingManager().DefaultLogger())); Graph& graph2 = model2->MainGraph(); for (auto& node : graph2.Nodes()) { auto node_name_to_input_output_iter = expected_node_name_to_input_output_args.find(node.Name()); @@ -368,7 +368,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) { TEST(ResolvingGraphTest, GraphConstruction_CheckInputNodeOrderMaintained) { ASSERT_TRUE(kSchemasRegistered); - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); // node_1 (Identity) node_2 (Identity) @@ -448,7 +448,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckInputNodeOrderMaintained) { TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) { ASSERT_TRUE(kSchemasRegistered); - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::unordered_map map; @@ -541,7 +541,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) ASSERT_TRUE(result) << "Failed to load model from serialized protobuf"; std::shared_ptr p_tmp_model; - auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr); + auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr, DefaultLoggingManager().DefaultLogger()); auto& graph2 = p_tmp_model->MainGraph(); status = graph2.Resolve(); @@ -555,7 +555,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) { ASSERT_TRUE(kSchemasRegistered); - Model model("UnusedInitializerIsIgnored"); + Model model("UnusedInitializerIsIgnored", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -596,7 +596,7 @@ TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) { ASSERT_TRUE(result) << "Failed to load model from serialized protobuf"; std::shared_ptr p_tmp_model; - auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr); + auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr, DefaultLoggingManager().DefaultLogger()); auto& graph2 = p_tmp_model->MainGraph(); status = graph2.Resolve(); @@ -621,7 +621,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsNotAcyclic) { tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); auto& input_arg1 = graph.GetOrCreateNodeArg("node_1_in_1", &tensor_int32); auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32); @@ -643,7 +643,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsNotAcyclic) { } TEST(ResolvingGraphTest, GraphConstruction_OnlyInitializer) { - onnxruntime::Model model("graph"); + onnxruntime::Model model("graph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); ONNX_NAMESPACE::TensorProto weight; @@ -663,7 +663,7 @@ TEST(ResolvingGraphTest, GraphConstruction_OnlyInitializer) { TEST(ResolvingGraphTest, GraphConstruction_TypeInference) { ASSERT_TRUE(kSchemasRegistered); - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); /* Case 1: A normal graph. @@ -725,7 +725,7 @@ TEST(ResolvingGraphTest, GraphConstruction_TypeInference) { EXPECT_TRUE(Model::Save(model, "model_x.onnx").IsOK()); std::shared_ptr loaded_model; - EXPECT_TRUE(Model::Load("model_x.onnx", loaded_model).IsOK()); + EXPECT_TRUE(Model::Load(ORT_TSTR("model_x.onnx"), loaded_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); EXPECT_EQ(2, loaded_model->MainGraph().GetInputs().size()); auto& graph_proto = graph.ToGraphProto(); @@ -740,7 +740,7 @@ TEST(ResolvingGraphTest, GraphConstruction_TypeInference) { TEST(ResolvingGraphTest, ShapeInferenceErrorHandling) { ASSERT_TRUE(kSchemasRegistered); - Model model("graph"); + Model model("graph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto tensor_int32; @@ -765,7 +765,7 @@ TEST(TestAddAttribute, AddTensorAttribute) { .Output(0, "output_1", "docstr for output_1.", "tensor(int64)"); std::vector inputs; std::vector outputs; - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto output_type; output_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); @@ -810,7 +810,7 @@ void AddAttribute(onnxruntime::Node& p_node, const std::string& attr_name, std:: TEST(TypeInferenceTest, TypeAttribute) { std::vector inputs; std::vector outputs; - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", nullptr); outputs.push_back(&output_arg); @@ -834,7 +834,7 @@ TEST(TypeInferenceTest, VariadicOutput) { std::vector outputs; TypeProto tensor_type; tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); auto& X = graph.GetOrCreateNodeArg("X", &tensor_type); inputs.push_back(&X); @@ -851,7 +851,7 @@ TEST(TypeInferenceTest, VariadicOutput) { // test that we prefer the graph input shape for a non-const initializer (initializer with matching graph input) TEST(TypeInferenceTest, NonConstInitializer) { - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto tensor_type_no_shape; @@ -901,7 +901,7 @@ TEST(TypeInferenceTest, NonConstInitializer) { ASSERT_TRUE(model.ToProto().SerializeToString(&s1)); ASSERT_TRUE(model_proto.ParseFromString(s1)); - auto status = onnxruntime::Model::Load(model_proto, p_model, nullptr); + auto status = onnxruntime::Model::Load(model_proto, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(status.IsOK()) << status; auto& graph2 = p_model->MainGraph(); @@ -910,7 +910,7 @@ TEST(TypeInferenceTest, NonConstInitializer) { // Test that Graph::Resolve identifies name-duplication across initializer and node-output-arg TEST(NameResolutionTest, DuplicateName) { - Model model("graph_1"); + Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); ONNX_NAMESPACE::TensorProto weight; @@ -940,7 +940,7 @@ TEST(NameResolutionTest, DuplicateName) { } TEST(GraphUpdateTest, ReplaceInitializedTensor) { - Model model{"GraphUpdateTest"}; + Model model{"GraphUpdateTest", false, DefaultLoggingManager().DefaultLogger()}; auto& graph = model.MainGraph(); const std::string initializer_name = "initializer"; @@ -1011,7 +1011,7 @@ TEST(GraphUpdateTest, ReplaceInitializedTensor) { } TEST(GraphUpdateTest, AddRemoveInitializerHandling) { - Model m{"test_model"}; + Model m{"test_model", false, DefaultLoggingManager().DefaultLogger()}; Graph& graph = m.MainGraph(); auto create_tensor_proto = [](const std::string& name, int32_t value) { diff --git a/onnxruntime/test/ir/onnx_model_test.cc b/onnxruntime/test/ir/onnx_model_test.cc index 5c4dce0452..77b45d7261 100644 --- a/onnxruntime/test/ir/onnx_model_test.cc +++ b/onnxruntime/test/ir/onnx_model_test.cc @@ -7,6 +7,8 @@ #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/graph/op.h" +#include "core/session/onnxruntime_c_api.h" +#include "test/test_environment.h" #include "gtest/gtest.h" using namespace onnxruntime; @@ -47,29 +49,18 @@ static void TestResolve(onnxruntime::Graph& graph) { TEST(ONNXModelsTest, squeeze_net) { // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located std::shared_ptr model; - ASSERT_TRUE(Model::Load("../models/opset8/test_squeezenet/model.onnx", model).IsOK()); + ASSERT_TRUE(Model::Load(ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"), model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); TestResolve(model->MainGraph()); -#ifdef _WIN32 - // wstring version - std::shared_ptr model2; - ASSERT_TRUE(Model::Load(L"../models/opset8/test_squeezenet/model.onnx", model2).IsOK()); - TestResolve(model2->MainGraph()); -#endif } #endif TEST(ONNXModelsTest, non_existing_model) { // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located std::shared_ptr model; - common::Status st = Model::Load("./testdata/non_existing_model_XXXXXX/model.onnx", model); + common::Status st = Model::Load(ORT_TSTR("./testdata/non_existing_model_XXXXXX/model.onnx"), model, nullptr, + DefaultLoggingManager().DefaultLogger()); ASSERT_FALSE(st.IsOK()); ASSERT_EQ(st.Code(), common::NO_SUCHFILE); -#ifdef _WIN32 - // wstring version - std::shared_ptr model2; - ASSERT_FALSE(Model::Load(L"./testdata/non_existing_model_XXXXXX/model.onnx", model2).IsOK()); - ASSERT_EQ(st.Code(), common::NO_SUCHFILE); -#endif } #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS @@ -78,7 +69,7 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) { using ::google::protobuf::io::FileInputStream; using ::google::protobuf::io::ZeroCopyInputStream; int fd; - ASSERT_TRUE(Env::Default().FileOpenRd("../models/opset8/test_bvlc_alexnet/model.onnx", fd).IsOK()); + ASSERT_TRUE(Env::Default().FileOpenRd(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), fd).IsOK()); ASSERT_TRUE(fd > 0); std::unique_ptr raw_input(new FileInputStream(fd)); std::unique_ptr coded_input(new CodedInputStream(raw_input.get())); @@ -92,7 +83,9 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) { ASSERT_TRUE(Env::Default().FileClose(fd).IsOK()); std::shared_ptr model; - ASSERT_TRUE(Model::Load("../models/opset8/test_bvlc_alexnet/model.onnx", model).IsOK()); + ASSERT_TRUE(Model::Load(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), model, nullptr, + DefaultLoggingManager().DefaultLogger()) + .IsOK()); // Check the graph input/output/value_info should have the same size as specified in the model file. EXPECT_EQ(model_proto.graph().value_info_size(), model->MainGraph().GetValueInfo().size()); @@ -101,21 +94,23 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) { TestResolve(model->MainGraph()); } -class ONNXModelsTest : public ::testing::TestWithParam { +class ONNXModelsTest : public ::testing::TestWithParam { // You can implement all the usual fixture class members here. // To access the test parameter, call GetParam() from class // TestWithParam. public: - std::string GetModelFileName() const { - std::ostringstream oss; - oss << "../models/opset7/test_" << GetParam() << "/model.onnx"; + std::basic_string GetModelFileName() const { + std::basic_ostringstream oss; + oss << ORT_TSTR("../models/opset7/test_") << GetParam() << ORT_TSTR("/model.onnx"); return oss.str(); } }; TEST_P(ONNXModelsTest, LoadFromFile) { std::shared_ptr model; - ASSERT_TRUE(Model::Load(GetModelFileName(), model).IsOK()); + ASSERT_TRUE(Model::Load(GetModelFileName(), model, nullptr, + DefaultLoggingManager().DefaultLogger()) + .IsOK()); TestResolve(model->MainGraph()); } @@ -137,18 +132,20 @@ TEST_P(ONNXModelsTest, LoadFromProtobuf) { ASSERT_TRUE(result); ASSERT_TRUE(Env::Default().FileClose(fd).IsOK()); std::shared_ptr model; - ASSERT_TRUE(Model::Load(std::move(model_proto), model).IsOK()); + ASSERT_TRUE(Model::Load(std::move(model_proto), model, nullptr, + DefaultLoggingManager().DefaultLogger()) + .IsOK()); TestResolve(model->MainGraph()); } #ifndef DISABLE_CONTRIB_OPS INSTANTIATE_TEST_CASE_P(ONNXModelsTests, ONNXModelsTest, - ::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "tiny_yolov2", "vgg19", "zfnet512")); + ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("tiny_yolov2"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); #else INSTANTIATE_TEST_CASE_P(ONNXModelsTests, ONNXModelsTest, - ::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "vgg19", "zfnet512")); + ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); #endif #endif @@ -159,7 +156,9 @@ INSTANTIATE_TEST_CASE_P(ONNXModelsTests, // for Graph::Resolve to succeed when processing the subgraph. TEST(ONNXModelsTest, TestIRv4NonInputInitializers) { std::shared_ptr model; - ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK()); + ASSERT_TRUE(Model::Load(ORT_TSTR("testdata/subgraph_implicit_input_from_initializer.onnx"), model, nullptr, + DefaultLoggingManager().DefaultLogger()) + .IsOK()); EXPECT_TRUE(model->MainGraph().Resolve().IsOK()); } @@ -170,7 +169,8 @@ TEST(ONNXModelsTest, TestIRv4NonInputInitializers) { TEST(ONNXModelsTest, TestModelsWithAnOpContainingAFunctionBody) { std::shared_ptr model; - auto status = Model::Load("testdata/model_containing_op_with_function_body.onnx", model); + auto status = Model::Load(ORT_TSTR("testdata/model_containing_op_with_function_body.onnx"), model, nullptr, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(status.IsOK()) << status; status = model->MainGraph().Resolve(); diff --git a/onnxruntime/test/ir/op_test.cc b/onnxruntime/test/ir/op_test.cc index bfe9fb7368..557cb81242 100644 --- a/onnxruntime/test/ir/op_test.cc +++ b/onnxruntime/test/ir/op_test.cc @@ -8,6 +8,7 @@ #include "core/graph/schema_registry.h" #include "core/graph/model.h" #include "core/graph/op.h" +#include "test/test_environment.h" using namespace ONNX_NAMESPACE; @@ -31,7 +32,7 @@ TEST(FormalParamTest, Success) { } TEST(FeatureVectorizerTest, TraditionalMlOpTest) { - Model model("traditionalMl"); + Model model("traditionalMl", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); // Case: A traditional ml graph. diff --git a/onnxruntime/test/ir/utils_test.cc b/onnxruntime/test/ir/utils_test.cc index b5efcba53a..b35efcbf4e 100644 --- a/onnxruntime/test/ir/utils_test.cc +++ b/onnxruntime/test/ir/utils_test.cc @@ -6,6 +6,8 @@ #include "core/graph/graph_utils.h" #include "core/graph/model.h" +#include "test/test_environment.h" + using ONNX_NAMESPACE::Utils::DataTypeUtils; using namespace ONNX_NAMESPACE; @@ -71,7 +73,8 @@ static GraphProto CreateNodeRemovalSubgraph(const std::string& new_output_name = std::string suffix = add_second_level ? ".top" : ".bottom"; std::string constant_output_name = (new_output_name.empty() ? "constant_in_0" + suffix : new_output_name); - Model model("CreateNodeRemovalSubgraph:" + constant_output_name); + Model model("CreateNodeRemovalSubgraph:" + constant_output_name, false, + DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto float_scalar_tensor; @@ -198,7 +201,8 @@ static void CheckNodeRemovalSubgraphUpdate(const std::string& new_name, const Gr } static void UpdateSubgraphWhenRemovingNode(bool include_nested = false) { - Model model(std::string("UpdateSubgraphWhenRemovingNode") + (include_nested ? ":Nested" : ":SingleLevel")); + Model model(std::string("UpdateSubgraphWhenRemovingNode") + (include_nested ? ":Nested" : ":SingleLevel"), + false, DefaultLoggingManager().DefaultLogger()); CreateNodeRemovalGraph(model, true, include_nested); @@ -232,13 +236,15 @@ TEST(GraphUtils, UpdateNestedSubgraphWhenRemovingNode) { // we can't remove a node if it is used as an implicit input in a subgraph, and changing the implicit input name // will result with in a clash with an existing node in the subgraph static void DontRemoveNodeIfItWillBreakSubgraph(bool test_nested = false) { - Model model(std::string("DontRemoveNodeIfItWillBreakSubgraph") + (test_nested ? ":Nested" : ":SingleLevel")); + Model model(std::string("DontRemoveNodeIfItWillBreakSubgraph") + (test_nested ? ":Nested" : ":SingleLevel"), + false, DefaultLoggingManager().DefaultLogger()); CreateNodeRemovalGraph(model, false, test_nested); auto& graph = model.MainGraph(); auto& node_to_remove = *graph.GetNode(1); - ASSERT_FALSE(graph_utils::CanRemoveNode(graph, node_to_remove)); + ASSERT_FALSE(graph_utils::CanRemoveNode(graph, node_to_remove, + DefaultLoggingManager().DefaultLogger())); } TEST(GraphUtils, DontRemoveNodeIfItWillBreakSubgraph) { @@ -253,7 +259,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) { // Create a graph with 5 Id nodes. The graph structure is as follows: Id0 ( Id1 Id2 ( Id3 Id4 ) ). // First we remove Id2, which leads to: Id0 ( Id1 Id4 Id5 ). // Then we remove Id1, which leads to: Id2 Id4 Id5, being fed the initializer. - Model model("MultiEdgeRemovalGraph"); + Model model("MultiEdgeRemovalGraph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto float_tensor; @@ -303,7 +309,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) { } TEST(GraphUtils, TestMultiOutputRemoveNode) { - Model model("MultiOutputRemovalGraph"); + Model model("MultiOutputRemovalGraph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto float_tensor; @@ -341,14 +347,17 @@ TEST(GraphUtils, TestMultiOutputRemoveNode) { // Try to remove do_0, which should return false // because both outputs are consumed by downstream Operators. - ASSERT_FALSE(graph_utils::CanRemoveNode(graph, *nodes[0])); + ASSERT_FALSE(graph_utils::CanRemoveNode(graph, *nodes[0], + DefaultLoggingManager().DefaultLogger())); // Try removing do_0 after removing id_2, which should return true // because it now has exactly one output consumed by downstream Operators. - ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[1])); + ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[1], + DefaultLoggingManager().DefaultLogger())); ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[1])); ASSERT_FALSE(graph_utils::IsOutputUsed(*nodes[0], 0)); - ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[0])); + ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[0], + DefaultLoggingManager().DefaultLogger())); ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[0])); } diff --git a/onnxruntime/test/onnx/microbenchmark/main.cc b/onnxruntime/test/onnx/microbenchmark/main.cc index 6682e6db70..8712e968d8 100644 --- a/onnxruntime/test/onnx/microbenchmark/main.cc +++ b/onnxruntime/test/onnx/microbenchmark/main.cc @@ -28,7 +28,7 @@ BENCHMARK(BM_CPUAllocator)->Arg(4)->Arg(sizeof(Tensor)); static void BM_ResolveGraph(benchmark::State& state) { std::shared_ptr model_copy; - auto st = onnxruntime::Model::Load("../models/opset8/test_tiny_yolov2/model.onnx", model_copy); + auto st = onnxruntime::Model::Load(ORT_TSTR("../models/opset8/test_tiny_yolov2/model.onnx"), model_copy); if (!st.IsOK()) { printf("Parse model failed: %s", st.ErrorMessage().c_str()); abort(); diff --git a/onnxruntime/test/opaque_api/test_opaque_api.cc b/onnxruntime/test/opaque_api/test_opaque_api.cc index f7301c58b2..3113663b21 100644 --- a/onnxruntime/test/opaque_api/test_opaque_api.cc +++ b/onnxruntime/test/opaque_api/test_opaque_api.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/framework/data_types.h" #include "core/framework/execution_providers.h" #include "core/framework/kernel_registry.h" @@ -161,8 +162,7 @@ namespace test { std::string CreateModel() { RegisterCustomKernel(); - - Model model("ModelWithOpaque", false); + Model model("ModelWithOpaque", false, logging::LoggingManager::DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; diff --git a/onnxruntime/test/optimizer/dummy_graph_transformer.h b/onnxruntime/test/optimizer/dummy_graph_transformer.h index 8116d0ba5f..5e82b251b0 100644 --- a/onnxruntime/test/optimizer/dummy_graph_transformer.h +++ b/onnxruntime/test/optimizer/dummy_graph_transformer.h @@ -20,7 +20,7 @@ class DummyGraphTransformer : public GraphTransformer { private: mutable bool transformer_invoked_; - Status ApplyImpl(Graph& /*graph*/, bool& /*modified*/, int /*graph_level*/) const override { + Status ApplyImpl(Graph& /*graph*/, bool& /*modified*/, int /*graph_level*/, const logging::Logger&) const override { transformer_invoked_ = true; return Status::OK(); } @@ -43,11 +43,12 @@ class DummyRewriteRule : public RewriteRule { private: mutable bool rewrite_rule_invoked_; - bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/) const override { + bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override { return true; } - Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) const override { + Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, + const logging::Logger& /*logger*/) const override { rewrite_rule_invoked_ = true; return Status::OK(); } diff --git a/onnxruntime/test/optimizer/free_dimension_override_test.cc b/onnxruntime/test/optimizer/free_dimension_override_test.cc index 37075ab15f..0961650a45 100644 --- a/onnxruntime/test/optimizer/free_dimension_override_test.cc +++ b/onnxruntime/test/optimizer/free_dimension_override_test.cc @@ -18,10 +18,12 @@ namespace onnxruntime { namespace test { TEST(FreeDimensionOverrideTransformerTest, Test) { - string model_uri = "testdata/abs_free_dimensions.onnx"; + auto model_uri = ORT_TSTR("testdata/abs_free_dimensions.onnx"); std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, + DefaultLoggingManager().DefaultLogger()) + .IsOK()); Graph& graph = model->MainGraph(); // The model's input shape has two free dimensions, which have the denotation of DATA_BATCH @@ -38,7 +40,8 @@ TEST(FreeDimensionOverrideTransformerTest, Test) { onnxruntime::GraphTransformerManager graph_transformation_mgr(5); graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level1); - graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1); + graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, + DefaultLoggingManager().DefaultLogger()); // Verify that the shape of the input graph has the correct values diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index d71c400d32..976840549c 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -43,12 +43,12 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -static const std::string MODEL_FOLDER = "testdata/transform/"; +#define MODEL_FOLDER ORT_TSTR("testdata/transform/") TEST(GraphTransformationTests, IdentityElimination) { - string model_uri = MODEL_FOLDER + "abs-id-max.onnx"; + auto model_uri = MODEL_FOLDER "abs-id-max.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Identity"] == 1); @@ -57,16 +57,16 @@ TEST(GraphTransformationTests, IdentityElimination) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Identity"] == 0); } TEST(GraphTransformationTests, DropoutElimination) { - string model_uri = MODEL_FOLDER + "dropout.onnx"; + auto model_uri = MODEL_FOLDER "dropout.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Identity"] == 5); @@ -76,7 +76,7 @@ TEST(GraphTransformationTests, DropoutElimination) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); // Of the 6 Dropout nodes in the graph, all but the ones named `d1` and `d6` should have been removed. @@ -88,11 +88,11 @@ TEST(GraphTransformationTests, DropoutElimination) { } TEST(GraphTransformationTests, SliceElimination) { - std::vector model_names = {"slice-v1-elim.onnx", "slice-v11-elim.onnx"}; + std::vector > model_names = {ORT_TSTR("slice-v1-elim.onnx"), ORT_TSTR("slice-v11-elim.onnx")}; for (const auto& model_name : model_names) { - string model_uri = MODEL_FOLDER + model_name; + auto model_uri = MODEL_FOLDER + model_name; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); int initial_slice_num = op_to_count["Slice"]; @@ -101,7 +101,7 @@ TEST(GraphTransformationTests, SliceElimination) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); // Only one Slice operator is redundant and is removed. @@ -110,9 +110,9 @@ TEST(GraphTransformationTests, SliceElimination) { } TEST(GraphTransformationTests, ConstantFolding) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Unsqueeze"] == 2); @@ -120,7 +120,7 @@ TEST(GraphTransformationTests, ConstantFolding) { onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); @@ -138,7 +138,7 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { auto create_subgraph = [&](GraphProto& graph_proto) { // create subgraph that has an Add node to add a local and parent graph initializer - Model model("ConstantFoldingSubgraphTest_subgraph"); + Model model("ConstantFoldingSubgraphTest_subgraph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TensorProto local_constant(value_tensor); @@ -160,7 +160,7 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { graph_proto = graph.ToGraphProto(); }; - Model model("ConstantFoldingSubgraphTest_main_graph"); + Model model("ConstantFoldingSubgraphTest_main_graph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); // add initializer at parent level @@ -192,7 +192,7 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1); + status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(status.IsOK()) << status; op_to_count = CountOpsInGraph(graph); @@ -201,9 +201,9 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { } TEST(GraphTransformationTests, ShapeToInitializer) { - string model_uri = MODEL_FOLDER + "shape-add.onnx"; + auto model_uri = MODEL_FOLDER "shape-add.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Shape"] == 4); @@ -213,7 +213,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); // Two of the Shapes are not eliminated because: @@ -224,7 +224,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) { // Check transformations in the case of a subgraph with constant inputs. TEST(GraphTransformationTests, SubgraphWithConstantInputs) { - string model_uri = MODEL_FOLDER + "constant-subgraph.onnx"; + auto model_uri = MODEL_FOLDER "constant-subgraph.onnx"; SessionOptions so; so.graph_optimization_level = TransformerLevel::Level2; @@ -233,7 +233,7 @@ TEST(GraphTransformationTests, SubgraphWithConstantInputs) { ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -247,10 +247,10 @@ TEST(GraphTransformationTests, SubgraphWithConstantInputs) { } TEST(GraphTransformationTests, FuseConvBNNoBias) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-no-bias.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); std::string bn_output_name; @@ -268,7 +268,7 @@ TEST(GraphTransformationTests, FuseConvBNNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); @@ -282,10 +282,10 @@ TEST(GraphTransformationTests, FuseConvBNNoBias) { } TEST(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-no-bias.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); // add an optional output to the BN node. should not fuse if this is present @@ -302,21 +302,21 @@ TEST(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 1); } TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { - std::vector test_models = {"fusion/fuse-conv-bn-mul-add-unsqueeze.onnx", - "fusion/fuse-conv-bn-mul-add-unsqueeze.negative_axes.onnx", - "fusion/fuse-conv-bn-mul-add-unsqueeze-no-bias.onnx"}; + std::vector > test_models = {ORT_TSTR("fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"), + ORT_TSTR("fusion/fuse-conv-bn-mul-add-unsqueeze.negative_axes.onnx"), + ORT_TSTR("fusion/fuse-conv-bn-mul-add-unsqueeze-no-bias.onnx")}; for (const auto& model : test_models) { - string model_uri = MODEL_FOLDER + model; + auto model_uri = MODEL_FOLDER + model; std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model)); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger())); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -327,7 +327,7 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); @@ -339,16 +339,16 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { #ifndef DISABLE_CONTRIB_OPS TEST(GraphTransformationTests, FuseConvActivation) { - std::unordered_map model_to_op_name{{"fusion/conv_relu.onnx", "Relu"}, - {"fusion/conv_clip.onnx", "Clip"}, - {"fusion/conv_sigmoid.onnx", "Sigmoid"}, - {"fusion/conv_tanh.onnx", "Tanh"}, - {"fusion/conv_leakyrelu.onnx", "LeakyRelu"}}; + std::unordered_map, std::string> model_to_op_name{{ORT_TSTR("fusion/conv_relu.onnx"), "Relu"}, + {ORT_TSTR("fusion/conv_clip.onnx"), "Clip"}, + {ORT_TSTR("fusion/conv_sigmoid.onnx"), "Sigmoid"}, + {ORT_TSTR("fusion/conv_tanh.onnx"), "Tanh"}, + {ORT_TSTR("fusion/conv_leakyrelu.onnx"), "LeakyRelu"}}; for (const auto& model : model_to_op_name) { - std::string model_uri = MODEL_FOLDER + model.first; + auto model_uri = MODEL_FOLDER + model.first; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); @@ -357,7 +357,7 @@ TEST(GraphTransformationTests, FuseConvActivation) { // Apply transformer onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count[model.second] == 0); @@ -365,9 +365,9 @@ TEST(GraphTransformationTests, FuseConvActivation) { } TEST(GraphTransformationTests, FuseConvClip11Activation) { - std::string model_uri = MODEL_FOLDER + "fusion/conv_clip11.onnx"; + auto model_uri = MODEL_FOLDER "fusion/conv_clip11.onnx"; std::shared_ptr p_model; - auto status = Model::Load(model_uri, p_model); + auto status = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(status.IsOK()) << status; Graph& graph = p_model->MainGraph(); @@ -377,7 +377,7 @@ TEST(GraphTransformationTests, FuseConvClip11Activation) { // Apply transformer onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); ASSERT_EQ(op_to_count["Clip"], 1); @@ -404,10 +404,10 @@ TEST(GraphTransformationTests, FuseConvClip11Activation) { #endif TEST(GraphTransformationTests, FuseConvMulNoBias) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-mul-no-bias.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-mul-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -416,7 +416,7 @@ TEST(GraphTransformationTests, FuseConvMulNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Mul"] == 0); @@ -424,10 +424,10 @@ TEST(GraphTransformationTests, FuseConvMulNoBias) { } TEST(GraphTransformationTests, FuseConvAddNoBias) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-add-no-bias.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-add-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -436,7 +436,7 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0); @@ -446,10 +446,10 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) { // if IR version is 4 or higher the weights can be overridden if there's a matching graph input. // check that we don't fuse if that is the case TEST(GraphTransformationTests, NegativeFuseConvAddNoBias) { - string model_uri = MODEL_FOLDER + "fusion/negative-fuse-conv-add-no-bias.onnx"; + auto model_uri = MODEL_FOLDER "fusion/negative-fuse-conv-add-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -458,7 +458,7 @@ TEST(GraphTransformationTests, NegativeFuseConvAddNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); // Nodes are not fused because the weights to conv/add are not constants (they appear in the graph inputs). // Unsqueeze is also not eliminated as the initializer that is its input is also not constant @@ -468,10 +468,10 @@ TEST(GraphTransformationTests, NegativeFuseConvAddNoBias) { } TEST(GraphTransformationTests, FuseConvAddMul3D) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-add-mul-3d.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-add-mul-3d.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -480,7 +480,7 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0); @@ -488,10 +488,10 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) { } TEST(GraphTransformationTests, FuseConvAddMul3D_2) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-add-mul-3d-2.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-add-mul-3d-2.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -500,7 +500,7 @@ TEST(GraphTransformationTests, FuseConvAddMul3D_2) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0); @@ -508,15 +508,15 @@ TEST(GraphTransformationTests, FuseConvAddMul3D_2) { } TEST(GraphTransformationTests, MatMulAddFusion_two_input) { - string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx"; + auto model_uri = MODEL_FOLDER "matmul_add_fusion/2Input/model.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["MatMul"] == 0); @@ -525,15 +525,15 @@ TEST(GraphTransformationTests, MatMulAddFusion_two_input) { } TEST(GraphTransformationTests, MatMulAddFusion_three_input) { - string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/model.onnx"; + auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/model.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["MatMul"] == 0); @@ -543,15 +543,15 @@ TEST(GraphTransformationTests, MatMulAddFusion_three_input) { #ifndef DISABLE_CONTRIB_OPS TEST(GraphTransformationTests, Gemm_Relu_three_input) { - string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/gemm_relu.onnx"; + auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); std::map op_to_count1 = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()).IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Relu"] == 0); @@ -559,7 +559,7 @@ TEST(GraphTransformationTests, Gemm_Relu_three_input) { #endif TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { - string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-add-mul-float16.onnx"; + auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-add-mul-float16.onnx"; SessionOptions so; so.session_logid = "GraphTransformationTests.LoadModelToTransform"; @@ -567,7 +567,7 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformerL1"); rule_transformer_L1->Register(onnxruntime::make_unique()); @@ -617,7 +617,7 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { TEST(GraphTransformationTests, ReluClip6Fusion) { // Clip op schema changed for opset version 11. Until Clip op is updated in ORT hard coding this model to use // older opset. - Model model("ReluClip6Fusion", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, {}); + Model model("ReluClip6Fusion", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, {}, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -669,7 +669,7 @@ TEST(GraphTransformationTests, ReluClip6Fusion) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Relu"] == 0); @@ -685,7 +685,7 @@ TEST(GraphTransformationTests, ReluClip6Fusion) { // test handling of Clip 11 TEST(GraphTransformationTests, ReluClip11Fusion) { - Model model("ReluClip6Fusion"); //, true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 11}}, {}); + Model model("ReluClip6Fusion", false, DefaultLoggingManager().DefaultLogger()); //, true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 11}}, {}); auto& graph = model.MainGraph(); std::vector inputs; @@ -757,7 +757,7 @@ TEST(GraphTransformationTests, ReluClip11Fusion) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1); + status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(status.IsOK()) << status; op_to_count = CountOpsInGraph(graph); @@ -805,14 +805,14 @@ TEST(GraphTransformationTests, ReluClip11Fusion) { #ifndef DISABLE_CONTRIB_OPS TEST(GraphTransformationTests, GeluFusionTest) { - string model_uri = MODEL_FOLDER + "fusion/gelu.onnx"; + auto model_uri = MODEL_FOLDER "fusion/gelu.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -824,15 +824,15 @@ TEST(GraphTransformationTests, GeluFusionTest) { } TEST(GraphTransformationTests, AddGeluFusionTest) { - string model_uri = MODEL_FOLDER + "fusion/add_gelu_fusion.onnx"; + auto model_uri = MODEL_FOLDER "fusion/add_gelu_fusion.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Div"] == 0); @@ -844,14 +844,14 @@ TEST(GraphTransformationTests, AddGeluFusionTest) { } TEST(GraphTransformationTests, LayerNormFusionTest) { - string model_uri = MODEL_FOLDER + "fusion/layer_norm.onnx"; + auto model_uri = MODEL_FOLDER "fusion/layer_norm.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -880,14 +880,14 @@ TEST(GraphTransformationTests, LayerNormFusionTest) { } TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) { - string model_uri = MODEL_FOLDER + "fusion/layer_norm_sub_dup.onnx"; + auto model_uri = MODEL_FOLDER "fusion/layer_norm_sub_dup.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -916,15 +916,15 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) { } TEST(GraphTransformationTests, SkipLayerNormFusionTest) { - string model_uri = MODEL_FOLDER + "fusion/skip_layer_norm.onnx"; + auto model_uri = MODEL_FOLDER "fusion/skip_layer_norm.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 0fd8746075..3eea1bf0dd 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -170,7 +170,8 @@ void NchwcOptimizerTester(const std::function& bu // Build the model for this test. std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = opset_version; - Model model("nchwc", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version); + Model model("nchwc", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + {}, DefaultLoggingManager().DefaultLogger()); NchwcTestHelper helper(model.MainGraph()); build_test_case(helper); ASSERT_TRUE(model.MainGraph().Resolve().IsOK()); diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 097ba35821..ed70cc4b9e 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -26,7 +26,7 @@ namespace test { static const std::string MODEL_FOLDER = "testdata/transform/"; TEST(OptimizerTest, Basic) { - Model model("OptimizerBasic"); + Model model("OptimizerBasic", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); const int tensor_dim = 10; @@ -67,7 +67,7 @@ TEST(OptimizerTest, Basic) { OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set); std::vector fetch_mlvalue_idxs{info.GetMLValueIndex("out")}; OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); - const logging::Logger& logger = ::onnxruntime::test::DefaultLoggingManager().DefaultLogger(); + const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); for (auto& node : graph.Nodes()) { auto* kernel = info.GetKernel(node.Index()); diff --git a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc index c98d4d76bb..5b5d7cb33f 100644 --- a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc +++ b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc @@ -18,10 +18,12 @@ namespace onnxruntime { namespace test { TEST(RuleBasedGraphTransformerTest, TestCompatibleProviders) { - string model_uri = "testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; + auto model_uri = ORT_TSTR("testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"); std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, + DefaultLoggingManager().DefaultLogger()) + .IsOK()); Graph& graph = model->MainGraph(); // Create rule based transformer with a dummy rewrite rule and register it with Cuda as compatible provider @@ -44,7 +46,8 @@ TEST(RuleBasedGraphTransformerTest, TestCompatibleProviders) { graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level2); graph_transformation_mgr.Register(std::move(graph_transformer1), TransformerLevel::Level2); - graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, + DefaultLoggingManager().DefaultLogger()); // Validate transformer registered with CUDA as compatible provider is not called. ASSERT_FALSE(dummy_rule_ptr->IsRewriteRuleInvoked()); diff --git a/onnxruntime/test/providers/cpu/controlflow/if_test.cc b/onnxruntime/test/providers/cpu/controlflow/if_test.cc index 7ccc15dbf1..43b405b610 100644 --- a/onnxruntime/test/providers/cpu/controlflow/if_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/if_test.cc @@ -147,7 +147,7 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(bool then_branch, const R bool include_dim_values = options.include_dim_values_in_subgraph; bool sym_dim_zero = options.symbolic_dim_value_in_main_graph == 0; - Model model(then_branch ? "If_then" : "If_else"); + Model model(then_branch ? "If_then" : "If_else", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index bede6451fe..c0a8d913a8 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -97,7 +97,7 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options bool is_cond_1d = options.subgraph_cond_1d_tensor; bool is_iter_num_1d = options.subgraph_iter_num_1d_tensor; - Model model("Loop subgraph"); + Model model("Loop subgraph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -474,7 +474,7 @@ TEST(Loop, ZeroIterations) { TEST(Loop, InfiniteLoopTermination) { auto create_subgraph = [](const RunOptions&) { - Model model("Infinite Loop subgraph"); + Model model("Infinite Loop subgraph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -632,7 +632,7 @@ TEST(Loop, SubgraphInputShadowsOuterScopeValue) { TEST(Loop, Opset11WithNoVariadicInputsAndOutputs) { auto create_subgraph = []() { - Model model("Loop opset 11 op body graph"); + Model model("Loop opset 11 op body graph", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 51263e253c..a34fcea6ab 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -270,7 +270,7 @@ static void RunTest_v8(const std::string test_name, int64_t batch_size, int64_t OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& failure_message = "") { // create model that will be used to initialize subgraph. currently there's no direct way to create a Graph instance. - Model model(test_name); + Model model(test_name, false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); CreateSubgraph(graph, options, options.add_bad_shape ? failure_message : ""); auto& proto = graph.ToGraphProto(); @@ -332,7 +332,7 @@ static void RunTest_v9(const std::string test_name, int64_t sequence_len, int64_ OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& failure_message = "") { // create model that will be used to initialize subgraph. currently there's no direct way to create a Graph instance. - Model model(test_name); + Model model(test_name, false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); CreateSubgraph(graph, options, options.add_bad_shape ? failure_message : ""); auto& proto = graph.ToGraphProto(); @@ -860,7 +860,7 @@ TEST(Scan9, TransposeOutput) { TEST(Scan9, TransposeOutputDim2) { // Construct scan body subgraph with 1 scan inputs, 1 scan outputs // scan-in-1 => scan-out-1 - Model model("ScanBody"); + Model model("ScanBody", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto float_tensor; @@ -1028,7 +1028,7 @@ void MixedTypeInputs(bool is_v8) { // state-in-2 => scan-out-2 // scan-in-2 => state-out-2 - Model model("ScanBody"); + Model model("ScanBody", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto float_tensor; @@ -1096,7 +1096,7 @@ TEST_8_AND_9(MixedTypeInputs); // create a subgraph that will have unknown dimensions in both the loop state variable and output // after shape inferencing. void UnknownDimInSubgraphOutput(bool is_v8, bool mixed_execution_providers = false) { - Model model("ScanBody"); + Model model("ScanBody", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); TypeProto float_tensor; diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index 126ecad221..42b0e23c1c 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -12,6 +12,7 @@ #include "core/framework/utils.h" #include "core/framework/path_lib.h" #include +#include "test/test_environment.h" namespace onnxruntime { namespace { @@ -21,6 +22,8 @@ void PutAllNodesOnOneProvider(Graph& graph, const std::string& provider_type) { } } } // namespace + +namespace test { TEST(MemcpyTest, copy1) { concurrency::ThreadPool tp{"test", 1}; @@ -39,7 +42,7 @@ TEST(MemcpyTest, copy1) { const bool result = mp.ParseFromZeroCopyStream(&zero_copy_input) && model_istream.eof(); ASSERT_TRUE(result); - Model model(mp); + Model model(mp, nullptr, DefaultLoggingManager().DefaultLogger()); st = model.MainGraph().Resolve(); ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); PutAllNodesOnOneProvider(model.MainGraph(), onnxruntime::kCpuExecutionProvider); @@ -61,4 +64,5 @@ TEST(MemcpyTest, copy1) { st = utils::CopyOneInputAcrossDevices(s, "X", input, output); ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); } +} // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 330e247ef7..857857daf9 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -398,7 +398,10 @@ std::unique_ptr OpTester::BuildGraph() { std::unordered_map domain_to_version; domain_to_version[domain_] = opset_version_; auto p_model = onnxruntime::make_unique("test", false, ModelMetaData(), - custom_schema_registries_, domain_to_version); + custom_schema_registries_, + domain_to_version, + std::vector{}, + DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = p_model->MainGraph(); AddNodes(graph, node_input_defs, output_defs, add_attribute_funcs_); diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index e3274200fe..1d51175cb0 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -25,7 +25,7 @@ void VerifyOutputs(const std::vector& fetches, const std::vector inputs; std::vector outputs; @@ -104,7 +104,7 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) { } TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { - onnxruntime::Model model("graph_1"); + onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; std::vector outputs; diff --git a/onnxruntime/test/util/test_environment.cc b/onnxruntime/test/util/test_environment.cc index 73400a5892..2d6d074d16 100644 --- a/onnxruntime/test/util/test_environment.cc +++ b/onnxruntime/test/util/test_environment.cc @@ -43,7 +43,7 @@ TestEnvironment::TestEnvironment(int argc, char** argv, bool create_default_logg &default_logger_id); // make sure default logging manager exists and is working - auto logger = ::onnxruntime::test::DefaultLoggingManager().DefaultLogger(); + auto logger = DefaultLoggingManager().DefaultLogger(); LOGS(logger, VERBOSE) << "Logging manager initialized."; }