Avoid using the default logger in the graph lib and optimizers (#2361)

1. Use the session logger if it is available.
2. Don't disable warning 4100 globally. We should fix the warnings instead of disabling it.
This commit is contained in:
Changming Sun 2019-11-14 13:23:28 -08:00 committed by GitHub
parent b15e43a541
commit 109b3cb450
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
112 changed files with 614 additions and 556 deletions

View file

@ -183,9 +183,7 @@ if (MSVC)
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib for gtest" FORCE)
endif()
#Always enable exception handling, even for Windows ARM
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
#Disable 4100 globally. Too many this kind errors in protobuf
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4100")
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
if (NOT onnxruntime_USE_CUDA)
SET (CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Gw /GL")
SET (CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /Gw /GL")

View file

@ -231,11 +231,11 @@ if (onnxruntime_USE_TENSORRT)
target_sources(getSupportedAPITest PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/getopt.cc)
target_include_directories(onnx2trt PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/include)
target_include_directories(getSupportedAPITest PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/include)
target_compile_options(nvonnxparser_static PRIVATE /FIio.h)
target_compile_options(nvonnxparser PRIVATE /FIio.h)
target_compile_options(trt_onnxify PRIVATE /FIio.h)
target_compile_options(onnx2trt PRIVATE /FIio.h)
target_compile_options(getSupportedAPITest PRIVATE /FIio.h)
target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100)
target_compile_options(nvonnxparser PRIVATE /FIio.h /wd4100)
target_compile_options(trt_onnxify PRIVATE /FIio.h /wd4100)
target_compile_options(onnx2trt PRIVATE /FIio.h /wd4100)
target_compile_options(getSupportedAPITest PRIVATE /FIio.h /wd4100)
endif()
include_directories(${ONNXRUNTIME_ROOT}/../cmake/external/onnx-tensorrt)
include_directories(${TENSORRT_INCLUDE_DIR})

View file

@ -478,7 +478,7 @@ endif()
add_library(onnx_test_data_proto ${TEST_SRC_DIR}/proto/tml.proto)
if(WIN32)
target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456")
target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456" "/wd4100")
endif()
add_dependencies(onnx_test_data_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})

View file

@ -7,6 +7,7 @@
#include "gsl/gsl"
#include "core/common/status.h"
#include "core/common/logging/logging.h"
#include "core/framework/tensor.h"
#include "core/framework/func_api.h"
#include "core/framework/data_transfer.h"
@ -153,10 +154,19 @@ class IExecutionProvider {
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_node,
std::string& dll_path);
void SetLogger(const logging::Logger* logger) {
logger_ = logger;
}
const logging::Logger* GetLogger() const {
return logger_;
}
private:
const std::string type_;
AllocatorMap allocators_;
//It will be set when this object is registered to a session
const logging::Logger* logger_ = nullptr;
// convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time
// contains the same instances as allocators_
std::vector<gsl::not_null<const IAllocator*>> allocator_list_;

View file

@ -37,5 +37,6 @@ Create a new Function instance.
@param customized_func the IndexedSubGraph to use for the Function.
*/
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
std::unique_ptr<IndexedSubGraph> customized_func);
std::unique_ptr<IndexedSubGraph> customized_func,
const logging::Logger& logger);
} // namespace onnxruntime

View file

@ -13,6 +13,7 @@
#include "core/common/common.h"
#include "core/common/const_pointer_container.h"
#include "core/common/status.h"
#include "core/common/logging/logging.h"
#include "core/graph/basic_types.h"
#include "core/graph/constants.h"
#include "core/graph/graph_nodes.h"
@ -822,6 +823,7 @@ class Graph {
const std::unordered_map<std::string, int>& domain_to_version,
Version ir_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
const logging::Logger& logger,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions = {});
// internal use by the Graph class only
@ -831,6 +833,7 @@ class Graph {
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
Graph* parent_graph,
const Node* parent_node,
const logging::Logger& logger,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions = {});
// Add node with specified <node_proto>.
@ -1038,6 +1041,8 @@ class Graph {
// number of times Resolve has run.
int num_resolves_ = 0;
const logging::Logger& logger_;
};
} // namespace onnxruntime

View file

@ -68,13 +68,13 @@ class NodeArg {
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict = true);
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger);
/** Validate and merge type [and shape] info from node_arg.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict = true);
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger);
/** Gets this NodeArg as a ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }

View file

@ -7,15 +7,7 @@
#include "core/common/status.h"
#include "core/platform/ort_mutex.h"
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
#include "onnx/defs/schema.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
#include "core/graph/onnx_protobuf.h"
#include <mutex>
#include <deque>
#include "sstream"

View file

@ -37,19 +37,19 @@ class GraphTransformer {
@param[out] modified Set to true if the Graph was modified.
@returns Status with success or error information.
*/
common::Status Apply(Graph& graph, bool& modified) const;
common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const;
protected:
/** Helper method to call ApplyImpl on any subgraphs in the Node. */
common::Status Recurse(Node& node, bool& modified, int graph_level) const {
common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const {
int subgraph_level = ++graph_level;
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
auto& subgraph = *entry.second;
ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level));
ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level, logger));
}
return Status::OK();
}
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer);
@ -61,7 +61,8 @@ class GraphTransformer {
// You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases
// the call to Graph::Resolve in Apply prior to ApplyImpl being called, and after ApplyImpl fore the main graph
// completes (if 'modified' is true) should suffice.
virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const = 0;
virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger)
const = 0;
const std::string name_;
const std::unordered_set<std::string> compatible_provider_types_;

View file

@ -66,8 +66,8 @@ class RewriteRule {
@param[in] node The Node to apply the rewrite to.
@param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
@returns Status indicating success or providing error information */
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK();
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const {
return SatisfyCondition(graph, node, logger) ? Apply(graph, node, rule_effect, logger) : Status::OK();
}
private:
@ -79,11 +79,11 @@ class RewriteRule {
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
of the nodes. */
virtual bool SatisfyCondition(const Graph& graph, const Node& node) const = 0;
virtual bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const = 0;
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
The return-value of node may be different from the input-value due to rewriting.
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const = 0;
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const = 0;
};
} // namespace onnxruntime

View file

@ -63,7 +63,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
@returns Status indicating success or providing error information. */
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::reference_wrapper<const RewriteRule>>& rules,
RewriteRule::RewriteRuleEffect& rule_effect) const;
RewriteRule::RewriteRuleEffect& rule_effect, const logging::Logger& logger) const;
private:
using RuleEffect = RewriteRule::RewriteRuleEffect;
@ -76,7 +76,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
std::vector<std::reference_wrapper<const RewriteRule>> any_op_type_rules_;
// Performs a single top-down traversal of the graph and applies all registered rules.
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -55,7 +55,7 @@ extern "C" {
#ifdef _WIN32
#define ORT_TSTR(X) L##X
#else
#define ORT_TSTR(X) (X)
#define ORT_TSTR(X) X
#endif
#endif

View file

@ -47,9 +47,8 @@ class ExecutionProviders {
ORT_IGNORE_RETURN_VALUE(allocator_idx_map_.insert({allocator->Info(), new_provider_idx}));
}
exec_providers_.push_back(std::move(p_exec_provider));
exec_provider_ids_.push_back(provider_id);
exec_providers_.push_back(std::move(p_exec_provider));
return Status::OK();
}

View file

@ -25,6 +25,7 @@ namespace {
Status RemoveFileSpec(PWSTR pszPath, size_t cchPath) {
assert(pszPath != nullptr && pszPath[0] != L'\0');
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
(void)cchPath;
for (PWSTR t = L"\0"; *t == L'\0'; t = PathRemoveBackslashW(pszPath))
;
PWSTR pszLast = PathSkipRootW(pszPath);

View file

@ -3,15 +3,7 @@
#pragma once
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
#include "onnx/defs/schema.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
#include "core/graph/onnx_protobuf.h"
namespace onnxruntime {
namespace contrib {

View file

@ -3,15 +3,7 @@
#pragma once
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
#include "onnx/defs/schema.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
#include "core/graph/onnx_protobuf.h"
namespace onnxruntime {
namespace contrib {

View file

@ -145,15 +145,15 @@ static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& su
}
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
std::unique_ptr<IndexedSubGraph> customized_func)
: parent_graph_(&graph) {
std::unique_ptr<IndexedSubGraph> customized_func,
const logging::Logger& logger)
: parent_graph_(&graph),
body_("fused_function_subgraph", false, onnxruntime::ModelMetaData(),
IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}),
graph.DomainToVersionMap(), {}, logger)
{
customized_func_body_ = std::move(customized_func);
// Construct body.
body_ = onnxruntime::make_unique<onnxruntime::Model>("fused_function_subgraph", false, onnxruntime::ModelMetaData(),
IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}),
graph.DomainToVersionMap());
auto& function_body_graph = body_->MainGraph();
auto& function_body_graph = body_.MainGraph();
auto meta_def = customized_func_body_->GetMetaDef();
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
@ -220,15 +220,24 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
}
static std::unordered_map<std::string, int> GetOpsetVersionMap(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto){
return std::unordered_map<std::string, int>{{onnxruntime::kOnnxDomain, static_cast<int>(onnx_func_proto.since_version())}};
}
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
const onnxruntime::NodeIndex& node_index,
const ONNX_NAMESPACE::FunctionProto& onnx_func_proto)
: parent_graph_(&graph) {
const ONNX_NAMESPACE::FunctionProto& onnx_func_proto,
const logging::Logger& logger)
: parent_graph_(&graph),
body_ (onnx_func_proto.name(), false, onnxruntime::ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
GetOpsetVersionMap(onnx_func_proto), {}, logger),
onnx_func_proto_(onnx_func_proto)
{
// Make a copy of the FunctionProto.
// All FunctionBody ops with the same op type seem to share the same FunctionProto struct within a model.
// Hence, we make a copy prior to generating the graph representation of the function,
// as we might make some modifications to the FunctionProto along the way
onnx_func_proto_ = onnx_func_proto;
auto node_in_parent_graph = parent_graph_->GetNode(node_index);
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
@ -290,9 +299,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
std::unordered_map<std::string, int> domain_to_version;
//TODO: set correct domain and version
domain_to_version[onnxruntime::kOnnxDomain] = static_cast<int>(onnx_func_proto_.since_version());
body_ = onnxruntime::make_unique<onnxruntime::Model>(onnx_func_proto_.name(), false, onnxruntime::ModelMetaData(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
auto& function_body_graph = body_->MainGraph();
auto& function_body_graph = body_.MainGraph();
// Add node and node args into subgraph
// The subgraph preserved the input/output tensor names
// in the parent graph for later inlining purpose
@ -379,7 +387,7 @@ const ONNX_NAMESPACE::OpSchema& FunctionImpl::OpSchema() const {
}
const onnxruntime::Graph& FunctionImpl::Body() const {
return body_->MainGraph();
return body_.MainGraph();
}
const IndexedSubGraph& FunctionImpl::GetIndexedSubGraph() const {
@ -391,7 +399,8 @@ const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const {
}
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
std::unique_ptr<IndexedSubGraph> customized_func) {
return onnxruntime::make_unique<FunctionImpl>(graph, std::move(customized_func));
std::unique_ptr<IndexedSubGraph> customized_func,
const logging::Logger& logger) {
return onnxruntime::make_unique<FunctionImpl>(graph, std::move(customized_func), logger);
}
} // namespace onnxruntime

View file

@ -2,12 +2,13 @@
// Licensed under the MIT License.
#pragma once
#include "core/common/logging/logging.h"
#include "core/graph/function.h"
#include "core/graph/model.h"
namespace onnxruntime {
class Graph;
class Node;
class Model;
} // namespace onnxruntime
namespace onnxruntime {
@ -16,11 +17,13 @@ namespace onnxruntime {
class FunctionImpl final : public Function {
public:
FunctionImpl(const onnxruntime::Graph& graph,
std::unique_ptr<IndexedSubGraph> customized_func);
std::unique_ptr<IndexedSubGraph> customized_func,
const logging::Logger& logger);
FunctionImpl(const onnxruntime::Graph& graph,
const onnxruntime::NodeIndex& node_index,
const ONNX_NAMESPACE::FunctionProto& onnx_func);
const ONNX_NAMESPACE::FunctionProto& onnx_func,
const logging::Logger& logger);
~FunctionImpl() override;
@ -36,7 +39,7 @@ class FunctionImpl final : public Function {
const onnxruntime::Graph* const parent_graph_;
std::unique_ptr<IndexedSubGraph> customized_func_body_;
std::unique_ptr<ONNX_NAMESPACE::OpSchema> op_schema_;
std::unique_ptr<onnxruntime::Model> body_;
onnxruntime::Model body_;
ONNX_NAMESPACE::FunctionProto onnx_func_proto_;
};

View file

@ -59,7 +59,7 @@ static bool UsingLatestOnnxOpset(const DomainToVersionMap& opset_versions) {
static Status MergeShapeInfo(const std::string& output_name,
const TypeProto_Tensor& source, TypeProto_Tensor& target,
bool strict) {
bool strict, const logging::Logger& logger) {
try {
ONNX_NAMESPACE::mergeInShapeInfo(source, target);
} catch (const ONNX_NAMESPACE::InferenceError& ex) {
@ -70,9 +70,10 @@ static Status MergeShapeInfo(const std::string& output_name,
// mergeInShapeInfo does nothing unless source.shape() is not null, and there would be no conflict if
// target.shape() was empty. 'assert' just in case that ever changes.
assert(utils::HasShape(source) && utils::HasShape(target));
LOGS_DEFAULT(WARNING) << "Error merging shape info for output. '" << output_name
<< "' source:" << source.shape() << " target:" << target.shape()
<< ". Falling back to lenient merge.";
LOGS(logger, WARNING) << "Error merging shape info for output. '" << output_name
<< "' source:" << source.shape() << " target:" << target.shape()
<< ". Falling back to lenient merge.";
ONNX_NAMESPACE::UnionShapeInfo(source.shape(), target);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output:", output_name, " ", ex.what());
@ -207,7 +208,7 @@ void NodeArg::ClearShape() {
}
}
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict) {
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger) {
if (!utils::HasType(node_arg_info_)) {
*node_arg_info_.mutable_type() = input_type;
type_ = DataTypeUtils::ToType(node_arg_info_.type());
@ -236,7 +237,7 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
if (utils::HasShape(input_tensor_type)) {
auto& current_tensor_type = *current_type.mutable_tensor_type();
if (utils::HasShape(current_tensor_type)) {
ORT_RETURN_IF_ERROR(MergeShapeInfo(Name(), input_tensor_type, current_tensor_type, strict));
ORT_RETURN_IF_ERROR(MergeShapeInfo(Name(), input_tensor_type, current_tensor_type, strict, logger));
} else {
current_tensor_type = input_tensor_type;
}
@ -274,11 +275,11 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
return Status::OK();
}
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict) {
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger) {
auto status = Status::OK();
if (utils::HasType(node_arg.node_arg_info_))
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict);
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, logger);
return status;
}
@ -698,11 +699,13 @@ Graph::Graph(GraphProto* graph_proto,
const std::unordered_map<std::string, int>& domain_to_version,
Version ir_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
const logging::Logger& logger,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions)
: Graph(graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, model_functions) {}
: Graph(graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, model_functions) {}
Graph::Graph(GraphProto* graph_proto, const std::unordered_map<std::string, int>& domain_to_version, Version ir_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node,
const logging::Logger& logger,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions)
: graph_proto_(graph_proto),
schema_registry_(schema_registry),
@ -712,7 +715,8 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map<std::string, int>
ir_version_(ir_version),
using_latest_onnx_opset_(UsingLatestOnnxOpset(domain_to_version)),
parent_graph_(parent_graph),
parent_node_(parent_node) {
parent_node_(parent_node),
logger_(logger) {
ORT_ENFORCE(graph_proto != nullptr, "graph_proto cannot be null");
ArgNameToTypeMap name_to_type_map;
@ -769,7 +773,7 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map<std::string, int>
// so we prefer the shape from the initializer
name_to_type_map[tensor.name()] = t;
if (matching_graph_input != nullptr) {
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t));
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, logger));
}
} else {
// v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these.
@ -805,7 +809,7 @@ Graph::Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::Graph
: Graph(&subgraph_proto,
parent_graph.DomainToVersionMap(), parent_graph.IrVersion(), parent_graph.schema_registry_,
&parent_graph,
&parent_node) {
&parent_node, parent_graph.logger_) {
}
Status Graph::VerifyNoDuplicateName() {
@ -1456,7 +1460,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
const auto& subgraph_input = *subgraph_inputs->at(i);
NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
status = mutable_nodearg->UpdateTypeAndShape(input_type);
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, subgraph.logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
}
@ -1477,7 +1481,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
if (!subgraph_nodearg)
continue;
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg);
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, subgraph.logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
}
@ -1666,7 +1670,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op) {
// that have no values.
TypeProto_Tensor merge_target;
(*merge_target.mutable_shape()) = *output_def->Shape();
auto status = MergeShapeInfo(output_def->Name(), tensor_type, merge_target, using_latest_onnx_opset_);
auto status = MergeShapeInfo(output_def->Name(), tensor_type, merge_target, using_latest_onnx_opset_, logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node_name, " ", status.ErrorMessage());
}
@ -1784,8 +1788,9 @@ Status Graph::VerifyNodeAndOpMatch() {
auto iter = model_functions_.find(node.OpType());
if (iter != model_functions_.end()) {
const ONNX_NAMESPACE::FunctionProto* model_function_proto = iter->second;
auto model_func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *model_function_proto);
function_container_.emplace_back(std::move(model_func_ptr));
function_container_.emplace_back(onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(),
*model_function_proto,
logger_));
node.SetFunctionBody(*function_container_.back());
}
@ -1805,7 +1810,8 @@ Status Graph::VerifyNodeAndOpMatch() {
if (node.op_ && node.op_->HasFunction()) {
auto onnx_function_proto = node.op_->GetFunction();
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *onnx_function_proto);
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *onnx_function_proto,
logger_);
function_container_.emplace_back(std::move(func_ptr));
node.SetFunctionBody(*function_container_.back());
}
@ -2402,12 +2408,12 @@ void Graph::CleanUnusedInitializers() {
// on the first call to Graph::Resolve we are removing unnecessary initializers that should be removed
// from the model.
// on later calls we are removing initializers that optimizations have made redundant.
if (num_resolves_ == 0) {
LOGS_DEFAULT(WARNING) << "Removing initializer '"
<< name << "'. It is not used by any node and should be removed from the model.";
} else {
LOGS_DEFAULT(INFO) << "Removing initializer '" << name << "'. It is no longer used by any node.";
}
if (num_resolves_ == 0) {
LOGS(logger_, WARNING) << "Removing initializer '"
<< name << "'. It is not used by any node and should be removed from the model.";
} else {
LOGS(logger_, INFO) << "Removing initializer '" << name << "'. It is no longer used by any node.";
}
erase_list.push_back(name);
}
@ -2704,7 +2710,7 @@ Node& Graph::FuseSubGraph(std::unique_ptr<::onnxruntime::IndexedSubGraph> sub_gr
func_meta_def->domain);
fused_node.SetNodeType(Node::Type::Fused);
function_container_.emplace_back(MakeFunction(*this, std::move(sub_graph)));
function_container_.emplace_back(MakeFunction(*this, std::move(sub_graph), logger_));
fused_node.SetFunctionBody(*function_container_.back());
// Remove nodes fused above.

View file

@ -183,15 +183,15 @@ static void RemoveGraphEdges(Graph& graph, const std::vector<GraphEdge>& edges)
/** Given a graph, a list of edges, and a NodeArg name, checks if each of the edges provides an implicit input
to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs.
This is important when removing a node with this NodeArg as input. */
bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph,
const std::vector<GraphEdge>& output_edges,
const std::string& new_arg_name) {
static bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph,
const std::vector<GraphEdge>& output_edges,
const std::string& new_arg_name, const logging::Logger& logger) {
for (const auto& output_edge : output_edges) {
if (OutputEdgeProvidesImplicitInput(graph, output_edge)) {
const Node& output_edge_node = *graph.GetNode(output_edge.dst_node);
if (!CanUpdateImplicitInputNameInSubgraph(output_edge_node, output_edge.arg_name, new_arg_name)) {
LOGS_DEFAULT(WARNING) << " Implicit input name " << output_edge.arg_name
<< " cannot be safely updated to " << new_arg_name << " in one of the subgraphs.";
LOGS(logger, WARNING) << " Implicit input name " << output_edge.arg_name
<< " cannot be safely updated to " << new_arg_name << " in one of the subgraphs.";
return false;
}
}
@ -354,7 +354,7 @@ bool IsOutputUsed(const Node& node, int index) {
return false;
}
bool CanRemoveNode(const Graph& graph, const Node& node) {
bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& logger) {
const std::string* output_name = nullptr;
if (!IsOnlyOneOutputUsed(graph, node, output_name)) {
return false;
@ -386,14 +386,15 @@ bool CanRemoveNode(const Graph& graph, const Node& node) {
if (new_name) {
// Check that changing the current output name to the new name won't break any subgraphs that consume it
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node);
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, *new_name);
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, *new_name, logger);
}
return can_remove;
}
bool RemoveNode(Graph& graph, Node& node) {
assert(CanRemoveNode(graph, node));
//TODO: enable the check back
//assert(CanRemoveNode(graph, node, nullptr));
// Note: Node does not produce any graph outputs, and only a single output is used.
@ -411,7 +412,8 @@ bool RemoveNode(Graph& graph, Node& node) {
ORT_THROW("Should be unreachable if CanRemoveNodeAndMergeEdges is in sync with the logic here.");
}
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name) {
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name,
const logging::Logger& logger) {
// we have no way to handle replacing multiple outputs so check only one is used
const std::string* output_name = nullptr;
if (!IsOnlyOneOutputUsed(graph, node, output_name)) {
@ -435,7 +437,7 @@ bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const s
// Check that changing the current output name to the new name won't break any subgraphs
// that consume the current name
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node);
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, initializer_name);
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, initializer_name, logger);
}
return can_remove;

View file

@ -108,7 +108,7 @@ Conditions:
Subgraph rules:
- Removing the node won't break a subgraph that consumes the node's output
*/
bool CanRemoveNode(const Graph& graph, const Node& node);
bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& logger);
/** Removes the given Node from the Graph.
See CanRemoveNode for the conditions that must be satisfied in order to remove the node.
@ -128,7 +128,8 @@ Conditions:
- otherwise the required graph output will not be produced
- Removing the node won't break a subgraph that consumes the node's output
*/
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name);
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name,
const logging::Logger& logger);
/** Remove a node and replace its output with the provided NodeArg for an initializer.
See CanReplaceNodeWithInitializer for the conditions that must be satisfied in order to remove the node.*/

View file

@ -31,7 +31,8 @@ Model::Model(const std::string& graph_name,
const ModelMetaData& model_metadata,
const IOnnxRuntimeOpSchemaRegistryList& local_registries,
const std::unordered_map<std::string, int>& domain_to_version,
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_functions) {
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_functions,
const logging::Logger& logger) {
model_proto_ = onnxruntime::make_unique<ModelProto>();
model_proto_->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
model_proto_->mutable_graph()->set_name(graph_name);
@ -69,14 +70,17 @@ Model::Model(const std::string& graph_name,
// need to call private ctor so can't use make_shared
GSL_SUPPRESS(r .11)
graph_.reset(new Graph(model_proto_->mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry, model_functions_map));
graph_.reset(new Graph(model_proto_->mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry,
logger, model_functions_map));
}
Model::Model(const ModelProto& model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries)
: Model(onnxruntime::make_unique<ModelProto>(model_proto), local_registries) {
Model::Model(const ModelProto& model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger)
: Model(onnxruntime::make_unique<ModelProto>(model_proto), local_registries, logger) {
}
Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
if (!model_proto) {
throw std::invalid_argument("ModelProto was null.");
}
@ -111,13 +115,13 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) {
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
// in CI to opset 7 or above
LOGS_DEFAULT(WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
"with opset version 7 or above for opset domain 'ai.onnx'. "
"Please upgrade your model to opset 7 or higher. "
"For now, this opset "
<< version
<< " model may run depending upon legacy support "
"of some older opset version operators.";
LOGS(logger, WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
"with opset version 7 or above for opset domain 'ai.onnx'. "
"Please upgrade your model to opset 7 or higher. "
"For now, this opset "
<< version
<< " model may run depending upon legacy support "
"of some older opset version operators.";
}
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
@ -146,7 +150,8 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
// create instance. need to call private ctor so can't use make_unique
GSL_SUPPRESS(r .11)
graph_.reset(new Graph(model_proto_->mutable_graph(), domain_to_version, IrVersion(), schema_registry, model_functions_map));
graph_.reset(new Graph(model_proto_->mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger,
model_functions_map));
}
Version Model::IrVersion() const {
@ -237,7 +242,9 @@ Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) {
return Status::OK();
}
Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
// we expect a graph to be present
if (!utils::HasGraph(model_proto)) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf.");
@ -246,7 +253,7 @@ Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model,
// need to call private ctor so can't use make_shared
GSL_SUPPRESS(r .11)
try {
model.reset(new Model(model_proto, local_registries));
model.reset(new Model(model_proto, local_registries, logger));
} catch (const std::exception& ex) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
}
@ -256,7 +263,9 @@ Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model,
return Status::OK();
}
Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Model>& model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Model>& model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
// we expect a graph to be present
if (!utils::HasGraph(*p_model_proto)) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf.");
@ -265,7 +274,7 @@ Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Mo
// need to call private ctor so can't use make_shared
GSL_SUPPRESS(r .11)
try {
model.reset(new Model(std::move(p_model_proto), local_registries));
model.reset(new Model(std::move(p_model_proto), local_registries, logger));
} catch (const std::exception& ex) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
}
@ -276,7 +285,9 @@ Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Mo
}
template <typename T>
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
int fd;
Status status = Env::Default().FileOpenRd(file_path, fd);
if (!status.IsOK()) {
@ -293,7 +304,7 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model, con
}
}
try {
status = Model::Load(fd, p_model, local_registries);
status = Model::Load(fd, p_model, local_registries, logger);
} catch (std::exception& ex) {
GSL_SUPPRESS(es .84)
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
@ -328,36 +339,32 @@ static Status SaveModel(Model& model, const T& file_path) {
}
#ifdef _WIN32
GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load
GSL_SUPPRESS(r .35)
Status Model::Load(const std::wstring& file_path, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
return LoadModel(file_path, p_model, local_registries);
}
Status Model::Save(Model& model, const std::wstring& file_path) {
return SaveModel(model, file_path);
}
#endif
GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load
GSL_SUPPRESS(r .35)
Status Model::Load(const std::string& file_path, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
return LoadModel(file_path, p_model, local_registries);
Status Model::Load(const std::basic_string<ORTCHAR_T>& file_path, std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
return LoadModel(file_path, p_model, local_registries, logger);
}
Status Model::Save(Model& model, const std::string& file_path) {
return SaveModel(model, file_path);
}
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
std::unique_ptr<ModelProto> modelProto = onnxruntime::make_unique<ModelProto>();
const bool result = modelProto->ParseFromArray(p_bytes, count);
if (!result) {
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
}
p_model = std::make_shared<Model>(std::move(modelProto), local_registries);
p_model = std::make_shared<Model>(std::move(modelProto), local_registries, logger);
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));
@ -368,7 +375,8 @@ using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::io::ZeroCopyInputStream;
Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
if (fd < 0) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<p_fd> less than 0.");
}
@ -394,7 +402,7 @@ Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOp
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
}
#endif
p_model = std::make_shared<Model>(std::move(model_proto), local_registries);
p_model = std::make_shared<Model>(std::move(model_proto), local_registries, logger);
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));

View file

@ -8,6 +8,7 @@
#include <climits>
#include <string>
#include "core/graph/graph_viewer.h"
#include "core/session/onnxruntime_c_api.h"
#include "gsl/gsl"
@ -22,23 +23,32 @@ class Model {
public:
static constexpr Version kNoVersion = INT64_MAX;
explicit Model(const std::string& graph_name,
bool is_onnx_domain_only,
const logging::Logger& logger)
:Model(graph_name,is_onnx_domain_only, ModelMetaData(),IOnnxRuntimeOpSchemaRegistryList(),{},{},
logger){}
// Construct model from scratch.
explicit Model(const std::string& graph_name,
bool is_onnx_domain_only = false,
const ModelMetaData& model_metadata = ModelMetaData(),
const IOnnxRuntimeOpSchemaRegistryList& local_registries = IOnnxRuntimeOpSchemaRegistryList(),
const std::unordered_map<std::string, int>& domain_to_version = {},
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_specific_functions = {});
bool is_onnx_domain_only,
const ModelMetaData& model_metadata,
const IOnnxRuntimeOpSchemaRegistryList& local_registries,
const std::unordered_map<std::string, int>& domain_to_version,
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_specific_functions,
const logging::Logger& logger);
// NOTE: after calling this constructor, <*this> model will
// hold a copy of <model_proto>.
explicit Model(const ONNX_NAMESPACE::ModelProto& model_proto,
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);
// NOTE: after calling this constructor, <*this> model will
// own the <model_proto>.
explicit Model(std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto,
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);
// Get model's IR version.
// Return <kNoVersion> if not specified.
@ -88,10 +98,6 @@ class Model {
#ifdef _WIN32
static common::Status Save(Model& model, const std::wstring& file_path);
// TODO(Task:132) Use of shared_ptr<X>* in Load/Save methods is confusing.
static common::Status Load(const std::wstring& file_path, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registry = nullptr);
#endif
static common::Status Save(Model& model, const std::string& file_path);
@ -99,23 +105,29 @@ class Model {
static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto);
static common::Status Load(const std::string& file_path,
// TODO(Task:132) Use of shared_ptr<X>* in Load/Save methods is confusing.
static common::Status Load(const std::basic_string<ORTCHAR_T>& file_path,
/*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);
static common::Status Load(int fd, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);
// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
static common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);
static common::Status Load(const ONNX_NAMESPACE::ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);
static common::Status Load(std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_model_proto,
/*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);
private:
// Model data.

View file

@ -5,15 +5,7 @@
#include <functional>
#include <unordered_map>
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
#include "onnx/defs/schema.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
#include "core/graph/onnx_protobuf.h"
#include "core/common/status.h"
#include "core/graph/constants.h"

View file

@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
@ -21,7 +21,7 @@ Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) c
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||

View file

@ -17,7 +17,7 @@ class AddGeluFusion : public GraphTransformer {
: GraphTransformer("AddGeluFusion", compatible_execution_providers) {
}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -11,7 +11,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {
Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
auto& order = graph_viewer.GetNodesInTopologicalOrder();
@ -21,7 +21,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
continue;
}
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
InitializedTensorSet constant_inputs;
@ -52,7 +52,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
ORT_RETURN_IF_ERROR(kernel->Compute(&op_kernel_context));
std::vector<OrtValue> fetches;
frame.GetOutputs(fetches);
ORT_RETURN_IF_ERROR(frame.GetOutputs(fetches));
// Go over all output node args and substitute them with the newly computed tensors, which will be
// added to the graph as initializers.
@ -62,8 +62,8 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
OrtValue& ort_value = fetches[fetch_idx];
if (!ort_value.IsTensor()) {
LOGS_DEFAULT(WARNING) << "Unsupported output type of " << ort_value.Type()
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type()
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
unsupported_output_type = true;
break;
}

View file

@ -25,7 +25,7 @@ class ConstantFolding : public GraphTransformer {
const std::unordered_set<std::string> excluded_op_types_ =
{"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"};
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
/** Create a TensorProto that has the same value as the given OrtValue
and the same type and dimensions as the given NodeArg. */

View file

@ -72,7 +72,7 @@ static bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& m
} // namespace
Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
@ -82,7 +82,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
if (!node)
continue;
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", {1, 11}) ||
!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders()) ||

View file

@ -13,7 +13,7 @@ class ConvActivationFusion : public GraphTransformer {
: GraphTransformer("ConvActivationFusion", compatible_execution_providers) {}
private:
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const {
Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified, const logging::Logger&) const {
auto& conv_node = node;
auto& add_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index()); // get mutable next node
const auto& conv_inputs = conv_node.InputDefs();
@ -103,7 +103,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
return Status::OK();
}
bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
node.GetOutputEdgesCount() != 1) {
return false;

View file

@ -23,9 +23,9 @@ class ConvAddFusion : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
auto& conv_node = node;
Node& bn_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index());
@ -144,7 +144,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
return Status::OK();
}
bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
node.GetOutputEdgesCount() != 1) {
return false;

View file

@ -23,9 +23,9 @@ class ConvBNFusion : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
auto& conv_node = node;
auto& mul_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index());
@ -114,7 +114,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
return Status::OK();
}
bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
node.GetOutputEdgesCount() != 1) {
return false;

View file

@ -22,9 +22,9 @@ class ConvMulFusion : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -10,7 +10,7 @@
namespace onnxruntime {
Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
@ -18,7 +18,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule
return Status::OK();
}
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) const {
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
// We currently support elimination for Dropout operator v1, v6, v7, and v10.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10})) {
return false;
@ -28,7 +28,7 @@ bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) co
// It can be safely removed if it has only one output that is used (checked by CanRemoveNode)
// and that output is not the 'mask' output.
// The 'is_test' attribute in v1 and v6 is captured by the check for the 'mask' output.
return graph_utils::CanRemoveNode(graph, node) && !graph_utils::IsOutputUsed(node, 1);
return graph_utils::CanRemoveNode(graph, node, logger) && !graph_utils::IsOutputUsed(node, 1);
}
} // namespace onnxruntime

View file

@ -23,9 +23,9 @@ class EliminateDropout : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -28,7 +28,7 @@ static std::string ToLower(std::string s) {
}
}
Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/) const {
Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const {
for (const onnxruntime::NodeArg* graph_input : graph.GetInputs()) {
// Get the current input's type and shape
const auto* input_type = graph_input->TypeAsProto();
@ -59,10 +59,10 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
// If this dimension actually has a value but it doesn't match the override value, return an
// error.
if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) {
LOGS_DEFAULT(ERROR) << "The model has input '" << graph_input->Name() << "' "
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
<< "but the size of this dimension " << dimension.dim_value() << " "
<< "does not equal the specified override of" << dimension_override << ".";
LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' "
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
<< "but the size of this dimension " << dimension.dim_value() << " "
<< "does not equal the specified override of" << dimension_override << ".";
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
}

View file

@ -23,7 +23,7 @@ class FreeDimensionOverrideTransformer : public GraphTransformer {
explicit FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply);
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
std::map<std::string, int64_t> dimension_override_by_denotation_;
};

View file

@ -67,7 +67,7 @@ static bool IsSupportedDataType(const Node& node) {
return true;
}
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
@ -77,7 +77,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
continue; // we removed the node as part of an earlier fusion
Node& div = *p_div;
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||

View file

@ -21,7 +21,7 @@ class GeluFusion : public GraphTransformer {
GeluFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -19,7 +19,7 @@ bool IsFusableActivation(const Node& node) {
}
} // namespace
Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
@ -30,7 +30,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue; // node was removed
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||

View file

@ -12,7 +12,7 @@ class GemmActivationFusion : public GraphTransformer {
GemmActivationFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GemmActivationFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -7,11 +7,11 @@ using namespace ::onnxruntime::common;
namespace onnxruntime {
Status GraphTransformer::Apply(Graph& graph, bool& modified) const {
Status GraphTransformer::Apply(Graph& graph, bool& modified, const logging::Logger& logger) const {
// the Graph should be in a good state prior this being called, so there should be no need to call Resolve here
// ORT_RETURN_IF_ERROR(graph.Resolve());
auto status = ApplyImpl(graph, modified, 0);
auto status = ApplyImpl(graph, modified, 0, logger);
ORT_RETURN_IF_ERROR(status);
// At least currently, some transformers (InsertCastTransformer and MemcpyTransformer) need this to be called

View file

@ -8,7 +8,7 @@ using namespace ::onnxruntime::common;
namespace onnxruntime {
common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, TransformerLevel level) const {
common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, TransformerLevel level, const logging::Logger& logger) const {
const auto& transformers = level_to_transformer_map_.find(level);
if (transformers == level_to_transformer_map_.end()) {
return Status::OK();
@ -18,7 +18,7 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor
bool graph_changed = false;
for (const auto& transformer : transformers->second) {
bool modified = false;
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified));
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger));
graph_changed = graph_changed || modified;
}
if (!graph_changed) {

View file

@ -3,6 +3,7 @@
#pragma once
#include "core/common/logging/logging.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/rewrite_rule.h"
@ -20,7 +21,7 @@ class GraphTransformerManager {
common::Status Register(std::unique_ptr<GraphTransformer> transformer, TransformerLevel level);
// Apply all transformers registered for the given level on the given graph
common::Status ApplyTransformers(Graph& graph, TransformerLevel level) const;
common::Status ApplyTransformers(Graph& graph, TransformerLevel level, const logging::Logger& logger) const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformerManager);

View file

@ -10,7 +10,7 @@
namespace onnxruntime {
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
@ -18,8 +18,8 @@ Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rul
return Status::OK();
}
bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node) const {
return graph_utils::CanRemoveNode(graph, node);
bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
return graph_utils::CanRemoveNode(graph, node, logger);
}
} // namespace onnxruntime

View file

@ -23,9 +23,9 @@ class EliminateIdentity : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
}; // namespace onnxruntime
} // namespace onnxruntime

View file

@ -99,7 +99,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
}
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override {
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacement_defs;
auto output_args = graph.GetOutputs();
@ -187,7 +187,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
}
if (!removed) {
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
}
}
@ -195,7 +195,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
}
};
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const {
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
if (force_cpu_fp32_)
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph));
@ -267,7 +267,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
node->ReplaceDefs(replacement_defs);
modified = modified || casted;
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
}
auto status = Status::OK();
@ -282,7 +282,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
RemoveDuplicateCastTransformer remover;
// RemoveDuplicateCastTransformer is a special transformer required for correctness.
// It is provider agnostic so simply send an empty vector.
status = remover.Apply(graph, modified);
status = remover.Apply(graph, modified, logger);
}
return status;

View file

@ -22,7 +22,7 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer {
}
private:
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const;

View file

@ -43,7 +43,7 @@ X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
| |
+---------------------+
*/
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::vector<std::reference_wrapper<Node>> nodes_to_remove;
@ -54,7 +54,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
continue; // we removed the node as part of an earlier fusion
Node& reduce_mean_node = *p_reduce_mean;
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) ||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||

View file

@ -21,7 +21,7 @@ class LayerNormFusion : public GraphTransformer {
LayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("LayerNormFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
@ -21,7 +21,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||

View file

@ -12,7 +12,7 @@ class MatMulAddFusion : public GraphTransformer {
MatMulAddFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("MatMulAddFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -727,13 +727,13 @@ void NchwcTransformerImpl::Finalize(bool& modified) {
}
}
Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
NchwcTransformerImpl impl(graph);
GraphViewer graph_viewer(graph);
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(index);
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (node.GetExecutionProviderType() == kCpuExecutionProvider) {
impl.Transform(node);
}

View file

@ -19,7 +19,7 @@ class NchwcTransformer : public GraphTransformer {
NchwcTransformer() noexcept : GraphTransformer("NchwcTransformer") {}
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -9,7 +9,7 @@
namespace onnxruntime {
Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
const auto& next_node = *node.OutputNodesBegin();
// Clip opset 6 has min and max as attributes. they're inputs from opset 11 on.
@ -109,7 +109,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
return Status::OK();
}
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const {
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) {
return false;
}
@ -127,7 +127,7 @@ bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const
return false;
}
if (!graph_utils::CanRemoveNode(graph, node)) {
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
return false;
}

View file

@ -21,9 +21,9 @@ class FuseReluClip : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -28,9 +28,9 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr<RewriteRule> rule) {
Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::reference_wrapper<const RewriteRule>>& rules,
RuleEffect& rule_effect) const {
RuleEffect& rule_effect, const logging::Logger& logger) const {
for (const RewriteRule& rule : rules) {
ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect));
ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect, logger));
// If the current node was removed as a result of a rule, stop rule application for that node.
if (rule_effect == RuleEffect::kRemovedCurrentNode) {
break;
@ -39,7 +39,7 @@ Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
return Status::OK();
}
Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
auto& order = graph_viewer.GetNodesInTopologicalOrder();
@ -65,13 +65,13 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
rules = GetRewriteRulesForOpType(node->OpType());
if (rules) {
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect, logger));
}
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
rules = GetAnyOpRewriteRules();
if (rules) {
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect, logger));
}
}
@ -81,7 +81,7 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
}
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
}
}

View file

@ -12,7 +12,7 @@
namespace onnxruntime {
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
// Store the statically inferred shape of the input to the Shape operator.
const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape();
std::vector<int64_t> input_dims;
@ -49,7 +49,7 @@ Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& ru
return Status::OK();
}
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) const {
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1})) {
return false;
}
@ -70,7 +70,7 @@ bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node)
// we're going to create an initializer with the same name as the node output
const auto& new_initializer_name = node.OutputDefs()[0]->Name();
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name)) {
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name, logger)) {
return false;
}

View file

@ -24,9 +24,9 @@ class ShapeToInitializer : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -26,7 +26,7 @@ static bool IsSupportedDataType(const Node& node) {
/**
Skip Layer Normalization will fuse Add + LayerNormalization into one node.
*/
Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::vector<std::reference_wrapper<Node>> nodes_to_remove;
@ -37,7 +37,7 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
continue; // we removed the node as part of an earlier fusion.
Node& add_node = *p_add;
ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
!graph_utils::IsSupportedProvider(add_node, GetCompatibleExecutionProviders()) ||

View file

@ -18,10 +18,10 @@ The formula corresponding to LayerNorm activation subgraph:
*/
class SkipLayerNormFusion : public GraphTransformer {
public:
SkipLayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
explicit SkipLayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("SkipLayerNormFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -9,7 +9,7 @@
namespace onnxruntime {
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
@ -17,13 +17,13 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_e
return Status::OK();
}
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) const {
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
// We currently support elimination for Slice operator v1.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11})) {
return false;
}
if (!graph_utils::CanRemoveNode(graph, node)) {
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
return false;
}

View file

@ -23,9 +23,9 @@ class EliminateSlice : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -67,7 +67,7 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
// and mainly provides the subgraph recursion functionality
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
for (auto& provider : provider_types_) {
if (provider != onnxruntime::kCpuExecutionProvider &&
provider != onnxruntime::kMklDnnExecutionProvider &&
@ -91,7 +91,7 @@ common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
// handle any subgraphs in nodes
for (auto& node : graph.Nodes()) {
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
}
return Status::OK();

View file

@ -23,7 +23,7 @@ class MemcpyTransformer : public GraphTransformer {
: GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {}
private:
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
const std::vector<std::string> provider_types_;
std::reference_wrapper<const KernelRegistryManager> registry_manager_;

View file

@ -11,13 +11,13 @@ using namespace onnxruntime::common;
namespace onnxruntime {
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const {
NodeArg& input_def = *node.MutableInputDefs()[0];
const auto& tensor_proto = *graph_utils::GetConstantInitializer(graph, input_def.Name());
auto new_name = graph.GenerateNodeArgName("UnsqueezeElimination_" + input_def.Name());
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_name)) {
LOGS_DEFAULT(WARNING) << "UnsqueezeElimination cannot remove node " << node.Name();
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_name, logger)) {
LOGS(logger, WARNING) << "UnsqueezeElimination cannot remove node " << node.Name();
return Status::OK();
}
@ -68,7 +68,7 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect&
return Status::OK();
}
bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node) const {
bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
// Attempt to remove an Unsqueeze operator only if it gets a constant initializer as input.
return graph_utils::IsConstantInitializer(graph, node.InputDefs()[0]->Name());
}

View file

@ -23,9 +23,9 @@ class UnsqueezeElimination : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -151,8 +151,8 @@ class WindowsEnv : public Env {
}
Status MapFileIntoMemory(
const ORTCHAR_T* file_path, FileOffsetType offset, size_t length,
MappedMemoryPtr& mapped_memory) const override {
const ORTCHAR_T*, FileOffsetType, size_t,
MappedMemoryPtr&) const override {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "MapFileIntoMemory is not implemented on Windows.");
}

View file

@ -20,25 +20,23 @@ namespace Dml
}
onnxruntime::common::Status GraphTransformer::ApplyImpl(
onnxruntime::Graph& graph,
onnxruntime::Graph& graph,
bool& modified,
int graph_level) const
{
modified = false;
// Perform fusion
{
bool transformModifiedGraph = false;
PerformOperatorFusion(&graph, &transformModifiedGraph);
modified |= transformModifiedGraph;
int graph_level, const onnxruntime::logging::Logger&) const {
modified = false;
if (modified)
{
ORT_RETURN_IF_ERROR(graph.Resolve());
}
// Perform fusion
{
bool transformModifiedGraph = false;
PerformOperatorFusion(&graph, &transformModifiedGraph);
modified |= transformModifiedGraph;
if (modified) {
ORT_RETURN_IF_ERROR(graph.Resolve());
}
}
return onnxruntime::common::Status::OK();
return onnxruntime::common::Status::OK();
}
static std::string GetUniqueNodeName(const onnxruntime::Node* node)

View file

@ -3,6 +3,7 @@
#pragma once
// Lotus framework headers for onnxruntime::IExecutionProvider (not part of the operator ABI).
#include "core/common/logging/logging.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/execution_provider.h"
#include "core/framework/op_kernel.h"
@ -18,7 +19,7 @@ namespace Dml
GraphTransformer(const std::string& name, std::shared_ptr<onnxruntime::KernelRegistry> dmlRegistry);
private:
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level = 0) const final;
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final;
private:
void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const;

View file

@ -13,7 +13,7 @@ using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const {
Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified, const onnxruntime::logging::Logger&) const {
auto& BatchNormalization_node = node;
const auto& add_node = *BatchNormalization_node.OutputNodesBegin();
const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs();
@ -88,7 +88,7 @@ Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleE
return Status::OK();
}
bool BatchNormalizationAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
bool BatchNormalizationAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7}) ||
node.GetOutputEdgesCount() != 1) {
return false;

View file

@ -23,9 +23,9 @@ class BatchNormalizationAddFusion : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -13,7 +13,7 @@ using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const {
auto& BatchNormalization_node = node;
const auto& mul_node = *BatchNormalization_node.OutputNodesBegin();
const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs();
@ -101,7 +101,7 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE
return Status::OK();
}
bool BatchNormalizationMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
bool BatchNormalizationMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7}) ||
node.GetOutputEdgesCount() != 1) {
return false;

View file

@ -22,9 +22,9 @@ class BatchNormalizationMulFusion : public RewriteRule {
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
bool SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const override;
};
} // namespace onnxruntime

View file

@ -58,7 +58,7 @@ NGRAPHCustomOp::~NGRAPHCustomOp() {
}
//This method gets called in critical path of execution: Optimize
Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const {
Status NGRAPHCustomOp::Initialize(const OrtApi* api, OrtKernelContext* context) const {
Ort::CustomOpApi ort{*api};
size_t num_inputs = ort.KernelContext_GetInputCount(context);
@ -176,7 +176,7 @@ Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* c
}
//This method gets called in critical path of execution: Optimize
Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const {
Status NGRAPHCustomOp::Compute(const OrtApi* api, OrtKernelContext* context) const {
Ort::CustomOpApi ort{*api};
// Initialize nGraph function if it is not already initialized.

View file

@ -29,12 +29,12 @@ class NGRAPHCustomOp {
const ONNX_NAMESPACE::ModelProto& model_proto,
const std::shared_ptr<ngraph::runtime::Backend>& ng_backend);
Status Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const;
Status Compute(const OrtApi* api, OrtKernelContext* context) const;
~NGRAPHCustomOp();
private:
Status Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const;
Status Initialize(const OrtApi* api, OrtKernelContext* context) const;
std::shared_ptr<ngraph::runtime::Backend> ng_backend_;

View file

@ -527,14 +527,15 @@ NGRAPHExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
return result;
}
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node) {
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node, const logging::Logger& logger) {
const auto* node_function = fused_node->GetFunctionBody();
ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", fused_node->Name());
const Graph& node_subgraph = node_function->Body();
onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{},
IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap()};
IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(),
std::vector<ONNX_NAMESPACE::FunctionProto>(), logger};
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
@ -551,9 +552,7 @@ Status NGRAPHExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& f
// Local copy of backend since, class members cannot be captured.
auto ngraph_backend = ng_backend_;
compute_info.create_state_func = [model_proto = GetModelProtoFromFusedNode(fused_node), ngraph_backend]
(ComputeContext* context, FunctionState* state)
{
compute_info.create_state_func = [model_proto = GetModelProtoFromFusedNode(fused_node, *GetLogger()), ngraph_backend](ComputeContext* context, FunctionState* state) {
auto* p = new ngraph_ep::NGRAPHCustomOp(context, model_proto, ngraph_backend);
*state = p;
return 0;
@ -564,7 +563,7 @@ Status NGRAPHExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& f
delete reinterpret_cast<onnxruntime::ngraph_ep::NGRAPHCustomOp*>(state);
};
compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
onnxruntime::ngraph_ep::NGRAPHCustomOp* ng_custom_op = reinterpret_cast<onnxruntime::ngraph_ep::NGRAPHCustomOp*>(state);
return ng_custom_op->Compute(api, context);
};

View file

@ -223,7 +223,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
if (group.second) {
nodes_list_output.push_back(group);
} else {
onnxruntime::Model model_build(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap());
onnxruntime::Model model_build(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
onnxruntime::Graph& graph_build = model_build.MainGraph();
//Add node and node args
@ -285,7 +285,7 @@ std::vector<std::unique_ptr<ComputeCapability>>
TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const {
// Construct modelproto from graph
onnxruntime::Model model(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap());
onnxruntime::Model model(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
onnxruntime::Graph& graph_build = model.MainGraph();
for (const auto& node : graph.Nodes()) {
std::vector<onnxruntime::NodeArg*> inputs, outputs;
@ -379,7 +379,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
}
const Graph& graph_body = func_body->Body();
onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph_body.DomainToVersionMap());
onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph_body.DomainToVersionMap(), std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
*(model_proto.mutable_graph()) = graph_body.ToGraphProto();
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

View file

@ -173,9 +173,9 @@ common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr<IExec
return st;
}
}
execution_providers_.Add(provider_type, std::move(p_exec_provider));
return Status::OK();
p_exec_provider->SetLogger(session_logger_);
return execution_providers_.Add(provider_type, std::move(p_exec_provider));
}
common::Status InferenceSession::RegisterGraphTransformer(
@ -270,7 +270,8 @@ common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
AddCustomOpDomains({domain.get()});
}
#endif
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
*session_logger_);
};
common::Status st = Load(loader, "model_loading_uri");
@ -296,7 +297,8 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) {
AddCustomOpDomains({domain.get()});
}
#endif
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
*session_logger_);
};
return Load(loader, "model_loading_proto");
@ -311,7 +313,7 @@ common::Status InferenceSession::Load(std::unique_ptr<ModelProto> p_model_proto)
}
#endif
return onnxruntime::Model::Load(std::move(p_model_proto), model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr);
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_);
};
return Load(loader, "model_loading_proto");
@ -333,7 +335,8 @@ common::Status InferenceSession::Load(std::istream& model_istream) {
AddCustomOpDomains({domain.get()});
}
#endif
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
*session_logger_);
};
return Load(loader, "model_loading_istream");
@ -355,7 +358,8 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
}
#endif
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
*session_logger_);
};
return Load(loader, "model_loading_array");
@ -375,7 +379,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// 5. insert cast nodes.
// first apply global(execution provider independent), level 1(default/system/basic) graph to graph optimizations
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1));
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *session_logger_));
#ifdef USE_DML
// TODO: this is a temporary workaround to apply the DML EP's custom graph transformer prior to partitioning. This
@ -390,7 +394,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
Dml::GraphTransformer dml_transformer(onnxruntime::kDmlExecutionProvider, std::move(dml_registry));
bool modified = false;
dml_transformer.Apply(graph, modified);
dml_transformer.Apply(graph, modified, *session_logger_);
}
#endif
@ -401,12 +405,12 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// apply transformers except default transformers
// Default transformers are required for correctness and they are owned and run by inference session
for (int i = static_cast<int>(TransformerLevel::Level1); i < static_cast<int>(TransformerLevel::MaxTransformerLevel); i++) {
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i)));
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *session_logger_));
}
bool modified = false;
// Insert cast node/s.
ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified));
ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified, *session_logger_));
// Now every node should be already assigned to an execution provider
std::unordered_map<std::string, std::vector<std::string>> node_placements;
@ -455,7 +459,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// Insert copy node/s.
MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager};
ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified));
ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified, *session_logger_));
return common::Status::OK();
}

View file

@ -58,7 +58,7 @@ std::vector<float> Add_Simple(const std::vector<float>& input_a_data, const std:
const std::vector<float>& input_small_size = input_a_data.size() < input_b_data.size() ? input_a_data : input_b_data;
std::vector<float> output(input_large_size.size());
for (int iter = 0; iter < input_large_size.size() / input_small_size.size(); iter++) {
for (size_t iter = 0; iter < input_large_size.size() / input_small_size.size(); iter++) {
std::transform(input_large_size.begin() + iter * input_small_size.size(),
input_large_size.begin() + (iter + 1) * input_small_size.size(),
input_small_size.begin(),

View file

@ -12,6 +12,7 @@
#include "test/framework/model_builder_utils.h"
#include "core/framework/allocation_planner.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "test/test_environment.h"
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
@ -162,7 +163,10 @@ class PlannerTest : public ::testing::Test {
public:
PlannerTest()
: model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_, nullptr) {
: model_("test", false, DefaultLoggingManager().DefaultLogger()),
graph_(model_.MainGraph()),
tp_("test", 1),
state_(execution_providers_, false, &tp_, nullptr) {
std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
in_place_kernel_ =
KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build();
@ -171,8 +175,6 @@ class PlannerTest : public ::testing::Test {
execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider));
}
~PlannerTest() = default;
onnxruntime::NodeArg* Arg(const std::string& name) {
auto iter = name_to_arg_.find(name);
if (name_to_arg_.end() != iter) return iter->second;

View file

@ -73,7 +73,9 @@ CREATE_INITIALIZER_FUNC(float, TensorProto_DataType_FLOAT, add_float_data)
CREATE_INITIALIZER_FUNC(int64_t, TensorProto_DataType_INT64, add_int64_data)
// TO DO: Figure out a way to enable it again
TEST(CUDAFenceTests, DISABLED_PartOnCPU) {
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test");
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test",
false,
DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
@ -135,7 +137,7 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) {
}
TEST(CUDAFenceTests, TileWithInitializer) {
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test");
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
@ -189,7 +191,8 @@ TEST(CUDAFenceTests, TileWithInitializer) {
TEST(CUDAFenceTests, TileWithComputedInput) {
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[onnxruntime::kOnnxDomain] = 7;
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);

View file

@ -21,7 +21,7 @@ namespace test {
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
std::shared_ptr<onnxruntime::Model> DummyGraphWithClip() {
auto model = std::make_shared<onnxruntime::Model>("test");
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
@ -42,7 +42,7 @@ class ExecutionFrameTest : public ::testing::Test {
};
TEST_F(ExecutionFrameTest, TensorAllocationTest) {
onnxruntime::Model model("test");
onnxruntime::Model model("test", false, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model.MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
@ -112,7 +112,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
TEST_F(ExecutionFrameTest, FeedInDataTest) {
onnxruntime::Model model("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
std::unordered_map<std::string, int>{{"", 10}});
std::unordered_map<std::string, int>{{"", 10}}, {},
DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model.MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
@ -164,7 +165,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
auto xp_type = cpu_xp->Type();
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[onnxruntime::kOnnxDomain] = 7;
onnxruntime::Model model("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
onnxruntime::Model model("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model.MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);

View file

@ -37,7 +37,6 @@
#include "test/providers/provider_test_utils.h"
#include "test/optimizer/dummy_graph_transformer.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "gtest/gtest.h"
using namespace std;
@ -133,16 +132,17 @@ class InferenceSessionGetGraphWrapper : public InferenceSession {
namespace test {
static void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
const std::vector<float>& expected_values);
static const std::string MODEL_URI = "testdata/mul_1.onnx";
static const std::string MODEL_URI_NO_OPSET = "testdata/mul_1.noopset.onnx";
static constexpr const ORTCHAR_T* MODEL_URI = ORT_TSTR("testdata/mul_1.onnx");
static constexpr const ORTCHAR_T* MODEL_URI_NO_OPSET = ORT_TSTR("testdata/mul_1.noopset.onnx");
//static const std::string MODEL_URI = "./testdata/squeezenet/model.onnx"; // TODO enable this after we've weights?
static void CreateMatMulModel(std::unique_ptr<onnxruntime::Model>& p_model, ProviderType provider_type) {
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[onnxruntime::kOnnxDomain] = 7;
// Generate the input & output def lists
p_model = onnxruntime::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version);
std::vector<ONNX_NAMESPACE::FunctionProto> model_specific_functions;
p_model = onnxruntime::make_unique<Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, model_specific_functions, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = p_model->MainGraph();
TypeProto tensor_float;
@ -455,11 +455,11 @@ TEST(InferenceSessionTests, ModelMetadata) {
so.session_logid = "InferenceSessionTests.ModelMetadata";
InferenceSession session_object{so, &DefaultLoggingManager()};
string model_uri = "../models/opset8/test_squeezenet/model.onnx";
auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx");
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
std::shared_ptr<onnxruntime::Model> p_model;
Status st = onnxruntime::Model::Load(model_uri, p_model);
Status st = onnxruntime::Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(st.IsOK());
const onnxruntime::Graph& graph = p_model->MainGraph();
@ -1029,7 +1029,7 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
}
TEST(ExecutionProviderTest, FunctionTest) {
onnxruntime::Model model("graph_1");
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
@ -1116,7 +1116,7 @@ TEST(ExecutionProviderTest, FunctionTest) {
}
TEST(ExecutionProviderTest, FunctionInlineTest) {
onnxruntime::Model model("graph_1");
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
ONNX_NAMESPACE::FunctionProto fc_proto;
fc_proto.set_name("FC");

View file

@ -7,16 +7,17 @@
#include "core/graph/model.h"
#include "gtest/gtest.h"
#include "test_utils.h"
#include "test/test_environment.h"
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
static const std::string MODEL_FOLDER = "testdata/transform/";
#define MODEL_FOLDER ORT_TSTR("testdata/transform/")
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
TEST(TransformerTest, InsertCastGPUTest) {
auto model = std::make_shared<onnxruntime::Model>("test");
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float_16;
@ -38,7 +39,7 @@ TEST(TransformerTest, InsertCastGPUTest) {
InsertCastTransformer transformer("Test");
bool modified = true;
status = transformer.Apply(graph, modified);
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(status.IsOK());
status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
@ -64,7 +65,7 @@ TEST(TransformerTest, InsertCastGPUTest) {
}
TEST(TransformerTest, InsertCastAllCPUTest) {
auto model = std::make_shared<onnxruntime::Model>("test");
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float_16;
@ -86,7 +87,7 @@ TEST(TransformerTest, InsertCastAllCPUTest) {
InsertCastTransformer transformer("Test");
bool modified = true;
EXPECT_TRUE(transformer.Apply(graph, modified).IsOK());
EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK());
status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_EQ(graph.NumberOfNodes(), 7);
@ -109,9 +110,9 @@ TEST(TransformerTest, InsertCastAllCPUTest) {
// test that when there are 3 Cast ops in a row we remove the correct ones
TEST(TransformerTest, ThreeInARowRemoval) {
std::string model_uri = MODEL_FOLDER + "triple-cast.onnx";
auto model_uri = MODEL_FOLDER ORT_TSTR("triple-cast.onnx");
std::shared_ptr<Model> model;
auto status = Model::Load(model_uri, model);
auto status = Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK()) << status;
Graph& graph = model->MainGraph();
@ -123,7 +124,7 @@ TEST(TransformerTest, ThreeInARowRemoval) {
InsertCastTransformer transformer("Test");
bool modified = false;
status = transformer.Apply(graph, modified);
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(status.IsOK()) << status;
EXPECT_TRUE(modified) << "Transformer should have removed some Cast nodes";
status = graph.Resolve();

View file

@ -8,6 +8,7 @@
#include "core/graph/model.h"
#include "gtest/gtest.h"
#include "test_utils.h"
#include "test/test_environment.h"
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
@ -74,7 +75,8 @@ TEST(TransformerTest, MemcpyTransformerTest) {
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 7;
auto model = std::make_shared<onnxruntime::Model>("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version);
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float_type;
@ -111,7 +113,7 @@ TEST(TransformerTest, MemcpyTransformerTest) {
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
bool modified = false;
status = transformer.Apply(graph, modified);
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(modified);
@ -128,7 +130,8 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) {
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 7;
auto model = std::make_shared<onnxruntime::Model>("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version);
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TypeProto tensor_float_type;
@ -165,7 +168,7 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) {
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
bool modified = false;
status = transformer.Apply(graph, modified);
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(modified);
@ -204,7 +207,8 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
false,
ModelMetaData(),
IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version);
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();
TensorProto parent_constant(value_tensor);
@ -221,7 +225,8 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
false,
ModelMetaData(),
IOnnxRuntimeOpSchemaRegistryList(),
subgraph_domain_to_version);
subgraph_domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& subgraph = sub_model->MainGraph();
TensorProto local_constant(value_tensor);
@ -278,7 +283,7 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
bool modified = false;
status = transformer.Apply(graph, modified);
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(modified);
}

View file

@ -298,7 +298,7 @@ TEST_F(OpaqueTypeTests, RunModel) {
IOnnxRuntimeOpSchemaRegistryList custom_schema_registries_ = {registry->GetOpschemaRegistry()};
std::unordered_map<std::string, int> domain_to_version = {{onnxruntime::kMLDomain, 8}};
Model model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version);
Model model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;

View file

@ -14,6 +14,7 @@
#include "core/graph/op.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "gtest/gtest.h"
#include "test/test_environment.h"
using namespace ONNX_NAMESPACE;
using namespace std;
@ -42,7 +43,7 @@ TEST(SessionStateTest, AddGetKernelTest) {
ExecutionProviders execution_providers;
SessionState s{execution_providers, true, &tp, nullptr};
onnxruntime::Model model("graph_1");
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
@ -94,10 +95,11 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
const TestParam& param = GetParam();
concurrency::ThreadPool tp{"test", 1};
std::string model_path = "testdata/optional_inputs_ir" + std::to_string(param.ir_version) + ".onnx";
std::basic_ostringstream<ORTCHAR_T> oss;
oss << ORT_TSTR("testdata/optional_inputs_ir") << param.ir_version << ORT_TSTR(".onnx");
Status status;
std::shared_ptr<Model> model;
ASSERT_TRUE((status = Model::Load(model_path, model)).IsOK()) << status;
ASSERT_TRUE((status = Model::Load(oss.str(), model, nullptr, DefaultLoggingManager().DefaultLogger())).IsOK()) << status;
Graph& graph = model->MainGraph();
// take a copy as this gets cleared during session state initialization
InitializedTensorSet initializers = graph.GetAllInitializedTensors();
@ -112,7 +114,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
ASSERT_TRUE(status.IsOK()) << status;
SessionState session_state(execution_providers, param.enable_mem_pattern, &tp, nullptr);
SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph, session_state,
SessionStateInitializer session_initializer(param.enable_mem_pattern, oss.str(), graph, session_state,
execution_providers, krm);
GraphPartitioner partitioner(krm, execution_providers);

View file

@ -7,6 +7,8 @@
#include "gtest/gtest.h"
#include "core/graph/model.h"
#include "test/framework/model_builder_utils.h"
#include "test/test_environment.h"
using namespace ONNX_NAMESPACE;
using namespace std;
@ -22,7 +24,7 @@ class ShapeInferenceTest : public ::testing::Test {
std::unordered_map<string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;
public:
ShapeInferenceTest() : model_("Test"), node_count_(0) {}
ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {}
void Input(const std::string& name, const Type& type) {
name_to_arg_[name] = onnxruntime::make_unique<onnxruntime::NodeArg>(name, &type.value);

View file

@ -3,17 +3,7 @@
#include "core/framework/data_types.h"
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
#include "onnx/defs/schema.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
#include "core/graph/onnx_protobuf.h"
#include "core/graph/constants.h"
#include "core/framework/op_kernel.h"
@ -293,7 +283,7 @@ class SparseTensorTests : public testing::Test {
registry(std::make_shared<CustomRegistry>()),
custom_schema_registries_{registry->GetOpschemaRegistry()},
domain_to_version{{onnxruntime::kMLDomain, 10}},
model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version),
model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version, {}, DefaultLoggingManager().DefaultLogger()),
graph(model.MainGraph()) {
EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK());
}
@ -306,21 +296,21 @@ class SparseTensorTests : public testing::Test {
schema.SinceVersion(10);
schemas.push_back(schema);
Action register_kernel = [](CustomRegistry* registry) {
Action register_kernel = [](CustomRegistry* registry2) {
auto kernel_def_builder = Op::KernelDef();
kernel_def_builder
.SetDomain(onnxruntime::kMLDomain)
.SinceVersion(10)
.Provider(onnxruntime::kCpuExecutionProvider);
KernelCreateFn kernel_create_fn = [](const OpKernelInfo& info) { return new typename Op::OpKernelImpl(info); };
EXPECT_TRUE(registry->RegisterCustomKernel(kernel_def_builder, kernel_create_fn).IsOK());
EXPECT_TRUE(registry2->RegisterCustomKernel(kernel_def_builder, kernel_create_fn).IsOK());
};
register_actions.push_back(register_kernel);
}
void RegisterOps() {
EXPECT_TRUE(registry->RegisterOpSet(schemas, onnxruntime::kMLDomain, 10, 11).IsOK());
for (auto registerop : register_actions)
for (auto& registerop : register_actions)
registerop(registry.get());
}

View file

@ -121,7 +121,7 @@ static const bool kSchemasRegistered = RegisterCustomSchemas();
TEST(GraphTraversalTest, ReverseDFS) {
ASSERT_TRUE(kSchemasRegistered);
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
/* Case 1: A normal graph.
@ -219,7 +219,7 @@ TEST(GraphTraversalTest, ReverseDFS) {
}
TEST(ResolvingGraphTest, GraphConstruction_VerifyNoDuplicateName) {
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
EXPECT_EQ("graph_1", graph.Name());
@ -252,7 +252,7 @@ TEST(ResolvingGraphTest, GraphConstruction_VerifyNoDuplicateName) {
}
TEST(ResolvingGraphTest, GraphConstruction_VerifyNodeAndOpMatch) {
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<NodeArg*> inputs;
@ -275,7 +275,7 @@ TEST(ResolvingGraphTest, GraphConstruction_VerifyNodeAndOpMatch) {
TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
ASSERT_TRUE(kSchemasRegistered);
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
/* A normal graph.
@ -336,7 +336,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
EXPECT_TRUE(Model::Save(model, "graph_1.onnx").IsOK());
std::shared_ptr<Model> model2;
EXPECT_TRUE(Model::Load("graph_1.onnx", model2).IsOK());
EXPECT_TRUE(Model::Load(ORT_TSTR("graph_1.onnx"), model2, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
auto model_proto = model.ToProto();
auto model_proto2 = model2->ToProto();
@ -345,7 +345,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
// Load the model again to ensure that it's still the right thing.
//EXPECT_EQ(Model::Load(model_proto2, &model2), Status::OK());
model2.reset(new Model(model_proto2));
model2.reset(new Model(model_proto2, nullptr, DefaultLoggingManager().DefaultLogger()));
Graph& graph2 = model2->MainGraph();
for (auto& node : graph2.Nodes()) {
auto node_name_to_input_output_iter = expected_node_name_to_input_output_args.find(node.Name());
@ -368,7 +368,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
TEST(ResolvingGraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
ASSERT_TRUE(kSchemasRegistered);
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
// node_1 (Identity) node_2 (Identity)
@ -448,7 +448,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) {
ASSERT_TRUE(kSchemasRegistered);
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::unordered_map<std::string, int> map;
@ -541,7 +541,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
ASSERT_TRUE(result) << "Failed to load model from serialized protobuf";
std::shared_ptr<onnxruntime::Model> p_tmp_model;
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr);
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr, DefaultLoggingManager().DefaultLogger());
auto& graph2 = p_tmp_model->MainGraph();
status = graph2.Resolve();
@ -555,7 +555,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) {
ASSERT_TRUE(kSchemasRegistered);
Model model("UnusedInitializerIsIgnored");
Model model("UnusedInitializerIsIgnored", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<NodeArg*> inputs;
@ -596,7 +596,7 @@ TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) {
ASSERT_TRUE(result) << "Failed to load model from serialized protobuf";
std::shared_ptr<onnxruntime::Model> p_tmp_model;
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr);
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr, DefaultLoggingManager().DefaultLogger());
auto& graph2 = p_tmp_model->MainGraph();
status = graph2.Resolve();
@ -621,7 +621,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsNotAcyclic) {
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
auto& input_arg1 = graph.GetOrCreateNodeArg("node_1_in_1", &tensor_int32);
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
@ -643,7 +643,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsNotAcyclic) {
}
TEST(ResolvingGraphTest, GraphConstruction_OnlyInitializer) {
onnxruntime::Model model("graph");
onnxruntime::Model model("graph", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
ONNX_NAMESPACE::TensorProto weight;
@ -663,7 +663,7 @@ TEST(ResolvingGraphTest, GraphConstruction_OnlyInitializer) {
TEST(ResolvingGraphTest, GraphConstruction_TypeInference) {
ASSERT_TRUE(kSchemasRegistered);
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
/* Case 1: A normal graph.
@ -725,7 +725,7 @@ TEST(ResolvingGraphTest, GraphConstruction_TypeInference) {
EXPECT_TRUE(Model::Save(model, "model_x.onnx").IsOK());
std::shared_ptr<Model> loaded_model;
EXPECT_TRUE(Model::Load("model_x.onnx", loaded_model).IsOK());
EXPECT_TRUE(Model::Load(ORT_TSTR("model_x.onnx"), loaded_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
EXPECT_EQ(2, loaded_model->MainGraph().GetInputs().size());
auto& graph_proto = graph.ToGraphProto();
@ -740,7 +740,7 @@ TEST(ResolvingGraphTest, GraphConstruction_TypeInference) {
TEST(ResolvingGraphTest, ShapeInferenceErrorHandling) {
ASSERT_TRUE(kSchemasRegistered);
Model model("graph");
Model model("graph", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
TypeProto tensor_int32;
@ -765,7 +765,7 @@ TEST(TestAddAttribute, AddTensorAttribute) {
.Output(0, "output_1", "docstr for output_1.", "tensor(int64)");
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
TypeProto output_type;
output_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
@ -810,7 +810,7 @@ void AddAttribute(onnxruntime::Node& p_node, const std::string& attr_name, std::
TEST(TypeInferenceTest, TypeAttribute) {
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", nullptr);
outputs.push_back(&output_arg);
@ -834,7 +834,7 @@ TEST(TypeInferenceTest, VariadicOutput) {
std::vector<NodeArg*> outputs;
TypeProto tensor_type;
tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
auto& X = graph.GetOrCreateNodeArg("X", &tensor_type);
inputs.push_back(&X);
@ -851,7 +851,7 @@ TEST(TypeInferenceTest, VariadicOutput) {
// test that we prefer the graph input shape for a non-const initializer (initializer with matching graph input)
TEST(TypeInferenceTest, NonConstInitializer) {
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
TypeProto tensor_type_no_shape;
@ -901,7 +901,7 @@ TEST(TypeInferenceTest, NonConstInitializer) {
ASSERT_TRUE(model.ToProto().SerializeToString(&s1));
ASSERT_TRUE(model_proto.ParseFromString(s1));
auto status = onnxruntime::Model::Load(model_proto, p_model, nullptr);
auto status = onnxruntime::Model::Load(model_proto, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK()) << status;
auto& graph2 = p_model->MainGraph();
@ -910,7 +910,7 @@ TEST(TypeInferenceTest, NonConstInitializer) {
// Test that Graph::Resolve identifies name-duplication across initializer and node-output-arg
TEST(NameResolutionTest, DuplicateName) {
Model model("graph_1");
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
ONNX_NAMESPACE::TensorProto weight;
@ -940,7 +940,7 @@ TEST(NameResolutionTest, DuplicateName) {
}
TEST(GraphUpdateTest, ReplaceInitializedTensor) {
Model model{"GraphUpdateTest"};
Model model{"GraphUpdateTest", false, DefaultLoggingManager().DefaultLogger()};
auto& graph = model.MainGraph();
const std::string initializer_name = "initializer";
@ -1011,7 +1011,7 @@ TEST(GraphUpdateTest, ReplaceInitializedTensor) {
}
TEST(GraphUpdateTest, AddRemoveInitializerHandling) {
Model m{"test_model"};
Model m{"test_model", false, DefaultLoggingManager().DefaultLogger()};
Graph& graph = m.MainGraph();
auto create_tensor_proto = [](const std::string& name, int32_t value) {

View file

@ -7,6 +7,8 @@
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/graph/op.h"
#include "core/session/onnxruntime_c_api.h"
#include "test/test_environment.h"
#include "gtest/gtest.h"
using namespace onnxruntime;
@ -47,29 +49,18 @@ static void TestResolve(onnxruntime::Graph& graph) {
TEST(ONNXModelsTest, squeeze_net) {
// NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load("../models/opset8/test_squeezenet/model.onnx", model).IsOK());
ASSERT_TRUE(Model::Load(ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"), model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
TestResolve(model->MainGraph());
#ifdef _WIN32
// wstring version
std::shared_ptr<Model> model2;
ASSERT_TRUE(Model::Load(L"../models/opset8/test_squeezenet/model.onnx", model2).IsOK());
TestResolve(model2->MainGraph());
#endif
}
#endif
TEST(ONNXModelsTest, non_existing_model) {
// NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located
std::shared_ptr<Model> model;
common::Status st = Model::Load("./testdata/non_existing_model_XXXXXX/model.onnx", model);
common::Status st = Model::Load(ORT_TSTR("./testdata/non_existing_model_XXXXXX/model.onnx"), model, nullptr,
DefaultLoggingManager().DefaultLogger());
ASSERT_FALSE(st.IsOK());
ASSERT_EQ(st.Code(), common::NO_SUCHFILE);
#ifdef _WIN32
// wstring version
std::shared_ptr<Model> model2;
ASSERT_FALSE(Model::Load(L"./testdata/non_existing_model_XXXXXX/model.onnx", model2).IsOK());
ASSERT_EQ(st.Code(), common::NO_SUCHFILE);
#endif
}
#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS
@ -78,7 +69,7 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) {
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::io::ZeroCopyInputStream;
int fd;
ASSERT_TRUE(Env::Default().FileOpenRd("../models/opset8/test_bvlc_alexnet/model.onnx", fd).IsOK());
ASSERT_TRUE(Env::Default().FileOpenRd(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), fd).IsOK());
ASSERT_TRUE(fd > 0);
std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
std::unique_ptr<CodedInputStream> coded_input(new CodedInputStream(raw_input.get()));
@ -92,7 +83,9 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) {
ASSERT_TRUE(Env::Default().FileClose(fd).IsOK());
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load("../models/opset8/test_bvlc_alexnet/model.onnx", model).IsOK());
ASSERT_TRUE(Model::Load(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), model, nullptr,
DefaultLoggingManager().DefaultLogger())
.IsOK());
// Check the graph input/output/value_info should have the same size as specified in the model file.
EXPECT_EQ(model_proto.graph().value_info_size(), model->MainGraph().GetValueInfo().size());
@ -101,21 +94,23 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) {
TestResolve(model->MainGraph());
}
class ONNXModelsTest : public ::testing::TestWithParam<const char*> {
class ONNXModelsTest : public ::testing::TestWithParam<const ORTCHAR_T*> {
// You can implement all the usual fixture class members here.
// To access the test parameter, call GetParam() from class
// TestWithParam<T>.
public:
std::string GetModelFileName() const {
std::ostringstream oss;
oss << "../models/opset7/test_" << GetParam() << "/model.onnx";
std::basic_string<ORTCHAR_T> GetModelFileName() const {
std::basic_ostringstream<ORTCHAR_T> oss;
oss << ORT_TSTR("../models/opset7/test_") << GetParam() << ORT_TSTR("/model.onnx");
return oss.str();
}
};
TEST_P(ONNXModelsTest, LoadFromFile) {
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(GetModelFileName(), model).IsOK());
ASSERT_TRUE(Model::Load(GetModelFileName(), model, nullptr,
DefaultLoggingManager().DefaultLogger())
.IsOK());
TestResolve(model->MainGraph());
}
@ -137,18 +132,20 @@ TEST_P(ONNXModelsTest, LoadFromProtobuf) {
ASSERT_TRUE(result);
ASSERT_TRUE(Env::Default().FileClose(fd).IsOK());
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(std::move(model_proto), model).IsOK());
ASSERT_TRUE(Model::Load(std::move(model_proto), model, nullptr,
DefaultLoggingManager().DefaultLogger())
.IsOK());
TestResolve(model->MainGraph());
}
#ifndef DISABLE_CONTRIB_OPS
INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
ONNXModelsTest,
::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "tiny_yolov2", "vgg19", "zfnet512"));
::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("tiny_yolov2"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512")));
#else
INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
ONNXModelsTest,
::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "vgg19", "zfnet512"));
::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512")));
#endif
#endif
@ -159,7 +156,9 @@ INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
// for Graph::Resolve to succeed when processing the subgraph.
TEST(ONNXModelsTest, TestIRv4NonInputInitializers) {
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK());
ASSERT_TRUE(Model::Load(ORT_TSTR("testdata/subgraph_implicit_input_from_initializer.onnx"), model, nullptr,
DefaultLoggingManager().DefaultLogger())
.IsOK());
EXPECT_TRUE(model->MainGraph().Resolve().IsOK());
}
@ -170,7 +169,8 @@ TEST(ONNXModelsTest, TestIRv4NonInputInitializers) {
TEST(ONNXModelsTest, TestModelsWithAnOpContainingAFunctionBody) {
std::shared_ptr<Model> model;
auto status = Model::Load("testdata/model_containing_op_with_function_body.onnx", model);
auto status = Model::Load(ORT_TSTR("testdata/model_containing_op_with_function_body.onnx"), model, nullptr,
DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(status.IsOK()) << status;
status = model->MainGraph().Resolve();

View file

@ -8,6 +8,7 @@
#include "core/graph/schema_registry.h"
#include "core/graph/model.h"
#include "core/graph/op.h"
#include "test/test_environment.h"
using namespace ONNX_NAMESPACE;
@ -31,7 +32,7 @@ TEST(FormalParamTest, Success) {
}
TEST(FeatureVectorizerTest, TraditionalMlOpTest) {
Model model("traditionalMl");
Model model("traditionalMl", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
// Case: A traditional ml graph.

View file

@ -6,6 +6,8 @@
#include "core/graph/graph_utils.h"
#include "core/graph/model.h"
#include "test/test_environment.h"
using ONNX_NAMESPACE::Utils::DataTypeUtils;
using namespace ONNX_NAMESPACE;
@ -71,7 +73,8 @@ static GraphProto CreateNodeRemovalSubgraph(const std::string& new_output_name =
std::string suffix = add_second_level ? ".top" : ".bottom";
std::string constant_output_name = (new_output_name.empty() ? "constant_in_0" + suffix : new_output_name);
Model model("CreateNodeRemovalSubgraph:" + constant_output_name);
Model model("CreateNodeRemovalSubgraph:" + constant_output_name, false,
DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
TypeProto float_scalar_tensor;
@ -198,7 +201,8 @@ static void CheckNodeRemovalSubgraphUpdate(const std::string& new_name, const Gr
}
static void UpdateSubgraphWhenRemovingNode(bool include_nested = false) {
Model model(std::string("UpdateSubgraphWhenRemovingNode") + (include_nested ? ":Nested" : ":SingleLevel"));
Model model(std::string("UpdateSubgraphWhenRemovingNode") + (include_nested ? ":Nested" : ":SingleLevel"),
false, DefaultLoggingManager().DefaultLogger());
CreateNodeRemovalGraph(model, true, include_nested);
@ -232,13 +236,15 @@ TEST(GraphUtils, UpdateNestedSubgraphWhenRemovingNode) {
// we can't remove a node if it is used as an implicit input in a subgraph, and changing the implicit input name
// will result with in a clash with an existing node in the subgraph
static void DontRemoveNodeIfItWillBreakSubgraph(bool test_nested = false) {
Model model(std::string("DontRemoveNodeIfItWillBreakSubgraph") + (test_nested ? ":Nested" : ":SingleLevel"));
Model model(std::string("DontRemoveNodeIfItWillBreakSubgraph") + (test_nested ? ":Nested" : ":SingleLevel"),
false, DefaultLoggingManager().DefaultLogger());
CreateNodeRemovalGraph(model, false, test_nested);
auto& graph = model.MainGraph();
auto& node_to_remove = *graph.GetNode(1);
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, node_to_remove));
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, node_to_remove,
DefaultLoggingManager().DefaultLogger()));
}
TEST(GraphUtils, DontRemoveNodeIfItWillBreakSubgraph) {
@ -253,7 +259,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
// Create a graph with 5 Id nodes. The graph structure is as follows: Id0 ( Id1 Id2 ( Id3 Id4 ) ).
// First we remove Id2, which leads to: Id0 ( Id1 Id4 Id5 ).
// Then we remove Id1, which leads to: Id2 Id4 Id5, being fed the initializer.
Model model("MultiEdgeRemovalGraph");
Model model("MultiEdgeRemovalGraph", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
TypeProto float_tensor;
@ -303,7 +309,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
}
TEST(GraphUtils, TestMultiOutputRemoveNode) {
Model model("MultiOutputRemovalGraph");
Model model("MultiOutputRemovalGraph", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
TypeProto float_tensor;
@ -341,14 +347,17 @@ TEST(GraphUtils, TestMultiOutputRemoveNode) {
// Try to remove do_0, which should return false
// because both outputs are consumed by downstream Operators.
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, *nodes[0]));
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, *nodes[0],
DefaultLoggingManager().DefaultLogger()));
// Try removing do_0 after removing id_2, which should return true
// because it now has exactly one output consumed by downstream Operators.
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[1]));
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[1],
DefaultLoggingManager().DefaultLogger()));
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[1]));
ASSERT_FALSE(graph_utils::IsOutputUsed(*nodes[0], 0));
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[0]));
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[0],
DefaultLoggingManager().DefaultLogger()));
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[0]));
}

View file

@ -28,7 +28,7 @@ BENCHMARK(BM_CPUAllocator)->Arg(4)->Arg(sizeof(Tensor));
static void BM_ResolveGraph(benchmark::State& state) {
std::shared_ptr<onnxruntime::Model> model_copy;
auto st = onnxruntime::Model::Load("../models/opset8/test_tiny_yolov2/model.onnx", model_copy);
auto st = onnxruntime::Model::Load(ORT_TSTR("../models/opset8/test_tiny_yolov2/model.onnx"), model_copy);
if (!st.IsOK()) {
printf("Parse model failed: %s", st.ErrorMessage().c_str());
abort();

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include <algorithm>
#include <core/common/logging/logging.h>
#include "core/framework/data_types.h"
#include "core/framework/execution_providers.h"
#include "core/framework/kernel_registry.h"
@ -161,8 +162,7 @@ namespace test {
std::string CreateModel() {
RegisterCustomKernel();
Model model("ModelWithOpaque", false);
Model model("ModelWithOpaque", false, logging::LoggingManager::DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;

View file

@ -20,7 +20,7 @@ class DummyGraphTransformer : public GraphTransformer {
private:
mutable bool transformer_invoked_;
Status ApplyImpl(Graph& /*graph*/, bool& /*modified*/, int /*graph_level*/) const override {
Status ApplyImpl(Graph& /*graph*/, bool& /*modified*/, int /*graph_level*/, const logging::Logger&) const override {
transformer_invoked_ = true;
return Status::OK();
}
@ -43,11 +43,12 @@ class DummyRewriteRule : public RewriteRule {
private:
mutable bool rewrite_rule_invoked_;
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/) const override {
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override {
return true;
}
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) const override {
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/,
const logging::Logger& /*logger*/) const override {
rewrite_rule_invoked_ = true;
return Status::OK();
}

Some files were not shown because too many files have changed in this diff Show more