mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
Avoid using the default logger in the graph lib and optimizers (#2361)
1. Use the session logger if it is available. 2. Don't disable warning 4100 globally. We should fix the warnings instead of disabling it.
This commit is contained in:
parent
b15e43a541
commit
109b3cb450
112 changed files with 614 additions and 556 deletions
|
|
@ -183,9 +183,7 @@ if (MSVC)
|
|||
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib for gtest" FORCE)
|
||||
endif()
|
||||
#Always enable exception handling, even for Windows ARM
|
||||
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
|
||||
#Disable 4100 globally. Too many this kind errors in protobuf
|
||||
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4100")
|
||||
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
|
||||
if (NOT onnxruntime_USE_CUDA)
|
||||
SET (CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Gw /GL")
|
||||
SET (CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /Gw /GL")
|
||||
|
|
|
|||
|
|
@ -231,11 +231,11 @@ if (onnxruntime_USE_TENSORRT)
|
|||
target_sources(getSupportedAPITest PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/getopt.cc)
|
||||
target_include_directories(onnx2trt PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/include)
|
||||
target_include_directories(getSupportedAPITest PRIVATE ${ONNXRUNTIME_ROOT}/test/win_getopt/mb/include)
|
||||
target_compile_options(nvonnxparser_static PRIVATE /FIio.h)
|
||||
target_compile_options(nvonnxparser PRIVATE /FIio.h)
|
||||
target_compile_options(trt_onnxify PRIVATE /FIio.h)
|
||||
target_compile_options(onnx2trt PRIVATE /FIio.h)
|
||||
target_compile_options(getSupportedAPITest PRIVATE /FIio.h)
|
||||
target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100)
|
||||
target_compile_options(nvonnxparser PRIVATE /FIio.h /wd4100)
|
||||
target_compile_options(trt_onnxify PRIVATE /FIio.h /wd4100)
|
||||
target_compile_options(onnx2trt PRIVATE /FIio.h /wd4100)
|
||||
target_compile_options(getSupportedAPITest PRIVATE /FIio.h /wd4100)
|
||||
endif()
|
||||
include_directories(${ONNXRUNTIME_ROOT}/../cmake/external/onnx-tensorrt)
|
||||
include_directories(${TENSORRT_INCLUDE_DIR})
|
||||
|
|
|
|||
|
|
@ -478,7 +478,7 @@ endif()
|
|||
|
||||
add_library(onnx_test_data_proto ${TEST_SRC_DIR}/proto/tml.proto)
|
||||
if(WIN32)
|
||||
target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456")
|
||||
target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456" "/wd4100")
|
||||
endif()
|
||||
add_dependencies(onnx_test_data_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "gsl/gsl"
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/func_api.h"
|
||||
#include "core/framework/data_transfer.h"
|
||||
|
|
@ -153,10 +154,19 @@ class IExecutionProvider {
|
|||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_node,
|
||||
std::string& dll_path);
|
||||
|
||||
void SetLogger(const logging::Logger* logger) {
|
||||
logger_ = logger;
|
||||
}
|
||||
|
||||
const logging::Logger* GetLogger() const {
|
||||
return logger_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string type_;
|
||||
AllocatorMap allocators_;
|
||||
|
||||
//It will be set when this object is registered to a session
|
||||
const logging::Logger* logger_ = nullptr;
|
||||
// convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time
|
||||
// contains the same instances as allocators_
|
||||
std::vector<gsl::not_null<const IAllocator*>> allocator_list_;
|
||||
|
|
|
|||
|
|
@ -37,5 +37,6 @@ Create a new Function instance.
|
|||
@param customized_func the IndexedSubGraph to use for the Function.
|
||||
*/
|
||||
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
|
||||
std::unique_ptr<IndexedSubGraph> customized_func);
|
||||
std::unique_ptr<IndexedSubGraph> customized_func,
|
||||
const logging::Logger& logger);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include "core/common/common.h"
|
||||
#include "core/common/const_pointer_container.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/graph/basic_types.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/graph/graph_nodes.h"
|
||||
|
|
@ -822,6 +823,7 @@ class Graph {
|
|||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
Version ir_version,
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
const logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions = {});
|
||||
|
||||
// internal use by the Graph class only
|
||||
|
|
@ -831,6 +833,7 @@ class Graph {
|
|||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
Graph* parent_graph,
|
||||
const Node* parent_node,
|
||||
const logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions = {});
|
||||
|
||||
// Add node with specified <node_proto>.
|
||||
|
|
@ -1038,6 +1041,8 @@ class Graph {
|
|||
|
||||
// number of times Resolve has run.
|
||||
int num_resolves_ = 0;
|
||||
|
||||
const logging::Logger& logger_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -68,13 +68,13 @@ class NodeArg {
|
|||
@param strict If true, the shape update will fail if there are incompatible values.
|
||||
If false, will be lenient and merge only shape info that can be validly processed.
|
||||
@returns Success unless there is existing type or shape info that can't be successfully updated. */
|
||||
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict = true);
|
||||
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger);
|
||||
|
||||
/** Validate and merge type [and shape] info from node_arg.
|
||||
@param strict If true, the shape update will fail if there are incompatible values.
|
||||
If false, will be lenient and merge only shape info that can be validly processed.
|
||||
@returns Success unless there is existing type or shape info that can't be successfully updated. */
|
||||
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict = true);
|
||||
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger);
|
||||
|
||||
/** Gets this NodeArg as a ValueInfoProto. */
|
||||
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
|
||||
|
|
|
|||
|
|
@ -7,15 +7,7 @@
|
|||
#include "core/common/status.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
#include "onnx/defs/schema.h"
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include <mutex>
|
||||
#include <deque>
|
||||
#include "sstream"
|
||||
|
|
|
|||
|
|
@ -37,19 +37,19 @@ class GraphTransformer {
|
|||
@param[out] modified Set to true if the Graph was modified.
|
||||
@returns Status with success or error information.
|
||||
*/
|
||||
common::Status Apply(Graph& graph, bool& modified) const;
|
||||
common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const;
|
||||
|
||||
protected:
|
||||
/** Helper method to call ApplyImpl on any subgraphs in the Node. */
|
||||
common::Status Recurse(Node& node, bool& modified, int graph_level) const {
|
||||
common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
int subgraph_level = ++graph_level;
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
auto& subgraph = *entry.second;
|
||||
ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level));
|
||||
ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level, logger));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer);
|
||||
|
|
@ -61,7 +61,8 @@ class GraphTransformer {
|
|||
// You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases
|
||||
// the call to Graph::Resolve in Apply prior to ApplyImpl being called, and after ApplyImpl fore the main graph
|
||||
// completes (if 'modified' is true) should suffice.
|
||||
virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const = 0;
|
||||
virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger)
|
||||
const = 0;
|
||||
|
||||
const std::string name_;
|
||||
const std::unordered_set<std::string> compatible_provider_types_;
|
||||
|
|
|
|||
|
|
@ -66,8 +66,8 @@ class RewriteRule {
|
|||
@param[in] node The Node to apply the rewrite to.
|
||||
@param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
|
||||
@returns Status indicating success or providing error information */
|
||||
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK();
|
||||
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const {
|
||||
return SatisfyCondition(graph, node, logger) ? Apply(graph, node, rule_effect, logger) : Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -79,11 +79,11 @@ class RewriteRule {
|
|||
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
|
||||
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
|
||||
of the nodes. */
|
||||
virtual bool SatisfyCondition(const Graph& graph, const Node& node) const = 0;
|
||||
virtual bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const = 0;
|
||||
|
||||
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
|
||||
The return-value of node may be different from the input-value due to rewriting.
|
||||
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
|
||||
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const = 0;
|
||||
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const = 0;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
|
|||
@returns Status indicating success or providing error information. */
|
||||
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
|
||||
const std::vector<std::reference_wrapper<const RewriteRule>>& rules,
|
||||
RewriteRule::RewriteRuleEffect& rule_effect) const;
|
||||
RewriteRule::RewriteRuleEffect& rule_effect, const logging::Logger& logger) const;
|
||||
|
||||
private:
|
||||
using RuleEffect = RewriteRule::RewriteRuleEffect;
|
||||
|
|
@ -76,7 +76,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
|
|||
std::vector<std::reference_wrapper<const RewriteRule>> any_op_type_rules_;
|
||||
|
||||
// Performs a single top-down traversal of the graph and applies all registered rules.
|
||||
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ extern "C" {
|
|||
#ifdef _WIN32
|
||||
#define ORT_TSTR(X) L##X
|
||||
#else
|
||||
#define ORT_TSTR(X) (X)
|
||||
#define ORT_TSTR(X) X
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -47,9 +47,8 @@ class ExecutionProviders {
|
|||
ORT_IGNORE_RETURN_VALUE(allocator_idx_map_.insert({allocator->Info(), new_provider_idx}));
|
||||
}
|
||||
|
||||
exec_providers_.push_back(std::move(p_exec_provider));
|
||||
exec_provider_ids_.push_back(provider_id);
|
||||
|
||||
exec_providers_.push_back(std::move(p_exec_provider));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ namespace {
|
|||
Status RemoveFileSpec(PWSTR pszPath, size_t cchPath) {
|
||||
assert(pszPath != nullptr && pszPath[0] != L'\0');
|
||||
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
|
||||
(void)cchPath;
|
||||
for (PWSTR t = L"\0"; *t == L'\0'; t = PathRemoveBackslashW(pszPath))
|
||||
;
|
||||
PWSTR pszLast = PathSkipRootW(pszPath);
|
||||
|
|
|
|||
|
|
@ -3,15 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
#include "onnx/defs/schema.h"
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
|
|
@ -3,15 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
#include "onnx/defs/schema.h"
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
|
|
@ -145,15 +145,15 @@ static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& su
|
|||
}
|
||||
|
||||
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
||||
std::unique_ptr<IndexedSubGraph> customized_func)
|
||||
: parent_graph_(&graph) {
|
||||
std::unique_ptr<IndexedSubGraph> customized_func,
|
||||
const logging::Logger& logger)
|
||||
: parent_graph_(&graph),
|
||||
body_("fused_function_subgraph", false, onnxruntime::ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}),
|
||||
graph.DomainToVersionMap(), {}, logger)
|
||||
{
|
||||
customized_func_body_ = std::move(customized_func);
|
||||
|
||||
// Construct body.
|
||||
body_ = onnxruntime::make_unique<onnxruntime::Model>("fused_function_subgraph", false, onnxruntime::ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}),
|
||||
graph.DomainToVersionMap());
|
||||
auto& function_body_graph = body_->MainGraph();
|
||||
auto& function_body_graph = body_.MainGraph();
|
||||
|
||||
auto meta_def = customized_func_body_->GetMetaDef();
|
||||
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
|
||||
|
|
@ -220,15 +220,24 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, int> GetOpsetVersionMap(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto){
|
||||
return std::unordered_map<std::string, int>{{onnxruntime::kOnnxDomain, static_cast<int>(onnx_func_proto.since_version())}};
|
||||
}
|
||||
|
||||
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
||||
const onnxruntime::NodeIndex& node_index,
|
||||
const ONNX_NAMESPACE::FunctionProto& onnx_func_proto)
|
||||
: parent_graph_(&graph) {
|
||||
const ONNX_NAMESPACE::FunctionProto& onnx_func_proto,
|
||||
const logging::Logger& logger)
|
||||
: parent_graph_(&graph),
|
||||
body_ (onnx_func_proto.name(), false, onnxruntime::ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
GetOpsetVersionMap(onnx_func_proto), {}, logger),
|
||||
onnx_func_proto_(onnx_func_proto)
|
||||
{
|
||||
// Make a copy of the FunctionProto.
|
||||
// All FunctionBody ops with the same op type seem to share the same FunctionProto struct within a model.
|
||||
// Hence, we make a copy prior to generating the graph representation of the function,
|
||||
// as we might make some modifications to the FunctionProto along the way
|
||||
onnx_func_proto_ = onnx_func_proto;
|
||||
|
||||
|
||||
auto node_in_parent_graph = parent_graph_->GetNode(node_index);
|
||||
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
|
||||
|
|
@ -290,9 +299,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
std::unordered_map<std::string, int> domain_to_version;
|
||||
//TODO: set correct domain and version
|
||||
domain_to_version[onnxruntime::kOnnxDomain] = static_cast<int>(onnx_func_proto_.since_version());
|
||||
body_ = onnxruntime::make_unique<onnxruntime::Model>(onnx_func_proto_.name(), false, onnxruntime::ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
|
||||
auto& function_body_graph = body_->MainGraph();
|
||||
|
||||
auto& function_body_graph = body_.MainGraph();
|
||||
// Add node and node args into subgraph
|
||||
// The subgraph preserved the input/output tensor names
|
||||
// in the parent graph for later inlining purpose
|
||||
|
|
@ -379,7 +387,7 @@ const ONNX_NAMESPACE::OpSchema& FunctionImpl::OpSchema() const {
|
|||
}
|
||||
|
||||
const onnxruntime::Graph& FunctionImpl::Body() const {
|
||||
return body_->MainGraph();
|
||||
return body_.MainGraph();
|
||||
}
|
||||
|
||||
const IndexedSubGraph& FunctionImpl::GetIndexedSubGraph() const {
|
||||
|
|
@ -391,7 +399,8 @@ const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const {
|
|||
}
|
||||
|
||||
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
|
||||
std::unique_ptr<IndexedSubGraph> customized_func) {
|
||||
return onnxruntime::make_unique<FunctionImpl>(graph, std::move(customized_func));
|
||||
std::unique_ptr<IndexedSubGraph> customized_func,
|
||||
const logging::Logger& logger) {
|
||||
return onnxruntime::make_unique<FunctionImpl>(graph, std::move(customized_func), logger);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,12 +2,13 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/graph/function.h"
|
||||
#include "core/graph/model.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class Graph;
|
||||
class Node;
|
||||
class Model;
|
||||
} // namespace onnxruntime
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -16,11 +17,13 @@ namespace onnxruntime {
|
|||
class FunctionImpl final : public Function {
|
||||
public:
|
||||
FunctionImpl(const onnxruntime::Graph& graph,
|
||||
std::unique_ptr<IndexedSubGraph> customized_func);
|
||||
std::unique_ptr<IndexedSubGraph> customized_func,
|
||||
const logging::Logger& logger);
|
||||
|
||||
FunctionImpl(const onnxruntime::Graph& graph,
|
||||
const onnxruntime::NodeIndex& node_index,
|
||||
const ONNX_NAMESPACE::FunctionProto& onnx_func);
|
||||
const ONNX_NAMESPACE::FunctionProto& onnx_func,
|
||||
const logging::Logger& logger);
|
||||
|
||||
~FunctionImpl() override;
|
||||
|
||||
|
|
@ -36,7 +39,7 @@ class FunctionImpl final : public Function {
|
|||
const onnxruntime::Graph* const parent_graph_;
|
||||
std::unique_ptr<IndexedSubGraph> customized_func_body_;
|
||||
std::unique_ptr<ONNX_NAMESPACE::OpSchema> op_schema_;
|
||||
std::unique_ptr<onnxruntime::Model> body_;
|
||||
onnxruntime::Model body_;
|
||||
ONNX_NAMESPACE::FunctionProto onnx_func_proto_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ static bool UsingLatestOnnxOpset(const DomainToVersionMap& opset_versions) {
|
|||
|
||||
static Status MergeShapeInfo(const std::string& output_name,
|
||||
const TypeProto_Tensor& source, TypeProto_Tensor& target,
|
||||
bool strict) {
|
||||
bool strict, const logging::Logger& logger) {
|
||||
try {
|
||||
ONNX_NAMESPACE::mergeInShapeInfo(source, target);
|
||||
} catch (const ONNX_NAMESPACE::InferenceError& ex) {
|
||||
|
|
@ -70,9 +70,10 @@ static Status MergeShapeInfo(const std::string& output_name,
|
|||
// mergeInShapeInfo does nothing unless source.shape() is not null, and there would be no conflict if
|
||||
// target.shape() was empty. 'assert' just in case that ever changes.
|
||||
assert(utils::HasShape(source) && utils::HasShape(target));
|
||||
LOGS_DEFAULT(WARNING) << "Error merging shape info for output. '" << output_name
|
||||
<< "' source:" << source.shape() << " target:" << target.shape()
|
||||
<< ". Falling back to lenient merge.";
|
||||
LOGS(logger, WARNING) << "Error merging shape info for output. '" << output_name
|
||||
<< "' source:" << source.shape() << " target:" << target.shape()
|
||||
<< ". Falling back to lenient merge.";
|
||||
|
||||
ONNX_NAMESPACE::UnionShapeInfo(source.shape(), target);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output:", output_name, " ", ex.what());
|
||||
|
|
@ -207,7 +208,7 @@ void NodeArg::ClearShape() {
|
|||
}
|
||||
}
|
||||
|
||||
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict) {
|
||||
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger) {
|
||||
if (!utils::HasType(node_arg_info_)) {
|
||||
*node_arg_info_.mutable_type() = input_type;
|
||||
type_ = DataTypeUtils::ToType(node_arg_info_.type());
|
||||
|
|
@ -236,7 +237,7 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
|
|||
if (utils::HasShape(input_tensor_type)) {
|
||||
auto& current_tensor_type = *current_type.mutable_tensor_type();
|
||||
if (utils::HasShape(current_tensor_type)) {
|
||||
ORT_RETURN_IF_ERROR(MergeShapeInfo(Name(), input_tensor_type, current_tensor_type, strict));
|
||||
ORT_RETURN_IF_ERROR(MergeShapeInfo(Name(), input_tensor_type, current_tensor_type, strict, logger));
|
||||
} else {
|
||||
current_tensor_type = input_tensor_type;
|
||||
}
|
||||
|
|
@ -274,11 +275,11 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict) {
|
||||
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger) {
|
||||
auto status = Status::OK();
|
||||
|
||||
if (utils::HasType(node_arg.node_arg_info_))
|
||||
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict);
|
||||
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, logger);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
|
@ -698,11 +699,13 @@ Graph::Graph(GraphProto* graph_proto,
|
|||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
Version ir_version,
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
const logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions)
|
||||
: Graph(graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, model_functions) {}
|
||||
: Graph(graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, model_functions) {}
|
||||
|
||||
Graph::Graph(GraphProto* graph_proto, const std::unordered_map<std::string, int>& domain_to_version, Version ir_version,
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node,
|
||||
const logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions)
|
||||
: graph_proto_(graph_proto),
|
||||
schema_registry_(schema_registry),
|
||||
|
|
@ -712,7 +715,8 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map<std::string, int>
|
|||
ir_version_(ir_version),
|
||||
using_latest_onnx_opset_(UsingLatestOnnxOpset(domain_to_version)),
|
||||
parent_graph_(parent_graph),
|
||||
parent_node_(parent_node) {
|
||||
parent_node_(parent_node),
|
||||
logger_(logger) {
|
||||
ORT_ENFORCE(graph_proto != nullptr, "graph_proto cannot be null");
|
||||
ArgNameToTypeMap name_to_type_map;
|
||||
|
||||
|
|
@ -769,7 +773,7 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map<std::string, int>
|
|||
// so we prefer the shape from the initializer
|
||||
name_to_type_map[tensor.name()] = t;
|
||||
if (matching_graph_input != nullptr) {
|
||||
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t));
|
||||
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, logger));
|
||||
}
|
||||
} else {
|
||||
// v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these.
|
||||
|
|
@ -805,7 +809,7 @@ Graph::Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::Graph
|
|||
: Graph(&subgraph_proto,
|
||||
parent_graph.DomainToVersionMap(), parent_graph.IrVersion(), parent_graph.schema_registry_,
|
||||
&parent_graph,
|
||||
&parent_node) {
|
||||
&parent_node, parent_graph.logger_) {
|
||||
}
|
||||
|
||||
Status Graph::VerifyNoDuplicateName() {
|
||||
|
|
@ -1456,7 +1460,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
|||
const auto& subgraph_input = *subgraph_inputs->at(i);
|
||||
|
||||
NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
|
||||
status = mutable_nodearg->UpdateTypeAndShape(input_type);
|
||||
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, subgraph.logger_);
|
||||
if (!status.IsOK()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
|
||||
}
|
||||
|
|
@ -1477,7 +1481,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
|||
if (!subgraph_nodearg)
|
||||
continue;
|
||||
|
||||
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg);
|
||||
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, subgraph.logger_);
|
||||
if (!status.IsOK()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
|
||||
}
|
||||
|
|
@ -1666,7 +1670,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op) {
|
|||
// that have no values.
|
||||
TypeProto_Tensor merge_target;
|
||||
(*merge_target.mutable_shape()) = *output_def->Shape();
|
||||
auto status = MergeShapeInfo(output_def->Name(), tensor_type, merge_target, using_latest_onnx_opset_);
|
||||
auto status = MergeShapeInfo(output_def->Name(), tensor_type, merge_target, using_latest_onnx_opset_, logger_);
|
||||
if (!status.IsOK()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node_name, " ", status.ErrorMessage());
|
||||
}
|
||||
|
|
@ -1784,8 +1788,9 @@ Status Graph::VerifyNodeAndOpMatch() {
|
|||
auto iter = model_functions_.find(node.OpType());
|
||||
if (iter != model_functions_.end()) {
|
||||
const ONNX_NAMESPACE::FunctionProto* model_function_proto = iter->second;
|
||||
auto model_func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *model_function_proto);
|
||||
function_container_.emplace_back(std::move(model_func_ptr));
|
||||
function_container_.emplace_back(onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(),
|
||||
*model_function_proto,
|
||||
logger_));
|
||||
node.SetFunctionBody(*function_container_.back());
|
||||
}
|
||||
|
||||
|
|
@ -1805,7 +1810,8 @@ Status Graph::VerifyNodeAndOpMatch() {
|
|||
|
||||
if (node.op_ && node.op_->HasFunction()) {
|
||||
auto onnx_function_proto = node.op_->GetFunction();
|
||||
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *onnx_function_proto);
|
||||
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *onnx_function_proto,
|
||||
logger_);
|
||||
function_container_.emplace_back(std::move(func_ptr));
|
||||
node.SetFunctionBody(*function_container_.back());
|
||||
}
|
||||
|
|
@ -2402,12 +2408,12 @@ void Graph::CleanUnusedInitializers() {
|
|||
// on the first call to Graph::Resolve we are removing unnecessary initializers that should be removed
|
||||
// from the model.
|
||||
// on later calls we are removing initializers that optimizations have made redundant.
|
||||
if (num_resolves_ == 0) {
|
||||
LOGS_DEFAULT(WARNING) << "Removing initializer '"
|
||||
<< name << "'. It is not used by any node and should be removed from the model.";
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "Removing initializer '" << name << "'. It is no longer used by any node.";
|
||||
}
|
||||
if (num_resolves_ == 0) {
|
||||
LOGS(logger_, WARNING) << "Removing initializer '"
|
||||
<< name << "'. It is not used by any node and should be removed from the model.";
|
||||
} else {
|
||||
LOGS(logger_, INFO) << "Removing initializer '" << name << "'. It is no longer used by any node.";
|
||||
}
|
||||
|
||||
erase_list.push_back(name);
|
||||
}
|
||||
|
|
@ -2704,7 +2710,7 @@ Node& Graph::FuseSubGraph(std::unique_ptr<::onnxruntime::IndexedSubGraph> sub_gr
|
|||
func_meta_def->domain);
|
||||
|
||||
fused_node.SetNodeType(Node::Type::Fused);
|
||||
function_container_.emplace_back(MakeFunction(*this, std::move(sub_graph)));
|
||||
function_container_.emplace_back(MakeFunction(*this, std::move(sub_graph), logger_));
|
||||
fused_node.SetFunctionBody(*function_container_.back());
|
||||
|
||||
// Remove nodes fused above.
|
||||
|
|
|
|||
|
|
@ -183,15 +183,15 @@ static void RemoveGraphEdges(Graph& graph, const std::vector<GraphEdge>& edges)
|
|||
/** Given a graph, a list of edges, and a NodeArg name, checks if each of the edges provides an implicit input
|
||||
to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs.
|
||||
This is important when removing a node with this NodeArg as input. */
|
||||
bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph,
|
||||
const std::vector<GraphEdge>& output_edges,
|
||||
const std::string& new_arg_name) {
|
||||
static bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph,
|
||||
const std::vector<GraphEdge>& output_edges,
|
||||
const std::string& new_arg_name, const logging::Logger& logger) {
|
||||
for (const auto& output_edge : output_edges) {
|
||||
if (OutputEdgeProvidesImplicitInput(graph, output_edge)) {
|
||||
const Node& output_edge_node = *graph.GetNode(output_edge.dst_node);
|
||||
if (!CanUpdateImplicitInputNameInSubgraph(output_edge_node, output_edge.arg_name, new_arg_name)) {
|
||||
LOGS_DEFAULT(WARNING) << " Implicit input name " << output_edge.arg_name
|
||||
<< " cannot be safely updated to " << new_arg_name << " in one of the subgraphs.";
|
||||
LOGS(logger, WARNING) << " Implicit input name " << output_edge.arg_name
|
||||
<< " cannot be safely updated to " << new_arg_name << " in one of the subgraphs.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -354,7 +354,7 @@ bool IsOutputUsed(const Node& node, int index) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool CanRemoveNode(const Graph& graph, const Node& node) {
|
||||
bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& logger) {
|
||||
const std::string* output_name = nullptr;
|
||||
if (!IsOnlyOneOutputUsed(graph, node, output_name)) {
|
||||
return false;
|
||||
|
|
@ -386,14 +386,15 @@ bool CanRemoveNode(const Graph& graph, const Node& node) {
|
|||
if (new_name) {
|
||||
// Check that changing the current output name to the new name won't break any subgraphs that consume it
|
||||
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node);
|
||||
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, *new_name);
|
||||
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, *new_name, logger);
|
||||
}
|
||||
|
||||
return can_remove;
|
||||
}
|
||||
|
||||
bool RemoveNode(Graph& graph, Node& node) {
|
||||
assert(CanRemoveNode(graph, node));
|
||||
//TODO: enable the check back
|
||||
//assert(CanRemoveNode(graph, node, nullptr));
|
||||
|
||||
// Note: Node does not produce any graph outputs, and only a single output is used.
|
||||
|
||||
|
|
@ -411,7 +412,8 @@ bool RemoveNode(Graph& graph, Node& node) {
|
|||
ORT_THROW("Should be unreachable if CanRemoveNodeAndMergeEdges is in sync with the logic here.");
|
||||
}
|
||||
|
||||
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name) {
|
||||
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name,
|
||||
const logging::Logger& logger) {
|
||||
// we have no way to handle replacing multiple outputs so check only one is used
|
||||
const std::string* output_name = nullptr;
|
||||
if (!IsOnlyOneOutputUsed(graph, node, output_name)) {
|
||||
|
|
@ -435,7 +437,7 @@ bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const s
|
|||
// Check that changing the current output name to the new name won't break any subgraphs
|
||||
// that consume the current name
|
||||
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node);
|
||||
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, initializer_name);
|
||||
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, initializer_name, logger);
|
||||
}
|
||||
|
||||
return can_remove;
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ Conditions:
|
|||
Subgraph rules:
|
||||
- Removing the node won't break a subgraph that consumes the node's output
|
||||
*/
|
||||
bool CanRemoveNode(const Graph& graph, const Node& node);
|
||||
bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& logger);
|
||||
|
||||
/** Removes the given Node from the Graph.
|
||||
See CanRemoveNode for the conditions that must be satisfied in order to remove the node.
|
||||
|
|
@ -128,7 +128,8 @@ Conditions:
|
|||
- otherwise the required graph output will not be produced
|
||||
- Removing the node won't break a subgraph that consumes the node's output
|
||||
*/
|
||||
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name);
|
||||
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name,
|
||||
const logging::Logger& logger);
|
||||
|
||||
/** Remove a node and replace its output with the provided NodeArg for an initializer.
|
||||
See CanReplaceNodeWithInitializer for the conditions that must be satisfied in order to remove the node.*/
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ Model::Model(const std::string& graph_name,
|
|||
const ModelMetaData& model_metadata,
|
||||
const IOnnxRuntimeOpSchemaRegistryList& local_registries,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_functions) {
|
||||
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_functions,
|
||||
const logging::Logger& logger) {
|
||||
model_proto_ = onnxruntime::make_unique<ModelProto>();
|
||||
model_proto_->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
model_proto_->mutable_graph()->set_name(graph_name);
|
||||
|
|
@ -69,14 +70,17 @@ Model::Model(const std::string& graph_name,
|
|||
|
||||
// need to call private ctor so can't use make_shared
|
||||
GSL_SUPPRESS(r .11)
|
||||
graph_.reset(new Graph(model_proto_->mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry, model_functions_map));
|
||||
graph_.reset(new Graph(model_proto_->mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry,
|
||||
logger, model_functions_map));
|
||||
}
|
||||
|
||||
Model::Model(const ModelProto& model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries)
|
||||
: Model(onnxruntime::make_unique<ModelProto>(model_proto), local_registries) {
|
||||
Model::Model(const ModelProto& model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger)
|
||||
: Model(onnxruntime::make_unique<ModelProto>(model_proto), local_registries, logger) {
|
||||
}
|
||||
|
||||
Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger) {
|
||||
if (!model_proto) {
|
||||
throw std::invalid_argument("ModelProto was null.");
|
||||
}
|
||||
|
|
@ -111,13 +115,13 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
|
|||
if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) {
|
||||
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
|
||||
// in CI to opset 7 or above
|
||||
LOGS_DEFAULT(WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
|
||||
"with opset version 7 or above for opset domain 'ai.onnx'. "
|
||||
"Please upgrade your model to opset 7 or higher. "
|
||||
"For now, this opset "
|
||||
<< version
|
||||
<< " model may run depending upon legacy support "
|
||||
"of some older opset version operators.";
|
||||
LOGS(logger, WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
|
||||
"with opset version 7 or above for opset domain 'ai.onnx'. "
|
||||
"Please upgrade your model to opset 7 or higher. "
|
||||
"For now, this opset "
|
||||
<< version
|
||||
<< " model may run depending upon legacy support "
|
||||
"of some older opset version operators.";
|
||||
}
|
||||
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
|
||||
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
|
||||
|
|
@ -146,7 +150,8 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
|
|||
|
||||
// create instance. need to call private ctor so can't use make_unique
|
||||
GSL_SUPPRESS(r .11)
|
||||
graph_.reset(new Graph(model_proto_->mutable_graph(), domain_to_version, IrVersion(), schema_registry, model_functions_map));
|
||||
graph_.reset(new Graph(model_proto_->mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger,
|
||||
model_functions_map));
|
||||
}
|
||||
|
||||
Version Model::IrVersion() const {
|
||||
|
|
@ -237,7 +242,9 @@ Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger) {
|
||||
// we expect a graph to be present
|
||||
if (!utils::HasGraph(model_proto)) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf.");
|
||||
|
|
@ -246,7 +253,7 @@ Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model,
|
|||
// need to call private ctor so can't use make_shared
|
||||
GSL_SUPPRESS(r .11)
|
||||
try {
|
||||
model.reset(new Model(model_proto, local_registries));
|
||||
model.reset(new Model(model_proto, local_registries, logger));
|
||||
} catch (const std::exception& ex) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
|
||||
}
|
||||
|
|
@ -256,7 +263,9 @@ Status Model::Load(const ModelProto& model_proto, std::shared_ptr<Model>& model,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Model>& model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Model>& model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger) {
|
||||
// we expect a graph to be present
|
||||
if (!utils::HasGraph(*p_model_proto)) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf.");
|
||||
|
|
@ -265,7 +274,7 @@ Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Mo
|
|||
// need to call private ctor so can't use make_shared
|
||||
GSL_SUPPRESS(r .11)
|
||||
try {
|
||||
model.reset(new Model(std::move(p_model_proto), local_registries));
|
||||
model.reset(new Model(std::move(p_model_proto), local_registries, logger));
|
||||
} catch (const std::exception& ex) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
|
||||
}
|
||||
|
|
@ -276,7 +285,9 @@ Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Mo
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger) {
|
||||
int fd;
|
||||
Status status = Env::Default().FileOpenRd(file_path, fd);
|
||||
if (!status.IsOK()) {
|
||||
|
|
@ -293,7 +304,7 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model, con
|
|||
}
|
||||
}
|
||||
try {
|
||||
status = Model::Load(fd, p_model, local_registries);
|
||||
status = Model::Load(fd, p_model, local_registries, logger);
|
||||
} catch (std::exception& ex) {
|
||||
GSL_SUPPRESS(es .84)
|
||||
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
|
||||
|
|
@ -328,36 +339,32 @@ static Status SaveModel(Model& model, const T& file_path) {
|
|||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load
|
||||
GSL_SUPPRESS(r .35)
|
||||
Status Model::Load(const std::wstring& file_path, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
return LoadModel(file_path, p_model, local_registries);
|
||||
}
|
||||
|
||||
Status Model::Save(Model& model, const std::wstring& file_path) {
|
||||
return SaveModel(model, file_path);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load
|
||||
GSL_SUPPRESS(r .35)
|
||||
Status Model::Load(const std::string& file_path, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
return LoadModel(file_path, p_model, local_registries);
|
||||
Status Model::Load(const std::basic_string<ORTCHAR_T>& file_path, std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger) {
|
||||
return LoadModel(file_path, p_model, local_registries, logger);
|
||||
}
|
||||
|
||||
Status Model::Save(Model& model, const std::string& file_path) {
|
||||
return SaveModel(model, file_path);
|
||||
}
|
||||
|
||||
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
|
||||
std::unique_ptr<ModelProto> modelProto = onnxruntime::make_unique<ModelProto>();
|
||||
const bool result = modelProto->ParseFromArray(p_bytes, count);
|
||||
if (!result) {
|
||||
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
|
||||
}
|
||||
|
||||
p_model = std::make_shared<Model>(std::move(modelProto), local_registries);
|
||||
p_model = std::make_shared<Model>(std::move(modelProto), local_registries, logger);
|
||||
|
||||
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));
|
||||
|
||||
|
|
@ -368,7 +375,8 @@ using ::google::protobuf::io::CodedInputStream;
|
|||
using ::google::protobuf::io::FileInputStream;
|
||||
using ::google::protobuf::io::ZeroCopyInputStream;
|
||||
|
||||
Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries) {
|
||||
Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger) {
|
||||
if (fd < 0) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<p_fd> less than 0.");
|
||||
}
|
||||
|
|
@ -394,7 +402,7 @@ Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOp
|
|||
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
|
||||
}
|
||||
#endif
|
||||
p_model = std::make_shared<Model>(std::move(model_proto), local_registries);
|
||||
p_model = std::make_shared<Model>(std::move(model_proto), local_registries, logger);
|
||||
|
||||
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <climits>
|
||||
#include <string>
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
#include "gsl/gsl"
|
||||
|
||||
|
|
@ -22,23 +23,32 @@ class Model {
|
|||
public:
|
||||
static constexpr Version kNoVersion = INT64_MAX;
|
||||
|
||||
explicit Model(const std::string& graph_name,
|
||||
bool is_onnx_domain_only,
|
||||
const logging::Logger& logger)
|
||||
:Model(graph_name,is_onnx_domain_only, ModelMetaData(),IOnnxRuntimeOpSchemaRegistryList(),{},{},
|
||||
logger){}
|
||||
|
||||
// Construct model from scratch.
|
||||
explicit Model(const std::string& graph_name,
|
||||
bool is_onnx_domain_only = false,
|
||||
const ModelMetaData& model_metadata = ModelMetaData(),
|
||||
const IOnnxRuntimeOpSchemaRegistryList& local_registries = IOnnxRuntimeOpSchemaRegistryList(),
|
||||
const std::unordered_map<std::string, int>& domain_to_version = {},
|
||||
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_specific_functions = {});
|
||||
bool is_onnx_domain_only,
|
||||
const ModelMetaData& model_metadata,
|
||||
const IOnnxRuntimeOpSchemaRegistryList& local_registries,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_specific_functions,
|
||||
const logging::Logger& logger);
|
||||
|
||||
// NOTE: after calling this constructor, <*this> model will
|
||||
// hold a copy of <model_proto>.
|
||||
explicit Model(const ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger);
|
||||
|
||||
// NOTE: after calling this constructor, <*this> model will
|
||||
// own the <model_proto>.
|
||||
explicit Model(std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger);
|
||||
|
||||
// Get model's IR version.
|
||||
// Return <kNoVersion> if not specified.
|
||||
|
|
@ -88,10 +98,6 @@ class Model {
|
|||
|
||||
#ifdef _WIN32
|
||||
static common::Status Save(Model& model, const std::wstring& file_path);
|
||||
|
||||
// TODO(Task:132) Use of shared_ptr<X>* in Load/Save methods is confusing.
|
||||
static common::Status Load(const std::wstring& file_path, /*out*/ std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registry = nullptr);
|
||||
#endif
|
||||
static common::Status Save(Model& model, const std::string& file_path);
|
||||
|
||||
|
|
@ -99,23 +105,29 @@ class Model {
|
|||
|
||||
static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto);
|
||||
|
||||
static common::Status Load(const std::string& file_path,
|
||||
// TODO(Task:132) Use of shared_ptr<X>* in Load/Save methods is confusing.
|
||||
static common::Status Load(const std::basic_string<ORTCHAR_T>& file_path,
|
||||
/*out*/ std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger);
|
||||
|
||||
static common::Status Load(int fd, /*out*/ std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger);
|
||||
|
||||
// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
|
||||
static common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger);
|
||||
|
||||
static common::Status Load(const ONNX_NAMESPACE::ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger);
|
||||
|
||||
static common::Status Load(std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_model_proto,
|
||||
/*out*/ std::shared_ptr<Model>& p_model,
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries = nullptr);
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
const logging::Logger& logger);
|
||||
|
||||
private:
|
||||
// Model data.
|
||||
|
|
|
|||
|
|
@ -5,15 +5,7 @@
|
|||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
#include "onnx/defs/schema.h"
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/graph/constants.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ Status AddGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) c
|
|||
|
||||
auto& node = *node_ptr;
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class AddGeluFusion : public GraphTransformer {
|
|||
: GraphTransformer("AddGeluFusion", compatible_execution_providers) {
|
||||
}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ using namespace onnxruntime::common;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
continue;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
|
||||
|
||||
InitializedTensorSet constant_inputs;
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
ORT_RETURN_IF_ERROR(kernel->Compute(&op_kernel_context));
|
||||
|
||||
std::vector<OrtValue> fetches;
|
||||
frame.GetOutputs(fetches);
|
||||
ORT_RETURN_IF_ERROR(frame.GetOutputs(fetches));
|
||||
|
||||
// Go over all output node args and substitute them with the newly computed tensors, which will be
|
||||
// added to the graph as initializers.
|
||||
|
|
@ -62,8 +62,8 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
OrtValue& ort_value = fetches[fetch_idx];
|
||||
|
||||
if (!ort_value.IsTensor()) {
|
||||
LOGS_DEFAULT(WARNING) << "Unsupported output type of " << ort_value.Type()
|
||||
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
|
||||
LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type()
|
||||
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
|
||||
unsupported_output_type = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class ConstantFolding : public GraphTransformer {
|
|||
const std::unordered_set<std::string> excluded_op_types_ =
|
||||
{"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"};
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
/** Create a TensorProto that has the same value as the given OrtValue
|
||||
and the same type and dimensions as the given NodeArg. */
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ static bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& m
|
|||
|
||||
} // namespace
|
||||
|
||||
Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
if (!node)
|
||||
continue;
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", {1, 11}) ||
|
||||
!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders()) ||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class ConvActivationFusion : public GraphTransformer {
|
|||
: GraphTransformer("ConvActivationFusion", compatible_execution_providers) {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const {
|
||||
Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified, const logging::Logger&) const {
|
||||
auto& conv_node = node;
|
||||
auto& add_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index()); // get mutable next node
|
||||
const auto& conv_inputs = conv_node.InputDefs();
|
||||
|
|
@ -103,7 +103,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class ConvAddFusion : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
auto& conv_node = node;
|
||||
Node& bn_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index());
|
||||
|
||||
|
|
@ -144,7 +144,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class ConvBNFusion : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
auto& conv_node = node;
|
||||
auto& mul_node = *graph.GetNode(conv_node.OutputNodesBegin()->Index());
|
||||
|
||||
|
|
@ -114,7 +114,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ class ConvMulFusion : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
|
@ -18,7 +18,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
// We currently support elimination for Dropout operator v1, v6, v7, and v10.
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10})) {
|
||||
return false;
|
||||
|
|
@ -28,7 +28,7 @@ bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) co
|
|||
// It can be safely removed if it has only one output that is used (checked by CanRemoveNode)
|
||||
// and that output is not the 'mask' output.
|
||||
// The 'is_test' attribute in v1 and v6 is captured by the check for the 'mask' output.
|
||||
return graph_utils::CanRemoveNode(graph, node) && !graph_utils::IsOutputUsed(node, 1);
|
||||
return graph_utils::CanRemoveNode(graph, node, logger) && !graph_utils::IsOutputUsed(node, 1);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class EliminateDropout : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ static std::string ToLower(std::string s) {
|
|||
}
|
||||
}
|
||||
|
||||
Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/) const {
|
||||
Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const {
|
||||
for (const onnxruntime::NodeArg* graph_input : graph.GetInputs()) {
|
||||
// Get the current input's type and shape
|
||||
const auto* input_type = graph_input->TypeAsProto();
|
||||
|
|
@ -59,10 +59,10 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
|
|||
// If this dimension actually has a value but it doesn't match the override value, return an
|
||||
// error.
|
||||
if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) {
|
||||
LOGS_DEFAULT(ERROR) << "The model has input '" << graph_input->Name() << "' "
|
||||
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
|
||||
<< "but the size of this dimension " << dimension.dim_value() << " "
|
||||
<< "does not equal the specified override of" << dimension_override << ".";
|
||||
LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' "
|
||||
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
|
||||
<< "but the size of this dimension " << dimension.dim_value() << " "
|
||||
<< "does not equal the specified override of" << dimension_override << ".";
|
||||
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class FreeDimensionOverrideTransformer : public GraphTransformer {
|
|||
explicit FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply);
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
std::map<std::string, int64_t> dimension_override_by_denotation_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ static bool IsSupportedDataType(const Node& node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
|
|||
continue; // we removed the node as part of an earlier fusion
|
||||
|
||||
Node& div = *p_div;
|
||||
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
|
||||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class GeluFusion : public GraphTransformer {
|
|||
GeluFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ bool IsFusableActivation(const Node& node) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
continue; // node was removed
|
||||
|
||||
auto& node = *node_ptr;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class GemmActivationFusion : public GraphTransformer {
|
|||
GemmActivationFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("GemmActivationFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ using namespace ::onnxruntime::common;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status GraphTransformer::Apply(Graph& graph, bool& modified) const {
|
||||
Status GraphTransformer::Apply(Graph& graph, bool& modified, const logging::Logger& logger) const {
|
||||
// the Graph should be in a good state prior this being called, so there should be no need to call Resolve here
|
||||
// ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
|
||||
auto status = ApplyImpl(graph, modified, 0);
|
||||
auto status = ApplyImpl(graph, modified, 0, logger);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
// At least currently, some transformers (InsertCastTransformer and MemcpyTransformer) need this to be called
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ using namespace ::onnxruntime::common;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, TransformerLevel level) const {
|
||||
common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, TransformerLevel level, const logging::Logger& logger) const {
|
||||
const auto& transformers = level_to_transformer_map_.find(level);
|
||||
if (transformers == level_to_transformer_map_.end()) {
|
||||
return Status::OK();
|
||||
|
|
@ -18,7 +18,7 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor
|
|||
bool graph_changed = false;
|
||||
for (const auto& transformer : transformers->second) {
|
||||
bool modified = false;
|
||||
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified));
|
||||
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger));
|
||||
graph_changed = graph_changed || modified;
|
||||
}
|
||||
if (!graph_changed) {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
|
@ -20,7 +21,7 @@ class GraphTransformerManager {
|
|||
common::Status Register(std::unique_ptr<GraphTransformer> transformer, TransformerLevel level);
|
||||
|
||||
// Apply all transformers registered for the given level on the given graph
|
||||
common::Status ApplyTransformers(Graph& graph, TransformerLevel level) const;
|
||||
common::Status ApplyTransformers(Graph& graph, TransformerLevel level, const logging::Logger& logger) const;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformerManager);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
|
@ -18,8 +18,8 @@ Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rul
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
return graph_utils::CanRemoveNode(graph, node);
|
||||
bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
return graph_utils::CanRemoveNode(graph, node, logger);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class EliminateIdentity : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
}; // namespace onnxruntime
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override {
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override {
|
||||
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacement_defs;
|
||||
|
||||
auto output_args = graph.GetOutputs();
|
||||
|
|
@ -187,7 +187,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
}
|
||||
|
||||
if (!removed) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -195,7 +195,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
}
|
||||
};
|
||||
|
||||
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const {
|
||||
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
if (force_cpu_fp32_)
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph));
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
|
|||
node->ReplaceDefs(replacement_defs);
|
||||
modified = modified || casted;
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
|
||||
}
|
||||
|
||||
auto status = Status::OK();
|
||||
|
|
@ -282,7 +282,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
|
|||
RemoveDuplicateCastTransformer remover;
|
||||
// RemoveDuplicateCastTransformer is a special transformer required for correctness.
|
||||
// It is provider agnostic so simply send an empty vector.
|
||||
status = remover.Apply(graph, modified);
|
||||
status = remover.Apply(graph, modified, logger);
|
||||
}
|
||||
|
||||
return status;
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer {
|
|||
}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
|
|||
| |
|
||||
+---------------------+
|
||||
*/
|
||||
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::vector<std::reference_wrapper<Node>> nodes_to_remove;
|
||||
|
|
@ -54,7 +54,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
continue; // we removed the node as part of an earlier fusion
|
||||
|
||||
Node& reduce_mean_node = *p_reduce_mean;
|
||||
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) ||
|
||||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class LayerNormFusion : public GraphTransformer {
|
|||
LayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("LayerNormFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
|
||||
auto& node = *node_ptr;
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class MatMulAddFusion : public GraphTransformer {
|
|||
MatMulAddFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("MatMulAddFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -727,13 +727,13 @@ void NchwcTransformerImpl::Finalize(bool& modified) {
|
|||
}
|
||||
}
|
||||
|
||||
Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
NchwcTransformerImpl impl(graph);
|
||||
GraphViewer graph_viewer(graph);
|
||||
|
||||
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(index);
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
if (node.GetExecutionProviderType() == kCpuExecutionProvider) {
|
||||
impl.Transform(node);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class NchwcTransformer : public GraphTransformer {
|
|||
NchwcTransformer() noexcept : GraphTransformer("NchwcTransformer") {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
|
||||
// Clip opset 6 has min and max as attributes. they're inputs from opset 11 on.
|
||||
|
|
@ -109,7 +109,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -127,7 +127,7 @@ bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!graph_utils::CanRemoveNode(graph, node)) {
|
||||
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ class FuseReluClip : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr<RewriteRule> rule) {
|
|||
|
||||
Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
|
||||
const std::vector<std::reference_wrapper<const RewriteRule>>& rules,
|
||||
RuleEffect& rule_effect) const {
|
||||
RuleEffect& rule_effect, const logging::Logger& logger) const {
|
||||
for (const RewriteRule& rule : rules) {
|
||||
ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect));
|
||||
ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect, logger));
|
||||
// If the current node was removed as a result of a rule, stop rule application for that node.
|
||||
if (rule_effect == RuleEffect::kRemovedCurrentNode) {
|
||||
break;
|
||||
|
|
@ -39,7 +39,7 @@ Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -65,13 +65,13 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
|
||||
rules = GetRewriteRulesForOpType(node->OpType());
|
||||
if (rules) {
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect, logger));
|
||||
}
|
||||
|
||||
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
|
||||
rules = GetAnyOpRewriteRules();
|
||||
if (rules) {
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect, logger));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
}
|
||||
|
||||
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
// Store the statically inferred shape of the input to the Shape operator.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape();
|
||||
std::vector<int64_t> input_dims;
|
||||
|
|
@ -49,7 +49,7 @@ Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& ru
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1})) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -70,7 +70,7 @@ bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node)
|
|||
|
||||
// we're going to create an initializer with the same name as the node output
|
||||
const auto& new_initializer_name = node.OutputDefs()[0]->Name();
|
||||
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name)) {
|
||||
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,9 +24,9 @@ class ShapeToInitializer : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ static bool IsSupportedDataType(const Node& node) {
|
|||
/**
|
||||
Skip Layer Normalization will fuse Add + LayerNormalization into one node.
|
||||
*/
|
||||
Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::vector<std::reference_wrapper<Node>> nodes_to_remove;
|
||||
|
|
@ -37,7 +37,7 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
continue; // we removed the node as part of an earlier fusion.
|
||||
|
||||
Node& add_node = *p_add;
|
||||
ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
|
||||
!graph_utils::IsSupportedProvider(add_node, GetCompatibleExecutionProviders()) ||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ The formula corresponding to LayerNorm activation subgraph:
|
|||
*/
|
||||
class SkipLayerNormFusion : public GraphTransformer {
|
||||
public:
|
||||
SkipLayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
explicit SkipLayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("SkipLayerNormFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
|
@ -17,13 +17,13 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_e
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
// We currently support elimination for Slice operator v1.
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!graph_utils::CanRemoveNode(graph, node)) {
|
||||
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class EliminateSlice : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
|
|||
|
||||
// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
|
||||
// and mainly provides the subgraph recursion functionality
|
||||
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
for (auto& provider : provider_types_) {
|
||||
if (provider != onnxruntime::kCpuExecutionProvider &&
|
||||
provider != onnxruntime::kMklDnnExecutionProvider &&
|
||||
|
|
@ -91,7 +91,7 @@ common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
|
||||
// handle any subgraphs in nodes
|
||||
for (auto& node : graph.Nodes()) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class MemcpyTransformer : public GraphTransformer {
|
|||
: GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {}
|
||||
|
||||
private:
|
||||
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
const std::vector<std::string> provider_types_;
|
||||
std::reference_wrapper<const KernelRegistryManager> registry_manager_;
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ using namespace onnxruntime::common;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const {
|
||||
NodeArg& input_def = *node.MutableInputDefs()[0];
|
||||
const auto& tensor_proto = *graph_utils::GetConstantInitializer(graph, input_def.Name());
|
||||
|
||||
auto new_name = graph.GenerateNodeArgName("UnsqueezeElimination_" + input_def.Name());
|
||||
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_name)) {
|
||||
LOGS_DEFAULT(WARNING) << "UnsqueezeElimination cannot remove node " << node.Name();
|
||||
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_name, logger)) {
|
||||
LOGS(logger, WARNING) << "UnsqueezeElimination cannot remove node " << node.Name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect&
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
// Attempt to remove an Unsqueeze operator only if it gets a constant initializer as input.
|
||||
return graph_utils::IsConstantInitializer(graph, node.InputDefs()[0]->Name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class UnsqueezeElimination : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -151,8 +151,8 @@ class WindowsEnv : public Env {
|
|||
}
|
||||
|
||||
Status MapFileIntoMemory(
|
||||
const ORTCHAR_T* file_path, FileOffsetType offset, size_t length,
|
||||
MappedMemoryPtr& mapped_memory) const override {
|
||||
const ORTCHAR_T*, FileOffsetType, size_t,
|
||||
MappedMemoryPtr&) const override {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "MapFileIntoMemory is not implemented on Windows.");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,25 +20,23 @@ namespace Dml
|
|||
}
|
||||
|
||||
onnxruntime::common::Status GraphTransformer::ApplyImpl(
|
||||
onnxruntime::Graph& graph,
|
||||
onnxruntime::Graph& graph,
|
||||
bool& modified,
|
||||
int graph_level) const
|
||||
{
|
||||
modified = false;
|
||||
|
||||
// Perform fusion
|
||||
{
|
||||
bool transformModifiedGraph = false;
|
||||
PerformOperatorFusion(&graph, &transformModifiedGraph);
|
||||
modified |= transformModifiedGraph;
|
||||
int graph_level, const onnxruntime::logging::Logger&) const {
|
||||
modified = false;
|
||||
|
||||
if (modified)
|
||||
{
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
// Perform fusion
|
||||
{
|
||||
bool transformModifiedGraph = false;
|
||||
PerformOperatorFusion(&graph, &transformModifiedGraph);
|
||||
modified |= transformModifiedGraph;
|
||||
|
||||
if (modified) {
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
}
|
||||
|
||||
return onnxruntime::common::Status::OK();
|
||||
return onnxruntime::common::Status::OK();
|
||||
}
|
||||
|
||||
static std::string GetUniqueNodeName(const onnxruntime::Node* node)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
// Lotus framework headers for onnxruntime::IExecutionProvider (not part of the operator ABI).
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
|
@ -18,7 +19,7 @@ namespace Dml
|
|||
GraphTransformer(const std::string& name, std::shared_ptr<onnxruntime::KernelRegistry> dmlRegistry);
|
||||
|
||||
private:
|
||||
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level = 0) const final;
|
||||
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final;
|
||||
|
||||
private:
|
||||
void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const;
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const {
|
||||
Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified, const onnxruntime::logging::Logger&) const {
|
||||
auto& BatchNormalization_node = node;
|
||||
const auto& add_node = *BatchNormalization_node.OutputNodesBegin();
|
||||
const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs();
|
||||
|
|
@ -88,7 +88,7 @@ Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleE
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool BatchNormalizationAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool BatchNormalizationAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7}) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class BatchNormalizationAddFusion : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const {
|
||||
auto& BatchNormalization_node = node;
|
||||
const auto& mul_node = *BatchNormalization_node.OutputNodesBegin();
|
||||
const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs();
|
||||
|
|
@ -101,7 +101,7 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool BatchNormalizationMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
bool BatchNormalizationMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7}) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ class BatchNormalizationMulFusion : public RewriteRule {
|
|||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const onnxruntime::logging::Logger&) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ NGRAPHCustomOp::~NGRAPHCustomOp() {
|
|||
}
|
||||
|
||||
//This method gets called in critical path of execution: Optimize
|
||||
Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const {
|
||||
Status NGRAPHCustomOp::Initialize(const OrtApi* api, OrtKernelContext* context) const {
|
||||
Ort::CustomOpApi ort{*api};
|
||||
|
||||
size_t num_inputs = ort.KernelContext_GetInputCount(context);
|
||||
|
|
@ -176,7 +176,7 @@ Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* c
|
|||
}
|
||||
|
||||
//This method gets called in critical path of execution: Optimize
|
||||
Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const {
|
||||
Status NGRAPHCustomOp::Compute(const OrtApi* api, OrtKernelContext* context) const {
|
||||
Ort::CustomOpApi ort{*api};
|
||||
|
||||
// Initialize nGraph function if it is not already initialized.
|
||||
|
|
|
|||
|
|
@ -29,12 +29,12 @@ class NGRAPHCustomOp {
|
|||
const ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const std::shared_ptr<ngraph::runtime::Backend>& ng_backend);
|
||||
|
||||
Status Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const;
|
||||
Status Compute(const OrtApi* api, OrtKernelContext* context) const;
|
||||
|
||||
~NGRAPHCustomOp();
|
||||
|
||||
private:
|
||||
Status Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const;
|
||||
Status Initialize(const OrtApi* api, OrtKernelContext* context) const;
|
||||
|
||||
std::shared_ptr<ngraph::runtime::Backend> ng_backend_;
|
||||
|
||||
|
|
|
|||
|
|
@ -527,14 +527,15 @@ NGRAPHExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
|
|||
return result;
|
||||
}
|
||||
|
||||
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node) {
|
||||
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node, const logging::Logger& logger) {
|
||||
const auto* node_function = fused_node->GetFunctionBody();
|
||||
|
||||
ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", fused_node->Name());
|
||||
|
||||
const Graph& node_subgraph = node_function->Body();
|
||||
onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{},
|
||||
IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap()};
|
||||
IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(),
|
||||
std::vector<ONNX_NAMESPACE::FunctionProto>(), logger};
|
||||
|
||||
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
|
||||
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
|
@ -551,9 +552,7 @@ Status NGRAPHExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& f
|
|||
|
||||
// Local copy of backend since, class members cannot be captured.
|
||||
auto ngraph_backend = ng_backend_;
|
||||
compute_info.create_state_func = [model_proto = GetModelProtoFromFusedNode(fused_node), ngraph_backend]
|
||||
(ComputeContext* context, FunctionState* state)
|
||||
{
|
||||
compute_info.create_state_func = [model_proto = GetModelProtoFromFusedNode(fused_node, *GetLogger()), ngraph_backend](ComputeContext* context, FunctionState* state) {
|
||||
auto* p = new ngraph_ep::NGRAPHCustomOp(context, model_proto, ngraph_backend);
|
||||
*state = p;
|
||||
return 0;
|
||||
|
|
@ -564,7 +563,7 @@ Status NGRAPHExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& f
|
|||
delete reinterpret_cast<onnxruntime::ngraph_ep::NGRAPHCustomOp*>(state);
|
||||
};
|
||||
|
||||
compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
|
||||
compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
|
||||
onnxruntime::ngraph_ep::NGRAPHCustomOp* ng_custom_op = reinterpret_cast<onnxruntime::ngraph_ep::NGRAPHCustomOp*>(state);
|
||||
return ng_custom_op->Compute(api, context);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
if (group.second) {
|
||||
nodes_list_output.push_back(group);
|
||||
} else {
|
||||
onnxruntime::Model model_build(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap());
|
||||
onnxruntime::Model model_build(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
|
||||
onnxruntime::Graph& graph_build = model_build.MainGraph();
|
||||
|
||||
//Add node and node args
|
||||
|
|
@ -285,7 +285,7 @@ std::vector<std::unique_ptr<ComputeCapability>>
|
|||
TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const {
|
||||
// Construct modelproto from graph
|
||||
onnxruntime::Model model(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap());
|
||||
onnxruntime::Model model(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
|
||||
onnxruntime::Graph& graph_build = model.MainGraph();
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
std::vector<onnxruntime::NodeArg*> inputs, outputs;
|
||||
|
|
@ -379,7 +379,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
|
|||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
}
|
||||
const Graph& graph_body = func_body->Body();
|
||||
onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph_body.DomainToVersionMap());
|
||||
onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph_body.DomainToVersionMap(), std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
|
||||
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
|
||||
*(model_proto.mutable_graph()) = graph_body.ToGraphProto();
|
||||
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
|
|
|||
|
|
@ -173,9 +173,9 @@ common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr<IExec
|
|||
return st;
|
||||
}
|
||||
}
|
||||
execution_providers_.Add(provider_type, std::move(p_exec_provider));
|
||||
|
||||
return Status::OK();
|
||||
p_exec_provider->SetLogger(session_logger_);
|
||||
return execution_providers_.Add(provider_type, std::move(p_exec_provider));
|
||||
}
|
||||
|
||||
common::Status InferenceSession::RegisterGraphTransformer(
|
||||
|
|
@ -270,7 +270,8 @@ common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
|
|||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
||||
*session_logger_);
|
||||
};
|
||||
|
||||
common::Status st = Load(loader, "model_loading_uri");
|
||||
|
|
@ -296,7 +297,8 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) {
|
|||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
||||
*session_logger_);
|
||||
};
|
||||
|
||||
return Load(loader, "model_loading_proto");
|
||||
|
|
@ -311,7 +313,7 @@ common::Status InferenceSession::Load(std::unique_ptr<ModelProto> p_model_proto)
|
|||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(std::move(p_model_proto), model,
|
||||
HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_);
|
||||
};
|
||||
|
||||
return Load(loader, "model_loading_proto");
|
||||
|
|
@ -333,7 +335,8 @@ common::Status InferenceSession::Load(std::istream& model_istream) {
|
|||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
||||
*session_logger_);
|
||||
};
|
||||
|
||||
return Load(loader, "model_loading_istream");
|
||||
|
|
@ -355,7 +358,8 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
|
|||
}
|
||||
#endif
|
||||
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
||||
*session_logger_);
|
||||
};
|
||||
|
||||
return Load(loader, "model_loading_array");
|
||||
|
|
@ -375,7 +379,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
// 5. insert cast nodes.
|
||||
|
||||
// first apply global(execution provider independent), level 1(default/system/basic) graph to graph optimizations
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1));
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *session_logger_));
|
||||
|
||||
#ifdef USE_DML
|
||||
// TODO: this is a temporary workaround to apply the DML EP's custom graph transformer prior to partitioning. This
|
||||
|
|
@ -390,7 +394,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
Dml::GraphTransformer dml_transformer(onnxruntime::kDmlExecutionProvider, std::move(dml_registry));
|
||||
|
||||
bool modified = false;
|
||||
dml_transformer.Apply(graph, modified);
|
||||
dml_transformer.Apply(graph, modified, *session_logger_);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -401,12 +405,12 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
// apply transformers except default transformers
|
||||
// Default transformers are required for correctness and they are owned and run by inference session
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i < static_cast<int>(TransformerLevel::MaxTransformerLevel); i++) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i)));
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *session_logger_));
|
||||
}
|
||||
|
||||
bool modified = false;
|
||||
// Insert cast node/s.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified));
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified, *session_logger_));
|
||||
|
||||
// Now every node should be already assigned to an execution provider
|
||||
std::unordered_map<std::string, std::vector<std::string>> node_placements;
|
||||
|
|
@ -455,7 +459,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
|
||||
// Insert copy node/s.
|
||||
MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager};
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified));
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified, *session_logger_));
|
||||
|
||||
return common::Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ std::vector<float> Add_Simple(const std::vector<float>& input_a_data, const std:
|
|||
const std::vector<float>& input_small_size = input_a_data.size() < input_b_data.size() ? input_a_data : input_b_data;
|
||||
|
||||
std::vector<float> output(input_large_size.size());
|
||||
for (int iter = 0; iter < input_large_size.size() / input_small_size.size(); iter++) {
|
||||
for (size_t iter = 0; iter < input_large_size.size() / input_small_size.size(); iter++) {
|
||||
std::transform(input_large_size.begin() + iter * input_small_size.size(),
|
||||
input_large_size.begin() + (iter + 1) * input_small_size.size(),
|
||||
input_small_size.begin(),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "test/framework/model_builder_utils.h"
|
||||
#include "core/framework/allocation_planner.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "test/test_environment.h"
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -162,7 +163,10 @@ class PlannerTest : public ::testing::Test {
|
|||
|
||||
public:
|
||||
PlannerTest()
|
||||
: model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_, nullptr) {
|
||||
: model_("test", false, DefaultLoggingManager().DefaultLogger()),
|
||||
graph_(model_.MainGraph()),
|
||||
tp_("test", 1),
|
||||
state_(execution_providers_, false, &tp_, nullptr) {
|
||||
std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
|
||||
in_place_kernel_ =
|
||||
KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build();
|
||||
|
|
@ -171,8 +175,6 @@ class PlannerTest : public ::testing::Test {
|
|||
execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider));
|
||||
}
|
||||
|
||||
~PlannerTest() = default;
|
||||
|
||||
onnxruntime::NodeArg* Arg(const std::string& name) {
|
||||
auto iter = name_to_arg_.find(name);
|
||||
if (name_to_arg_.end() != iter) return iter->second;
|
||||
|
|
|
|||
|
|
@ -73,7 +73,9 @@ CREATE_INITIALIZER_FUNC(float, TensorProto_DataType_FLOAT, add_float_data)
|
|||
CREATE_INITIALIZER_FUNC(int64_t, TensorProto_DataType_INT64, add_int64_data)
|
||||
// TO DO: Figure out a way to enable it again
|
||||
TEST(CUDAFenceTests, DISABLED_PartOnCPU) {
|
||||
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test");
|
||||
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test",
|
||||
false,
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -135,7 +137,7 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) {
|
|||
}
|
||||
|
||||
TEST(CUDAFenceTests, TileWithInitializer) {
|
||||
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test");
|
||||
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -189,7 +191,8 @@ TEST(CUDAFenceTests, TileWithInitializer) {
|
|||
TEST(CUDAFenceTests, TileWithComputedInput) {
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[onnxruntime::kOnnxDomain] = 7;
|
||||
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
|
||||
std::unique_ptr<onnxruntime::Model> model = onnxruntime::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ namespace test {
|
|||
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> DummyGraphWithClip() {
|
||||
auto model = std::make_shared<onnxruntime::Model>("test");
|
||||
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -42,7 +42,7 @@ class ExecutionFrameTest : public ::testing::Test {
|
|||
};
|
||||
|
||||
TEST_F(ExecutionFrameTest, TensorAllocationTest) {
|
||||
onnxruntime::Model model("test");
|
||||
onnxruntime::Model model("test", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model.MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -112,7 +112,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
|
|||
|
||||
TEST_F(ExecutionFrameTest, FeedInDataTest) {
|
||||
onnxruntime::Model model("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
std::unordered_map<std::string, int>{{"", 10}});
|
||||
std::unordered_map<std::string, int>{{"", 10}}, {},
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model.MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -164,7 +165,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
|
|||
auto xp_type = cpu_xp->Type();
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[onnxruntime::kOnnxDomain] = 7;
|
||||
onnxruntime::Model model("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
|
||||
onnxruntime::Model model("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model.MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@
|
|||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/optimizer/dummy_graph_transformer.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace std;
|
||||
|
|
@ -133,16 +132,17 @@ class InferenceSessionGetGraphWrapper : public InferenceSession {
|
|||
namespace test {
|
||||
static void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
|
||||
const std::vector<float>& expected_values);
|
||||
static const std::string MODEL_URI = "testdata/mul_1.onnx";
|
||||
static const std::string MODEL_URI_NO_OPSET = "testdata/mul_1.noopset.onnx";
|
||||
static constexpr const ORTCHAR_T* MODEL_URI = ORT_TSTR("testdata/mul_1.onnx");
|
||||
static constexpr const ORTCHAR_T* MODEL_URI_NO_OPSET = ORT_TSTR("testdata/mul_1.noopset.onnx");
|
||||
//static const std::string MODEL_URI = "./testdata/squeezenet/model.onnx"; // TODO enable this after we've weights?
|
||||
|
||||
static void CreateMatMulModel(std::unique_ptr<onnxruntime::Model>& p_model, ProviderType provider_type) {
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[onnxruntime::kOnnxDomain] = 7;
|
||||
// Generate the input & output def lists
|
||||
p_model = onnxruntime::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version);
|
||||
std::vector<ONNX_NAMESPACE::FunctionProto> model_specific_functions;
|
||||
p_model = onnxruntime::make_unique<Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, model_specific_functions, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = p_model->MainGraph();
|
||||
|
||||
TypeProto tensor_float;
|
||||
|
|
@ -455,11 +455,11 @@ TEST(InferenceSessionTests, ModelMetadata) {
|
|||
|
||||
so.session_logid = "InferenceSessionTests.ModelMetadata";
|
||||
InferenceSession session_object{so, &DefaultLoggingManager()};
|
||||
string model_uri = "../models/opset8/test_squeezenet/model.onnx";
|
||||
auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx");
|
||||
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> p_model;
|
||||
Status st = onnxruntime::Model::Load(model_uri, p_model);
|
||||
Status st = onnxruntime::Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(st.IsOK());
|
||||
const onnxruntime::Graph& graph = p_model->MainGraph();
|
||||
|
||||
|
|
@ -1029,7 +1029,7 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
|
|||
}
|
||||
|
||||
TEST(ExecutionProviderTest, FunctionTest) {
|
||||
onnxruntime::Model model("graph_1");
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
|
@ -1116,7 +1116,7 @@ TEST(ExecutionProviderTest, FunctionTest) {
|
|||
}
|
||||
|
||||
TEST(ExecutionProviderTest, FunctionInlineTest) {
|
||||
onnxruntime::Model model("graph_1");
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
|
||||
ONNX_NAMESPACE::FunctionProto fc_proto;
|
||||
fc_proto.set_name("FC");
|
||||
|
|
|
|||
|
|
@ -7,16 +7,17 @@
|
|||
#include "core/graph/model.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
static const std::string MODEL_FOLDER = "testdata/transform/";
|
||||
#define MODEL_FOLDER ORT_TSTR("testdata/transform/")
|
||||
|
||||
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
|
||||
TEST(TransformerTest, InsertCastGPUTest) {
|
||||
auto model = std::make_shared<onnxruntime::Model>("test");
|
||||
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
|
||||
TypeProto tensor_float_16;
|
||||
|
|
@ -38,7 +39,7 @@ TEST(TransformerTest, InsertCastGPUTest) {
|
|||
InsertCastTransformer transformer("Test");
|
||||
|
||||
bool modified = true;
|
||||
status = transformer.Apply(graph, modified);
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK());
|
||||
status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
|
@ -64,7 +65,7 @@ TEST(TransformerTest, InsertCastGPUTest) {
|
|||
}
|
||||
|
||||
TEST(TransformerTest, InsertCastAllCPUTest) {
|
||||
auto model = std::make_shared<onnxruntime::Model>("test");
|
||||
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
|
||||
TypeProto tensor_float_16;
|
||||
|
|
@ -86,7 +87,7 @@ TEST(TransformerTest, InsertCastAllCPUTest) {
|
|||
InsertCastTransformer transformer("Test");
|
||||
|
||||
bool modified = true;
|
||||
EXPECT_TRUE(transformer.Apply(graph, modified).IsOK());
|
||||
EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_EQ(graph.NumberOfNodes(), 7);
|
||||
|
|
@ -109,9 +110,9 @@ TEST(TransformerTest, InsertCastAllCPUTest) {
|
|||
|
||||
// test that when there are 3 Cast ops in a row we remove the correct ones
|
||||
TEST(TransformerTest, ThreeInARowRemoval) {
|
||||
std::string model_uri = MODEL_FOLDER + "triple-cast.onnx";
|
||||
auto model_uri = MODEL_FOLDER ORT_TSTR("triple-cast.onnx");
|
||||
std::shared_ptr<Model> model;
|
||||
auto status = Model::Load(model_uri, model);
|
||||
auto status = Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
|
|
@ -123,7 +124,7 @@ TEST(TransformerTest, ThreeInARowRemoval) {
|
|||
InsertCastTransformer transformer("Test");
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified);
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK()) << status;
|
||||
EXPECT_TRUE(modified) << "Transformer should have removed some Cast nodes";
|
||||
status = graph.Resolve();
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/graph/model.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
namespace onnxruntime {
|
||||
|
|
@ -74,7 +75,8 @@ TEST(TransformerTest, MemcpyTransformerTest) {
|
|||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = 7;
|
||||
auto model = std::make_shared<onnxruntime::Model>("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version);
|
||||
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
|
||||
TypeProto tensor_float_type;
|
||||
|
|
@ -111,7 +113,7 @@ TEST(TransformerTest, MemcpyTransformerTest) {
|
|||
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified);
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_TRUE(modified);
|
||||
|
||||
|
|
@ -128,7 +130,8 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) {
|
|||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = 7;
|
||||
auto model = std::make_shared<onnxruntime::Model>("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version);
|
||||
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
|
||||
TypeProto tensor_float_type;
|
||||
|
|
@ -165,7 +168,7 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) {
|
|||
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified);
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_TRUE(modified);
|
||||
|
||||
|
|
@ -204,7 +207,8 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
|
|||
false,
|
||||
ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version);
|
||||
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
|
||||
TensorProto parent_constant(value_tensor);
|
||||
|
|
@ -221,7 +225,8 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
|
|||
false,
|
||||
ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(),
|
||||
subgraph_domain_to_version);
|
||||
subgraph_domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& subgraph = sub_model->MainGraph();
|
||||
|
||||
TensorProto local_constant(value_tensor);
|
||||
|
|
@ -278,7 +283,7 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
|
|||
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified);
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_TRUE(modified);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -298,7 +298,7 @@ TEST_F(OpaqueTypeTests, RunModel) {
|
|||
IOnnxRuntimeOpSchemaRegistryList custom_schema_registries_ = {registry->GetOpschemaRegistry()};
|
||||
std::unordered_map<std::string, int> domain_to_version = {{onnxruntime::kMLDomain, 8}};
|
||||
|
||||
Model model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version);
|
||||
Model model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include "core/graph/op.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace std;
|
||||
|
|
@ -42,7 +43,7 @@ TEST(SessionStateTest, AddGetKernelTest) {
|
|||
ExecutionProviders execution_providers;
|
||||
SessionState s{execution_providers, true, &tp, nullptr};
|
||||
|
||||
onnxruntime::Model model("graph_1");
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
|
@ -94,10 +95,11 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
const TestParam& param = GetParam();
|
||||
concurrency::ThreadPool tp{"test", 1};
|
||||
|
||||
std::string model_path = "testdata/optional_inputs_ir" + std::to_string(param.ir_version) + ".onnx";
|
||||
std::basic_ostringstream<ORTCHAR_T> oss;
|
||||
oss << ORT_TSTR("testdata/optional_inputs_ir") << param.ir_version << ORT_TSTR(".onnx");
|
||||
Status status;
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE((status = Model::Load(model_path, model)).IsOK()) << status;
|
||||
ASSERT_TRUE((status = Model::Load(oss.str(), model, nullptr, DefaultLoggingManager().DefaultLogger())).IsOK()) << status;
|
||||
Graph& graph = model->MainGraph();
|
||||
// take a copy as this gets cleared during session state initialization
|
||||
InitializedTensorSet initializers = graph.GetAllInitializedTensors();
|
||||
|
|
@ -112,7 +114,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
SessionState session_state(execution_providers, param.enable_mem_pattern, &tp, nullptr);
|
||||
SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph, session_state,
|
||||
SessionStateInitializer session_initializer(param.enable_mem_pattern, oss.str(), graph, session_state,
|
||||
execution_providers, krm);
|
||||
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "test/framework/model_builder_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace std;
|
||||
|
||||
|
|
@ -22,7 +24,7 @@ class ShapeInferenceTest : public ::testing::Test {
|
|||
std::unordered_map<string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;
|
||||
|
||||
public:
|
||||
ShapeInferenceTest() : model_("Test"), node_count_(0) {}
|
||||
ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {}
|
||||
|
||||
void Input(const std::string& name, const Type& type) {
|
||||
name_to_arg_[name] = onnxruntime::make_unique<onnxruntime::NodeArg>(name, &type.value);
|
||||
|
|
|
|||
|
|
@ -3,17 +3,7 @@
|
|||
|
||||
#include "core/framework/data_types.h"
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
|
||||
#include "onnx/defs/schema.h"
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
|
@ -293,7 +283,7 @@ class SparseTensorTests : public testing::Test {
|
|||
registry(std::make_shared<CustomRegistry>()),
|
||||
custom_schema_registries_{registry->GetOpschemaRegistry()},
|
||||
domain_to_version{{onnxruntime::kMLDomain, 10}},
|
||||
model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version),
|
||||
model("SparseTensorTest", false, ModelMetaData(), custom_schema_registries_, domain_to_version, {}, DefaultLoggingManager().DefaultLogger()),
|
||||
graph(model.MainGraph()) {
|
||||
EXPECT_TRUE(session_object.RegisterCustomRegistry(registry).IsOK());
|
||||
}
|
||||
|
|
@ -306,21 +296,21 @@ class SparseTensorTests : public testing::Test {
|
|||
schema.SinceVersion(10);
|
||||
schemas.push_back(schema);
|
||||
|
||||
Action register_kernel = [](CustomRegistry* registry) {
|
||||
Action register_kernel = [](CustomRegistry* registry2) {
|
||||
auto kernel_def_builder = Op::KernelDef();
|
||||
kernel_def_builder
|
||||
.SetDomain(onnxruntime::kMLDomain)
|
||||
.SinceVersion(10)
|
||||
.Provider(onnxruntime::kCpuExecutionProvider);
|
||||
KernelCreateFn kernel_create_fn = [](const OpKernelInfo& info) { return new typename Op::OpKernelImpl(info); };
|
||||
EXPECT_TRUE(registry->RegisterCustomKernel(kernel_def_builder, kernel_create_fn).IsOK());
|
||||
EXPECT_TRUE(registry2->RegisterCustomKernel(kernel_def_builder, kernel_create_fn).IsOK());
|
||||
};
|
||||
register_actions.push_back(register_kernel);
|
||||
}
|
||||
|
||||
void RegisterOps() {
|
||||
EXPECT_TRUE(registry->RegisterOpSet(schemas, onnxruntime::kMLDomain, 10, 11).IsOK());
|
||||
for (auto registerop : register_actions)
|
||||
for (auto& registerop : register_actions)
|
||||
registerop(registry.get());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ static const bool kSchemasRegistered = RegisterCustomSchemas();
|
|||
TEST(GraphTraversalTest, ReverseDFS) {
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
/* Case 1: A normal graph.
|
||||
|
|
@ -219,7 +219,7 @@ TEST(GraphTraversalTest, ReverseDFS) {
|
|||
}
|
||||
|
||||
TEST(ResolvingGraphTest, GraphConstruction_VerifyNoDuplicateName) {
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
EXPECT_EQ("graph_1", graph.Name());
|
||||
|
|
@ -252,7 +252,7 @@ TEST(ResolvingGraphTest, GraphConstruction_VerifyNoDuplicateName) {
|
|||
}
|
||||
|
||||
TEST(ResolvingGraphTest, GraphConstruction_VerifyNodeAndOpMatch) {
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
|
|
@ -275,7 +275,7 @@ TEST(ResolvingGraphTest, GraphConstruction_VerifyNodeAndOpMatch) {
|
|||
TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
/* A normal graph.
|
||||
|
|
@ -336,7 +336,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
|
|||
|
||||
EXPECT_TRUE(Model::Save(model, "graph_1.onnx").IsOK());
|
||||
std::shared_ptr<Model> model2;
|
||||
EXPECT_TRUE(Model::Load("graph_1.onnx", model2).IsOK());
|
||||
EXPECT_TRUE(Model::Load(ORT_TSTR("graph_1.onnx"), model2, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
||||
auto model_proto = model.ToProto();
|
||||
auto model_proto2 = model2->ToProto();
|
||||
|
|
@ -345,7 +345,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
|
|||
|
||||
// Load the model again to ensure that it's still the right thing.
|
||||
//EXPECT_EQ(Model::Load(model_proto2, &model2), Status::OK());
|
||||
model2.reset(new Model(model_proto2));
|
||||
model2.reset(new Model(model_proto2, nullptr, DefaultLoggingManager().DefaultLogger()));
|
||||
Graph& graph2 = model2->MainGraph();
|
||||
for (auto& node : graph2.Nodes()) {
|
||||
auto node_name_to_input_output_iter = expected_node_name_to_input_output_args.find(node.Name());
|
||||
|
|
@ -368,7 +368,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsAcyclic) {
|
|||
TEST(ResolvingGraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
// node_1 (Identity) node_2 (Identity)
|
||||
|
|
@ -448,7 +448,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
|
|||
TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) {
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::unordered_map<std::string, int> map;
|
||||
|
|
@ -541,7 +541,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
|
|||
ASSERT_TRUE(result) << "Failed to load model from serialized protobuf";
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> p_tmp_model;
|
||||
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr);
|
||||
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
|
||||
auto& graph2 = p_tmp_model->MainGraph();
|
||||
status = graph2.Resolve();
|
||||
|
|
@ -555,7 +555,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
|
|||
TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) {
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("UnusedInitializerIsIgnored");
|
||||
Model model("UnusedInitializerIsIgnored", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
|
|
@ -596,7 +596,7 @@ TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) {
|
|||
ASSERT_TRUE(result) << "Failed to load model from serialized protobuf";
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> p_tmp_model;
|
||||
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr);
|
||||
auto x = onnxruntime::Model::Load(model_proto, p_tmp_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
|
||||
auto& graph2 = p_tmp_model->MainGraph();
|
||||
status = graph2.Resolve();
|
||||
|
|
@ -621,7 +621,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsNotAcyclic) {
|
|||
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
auto& input_arg1 = graph.GetOrCreateNodeArg("node_1_in_1", &tensor_int32);
|
||||
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
|
||||
|
|
@ -643,7 +643,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckIsNotAcyclic) {
|
|||
}
|
||||
|
||||
TEST(ResolvingGraphTest, GraphConstruction_OnlyInitializer) {
|
||||
onnxruntime::Model model("graph");
|
||||
onnxruntime::Model model("graph", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
ONNX_NAMESPACE::TensorProto weight;
|
||||
|
|
@ -663,7 +663,7 @@ TEST(ResolvingGraphTest, GraphConstruction_OnlyInitializer) {
|
|||
TEST(ResolvingGraphTest, GraphConstruction_TypeInference) {
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
/* Case 1: A normal graph.
|
||||
|
|
@ -725,7 +725,7 @@ TEST(ResolvingGraphTest, GraphConstruction_TypeInference) {
|
|||
|
||||
EXPECT_TRUE(Model::Save(model, "model_x.onnx").IsOK());
|
||||
std::shared_ptr<Model> loaded_model;
|
||||
EXPECT_TRUE(Model::Load("model_x.onnx", loaded_model).IsOK());
|
||||
EXPECT_TRUE(Model::Load(ORT_TSTR("model_x.onnx"), loaded_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
EXPECT_EQ(2, loaded_model->MainGraph().GetInputs().size());
|
||||
|
||||
auto& graph_proto = graph.ToGraphProto();
|
||||
|
|
@ -740,7 +740,7 @@ TEST(ResolvingGraphTest, GraphConstruction_TypeInference) {
|
|||
TEST(ResolvingGraphTest, ShapeInferenceErrorHandling) {
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("graph");
|
||||
Model model("graph", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
TypeProto tensor_int32;
|
||||
|
|
@ -765,7 +765,7 @@ TEST(TestAddAttribute, AddTensorAttribute) {
|
|||
.Output(0, "output_1", "docstr for output_1.", "tensor(int64)");
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
TypeProto output_type;
|
||||
output_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
|
||||
|
|
@ -810,7 +810,7 @@ void AddAttribute(onnxruntime::Node& p_node, const std::string& attr_name, std::
|
|||
TEST(TypeInferenceTest, TypeAttribute) {
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", nullptr);
|
||||
outputs.push_back(&output_arg);
|
||||
|
|
@ -834,7 +834,7 @@ TEST(TypeInferenceTest, VariadicOutput) {
|
|||
std::vector<NodeArg*> outputs;
|
||||
TypeProto tensor_type;
|
||||
tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
auto& X = graph.GetOrCreateNodeArg("X", &tensor_type);
|
||||
inputs.push_back(&X);
|
||||
|
|
@ -851,7 +851,7 @@ TEST(TypeInferenceTest, VariadicOutput) {
|
|||
|
||||
// test that we prefer the graph input shape for a non-const initializer (initializer with matching graph input)
|
||||
TEST(TypeInferenceTest, NonConstInitializer) {
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
TypeProto tensor_type_no_shape;
|
||||
|
|
@ -901,7 +901,7 @@ TEST(TypeInferenceTest, NonConstInitializer) {
|
|||
ASSERT_TRUE(model.ToProto().SerializeToString(&s1));
|
||||
ASSERT_TRUE(model_proto.ParseFromString(s1));
|
||||
|
||||
auto status = onnxruntime::Model::Load(model_proto, p_model, nullptr);
|
||||
auto status = onnxruntime::Model::Load(model_proto, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
auto& graph2 = p_model->MainGraph();
|
||||
|
|
@ -910,7 +910,7 @@ TEST(TypeInferenceTest, NonConstInitializer) {
|
|||
|
||||
// Test that Graph::Resolve identifies name-duplication across initializer and node-output-arg
|
||||
TEST(NameResolutionTest, DuplicateName) {
|
||||
Model model("graph_1");
|
||||
Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
ONNX_NAMESPACE::TensorProto weight;
|
||||
|
|
@ -940,7 +940,7 @@ TEST(NameResolutionTest, DuplicateName) {
|
|||
}
|
||||
|
||||
TEST(GraphUpdateTest, ReplaceInitializedTensor) {
|
||||
Model model{"GraphUpdateTest"};
|
||||
Model model{"GraphUpdateTest", false, DefaultLoggingManager().DefaultLogger()};
|
||||
auto& graph = model.MainGraph();
|
||||
const std::string initializer_name = "initializer";
|
||||
|
||||
|
|
@ -1011,7 +1011,7 @@ TEST(GraphUpdateTest, ReplaceInitializedTensor) {
|
|||
}
|
||||
|
||||
TEST(GraphUpdateTest, AddRemoveInitializerHandling) {
|
||||
Model m{"test_model"};
|
||||
Model m{"test_model", false, DefaultLoggingManager().DefaultLogger()};
|
||||
Graph& graph = m.MainGraph();
|
||||
|
||||
auto create_tensor_proto = [](const std::string& name, int32_t value) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
|
@ -47,29 +49,18 @@ static void TestResolve(onnxruntime::Graph& graph) {
|
|||
TEST(ONNXModelsTest, squeeze_net) {
|
||||
// NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load("../models/opset8/test_squeezenet/model.onnx", model).IsOK());
|
||||
ASSERT_TRUE(Model::Load(ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"), model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
TestResolve(model->MainGraph());
|
||||
#ifdef _WIN32
|
||||
// wstring version
|
||||
std::shared_ptr<Model> model2;
|
||||
ASSERT_TRUE(Model::Load(L"../models/opset8/test_squeezenet/model.onnx", model2).IsOK());
|
||||
TestResolve(model2->MainGraph());
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(ONNXModelsTest, non_existing_model) {
|
||||
// NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located
|
||||
std::shared_ptr<Model> model;
|
||||
common::Status st = Model::Load("./testdata/non_existing_model_XXXXXX/model.onnx", model);
|
||||
common::Status st = Model::Load(ORT_TSTR("./testdata/non_existing_model_XXXXXX/model.onnx"), model, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_FALSE(st.IsOK());
|
||||
ASSERT_EQ(st.Code(), common::NO_SUCHFILE);
|
||||
#ifdef _WIN32
|
||||
// wstring version
|
||||
std::shared_ptr<Model> model2;
|
||||
ASSERT_FALSE(Model::Load(L"./testdata/non_existing_model_XXXXXX/model.onnx", model2).IsOK());
|
||||
ASSERT_EQ(st.Code(), common::NO_SUCHFILE);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS
|
||||
|
|
@ -78,7 +69,7 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) {
|
|||
using ::google::protobuf::io::FileInputStream;
|
||||
using ::google::protobuf::io::ZeroCopyInputStream;
|
||||
int fd;
|
||||
ASSERT_TRUE(Env::Default().FileOpenRd("../models/opset8/test_bvlc_alexnet/model.onnx", fd).IsOK());
|
||||
ASSERT_TRUE(Env::Default().FileOpenRd(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), fd).IsOK());
|
||||
ASSERT_TRUE(fd > 0);
|
||||
std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
|
||||
std::unique_ptr<CodedInputStream> coded_input(new CodedInputStream(raw_input.get()));
|
||||
|
|
@ -92,7 +83,9 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) {
|
|||
ASSERT_TRUE(Env::Default().FileClose(fd).IsOK());
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load("../models/opset8/test_bvlc_alexnet/model.onnx", model).IsOK());
|
||||
ASSERT_TRUE(Model::Load(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), model, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger())
|
||||
.IsOK());
|
||||
|
||||
// Check the graph input/output/value_info should have the same size as specified in the model file.
|
||||
EXPECT_EQ(model_proto.graph().value_info_size(), model->MainGraph().GetValueInfo().size());
|
||||
|
|
@ -101,21 +94,23 @@ TEST(ONNXModelsTest1, bvlc_alexnet_1) {
|
|||
TestResolve(model->MainGraph());
|
||||
}
|
||||
|
||||
class ONNXModelsTest : public ::testing::TestWithParam<const char*> {
|
||||
class ONNXModelsTest : public ::testing::TestWithParam<const ORTCHAR_T*> {
|
||||
// You can implement all the usual fixture class members here.
|
||||
// To access the test parameter, call GetParam() from class
|
||||
// TestWithParam<T>.
|
||||
public:
|
||||
std::string GetModelFileName() const {
|
||||
std::ostringstream oss;
|
||||
oss << "../models/opset7/test_" << GetParam() << "/model.onnx";
|
||||
std::basic_string<ORTCHAR_T> GetModelFileName() const {
|
||||
std::basic_ostringstream<ORTCHAR_T> oss;
|
||||
oss << ORT_TSTR("../models/opset7/test_") << GetParam() << ORT_TSTR("/model.onnx");
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ONNXModelsTest, LoadFromFile) {
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(GetModelFileName(), model).IsOK());
|
||||
ASSERT_TRUE(Model::Load(GetModelFileName(), model, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger())
|
||||
.IsOK());
|
||||
TestResolve(model->MainGraph());
|
||||
}
|
||||
|
||||
|
|
@ -137,18 +132,20 @@ TEST_P(ONNXModelsTest, LoadFromProtobuf) {
|
|||
ASSERT_TRUE(result);
|
||||
ASSERT_TRUE(Env::Default().FileClose(fd).IsOK());
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(std::move(model_proto), model).IsOK());
|
||||
ASSERT_TRUE(Model::Load(std::move(model_proto), model, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger())
|
||||
.IsOK());
|
||||
TestResolve(model->MainGraph());
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
|
||||
ONNXModelsTest,
|
||||
::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "tiny_yolov2", "vgg19", "zfnet512"));
|
||||
::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("tiny_yolov2"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512")));
|
||||
#else
|
||||
INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
|
||||
ONNXModelsTest,
|
||||
::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "vgg19", "zfnet512"));
|
||||
::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512")));
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
@ -159,7 +156,9 @@ INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
|
|||
// for Graph::Resolve to succeed when processing the subgraph.
|
||||
TEST(ONNXModelsTest, TestIRv4NonInputInitializers) {
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK());
|
||||
ASSERT_TRUE(Model::Load(ORT_TSTR("testdata/subgraph_implicit_input_from_initializer.onnx"), model, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger())
|
||||
.IsOK());
|
||||
EXPECT_TRUE(model->MainGraph().Resolve().IsOK());
|
||||
}
|
||||
|
||||
|
|
@ -170,7 +169,8 @@ TEST(ONNXModelsTest, TestIRv4NonInputInitializers) {
|
|||
TEST(ONNXModelsTest, TestModelsWithAnOpContainingAFunctionBody) {
|
||||
std::shared_ptr<Model> model;
|
||||
|
||||
auto status = Model::Load("testdata/model_containing_op_with_function_body.onnx", model);
|
||||
auto status = Model::Load(ORT_TSTR("testdata/model_containing_op_with_function_body.onnx"), model, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK()) << status;
|
||||
|
||||
status = model->MainGraph().Resolve();
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/graph/schema_registry.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ TEST(FormalParamTest, Success) {
|
|||
}
|
||||
|
||||
TEST(FeatureVectorizerTest, TraditionalMlOpTest) {
|
||||
Model model("traditionalMl");
|
||||
Model model("traditionalMl", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
// Case: A traditional ml graph.
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/model.h"
|
||||
|
||||
#include "test/test_environment.h"
|
||||
|
||||
using ONNX_NAMESPACE::Utils::DataTypeUtils;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
|
|
@ -71,7 +73,8 @@ static GraphProto CreateNodeRemovalSubgraph(const std::string& new_output_name =
|
|||
std::string suffix = add_second_level ? ".top" : ".bottom";
|
||||
std::string constant_output_name = (new_output_name.empty() ? "constant_in_0" + suffix : new_output_name);
|
||||
|
||||
Model model("CreateNodeRemovalSubgraph:" + constant_output_name);
|
||||
Model model("CreateNodeRemovalSubgraph:" + constant_output_name, false,
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
TypeProto float_scalar_tensor;
|
||||
|
|
@ -198,7 +201,8 @@ static void CheckNodeRemovalSubgraphUpdate(const std::string& new_name, const Gr
|
|||
}
|
||||
|
||||
static void UpdateSubgraphWhenRemovingNode(bool include_nested = false) {
|
||||
Model model(std::string("UpdateSubgraphWhenRemovingNode") + (include_nested ? ":Nested" : ":SingleLevel"));
|
||||
Model model(std::string("UpdateSubgraphWhenRemovingNode") + (include_nested ? ":Nested" : ":SingleLevel"),
|
||||
false, DefaultLoggingManager().DefaultLogger());
|
||||
|
||||
CreateNodeRemovalGraph(model, true, include_nested);
|
||||
|
||||
|
|
@ -232,13 +236,15 @@ TEST(GraphUtils, UpdateNestedSubgraphWhenRemovingNode) {
|
|||
// we can't remove a node if it is used as an implicit input in a subgraph, and changing the implicit input name
|
||||
// will result with in a clash with an existing node in the subgraph
|
||||
static void DontRemoveNodeIfItWillBreakSubgraph(bool test_nested = false) {
|
||||
Model model(std::string("DontRemoveNodeIfItWillBreakSubgraph") + (test_nested ? ":Nested" : ":SingleLevel"));
|
||||
Model model(std::string("DontRemoveNodeIfItWillBreakSubgraph") + (test_nested ? ":Nested" : ":SingleLevel"),
|
||||
false, DefaultLoggingManager().DefaultLogger());
|
||||
CreateNodeRemovalGraph(model, false, test_nested);
|
||||
|
||||
auto& graph = model.MainGraph();
|
||||
auto& node_to_remove = *graph.GetNode(1);
|
||||
|
||||
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, node_to_remove));
|
||||
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, node_to_remove,
|
||||
DefaultLoggingManager().DefaultLogger()));
|
||||
}
|
||||
|
||||
TEST(GraphUtils, DontRemoveNodeIfItWillBreakSubgraph) {
|
||||
|
|
@ -253,7 +259,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
|
|||
// Create a graph with 5 Id nodes. The graph structure is as follows: Id0 ( Id1 Id2 ( Id3 Id4 ) ).
|
||||
// First we remove Id2, which leads to: Id0 ( Id1 Id4 Id5 ).
|
||||
// Then we remove Id1, which leads to: Id2 Id4 Id5, being fed the initializer.
|
||||
Model model("MultiEdgeRemovalGraph");
|
||||
Model model("MultiEdgeRemovalGraph", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
TypeProto float_tensor;
|
||||
|
|
@ -303,7 +309,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
|
|||
}
|
||||
|
||||
TEST(GraphUtils, TestMultiOutputRemoveNode) {
|
||||
Model model("MultiOutputRemovalGraph");
|
||||
Model model("MultiOutputRemovalGraph", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
TypeProto float_tensor;
|
||||
|
|
@ -341,14 +347,17 @@ TEST(GraphUtils, TestMultiOutputRemoveNode) {
|
|||
|
||||
// Try to remove do_0, which should return false
|
||||
// because both outputs are consumed by downstream Operators.
|
||||
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, *nodes[0]));
|
||||
ASSERT_FALSE(graph_utils::CanRemoveNode(graph, *nodes[0],
|
||||
DefaultLoggingManager().DefaultLogger()));
|
||||
|
||||
// Try removing do_0 after removing id_2, which should return true
|
||||
// because it now has exactly one output consumed by downstream Operators.
|
||||
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[1]));
|
||||
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[1],
|
||||
DefaultLoggingManager().DefaultLogger()));
|
||||
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[1]));
|
||||
ASSERT_FALSE(graph_utils::IsOutputUsed(*nodes[0], 0));
|
||||
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[0]));
|
||||
ASSERT_TRUE(graph_utils::CanRemoveNode(graph, *nodes[0],
|
||||
DefaultLoggingManager().DefaultLogger()));
|
||||
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[0]));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ BENCHMARK(BM_CPUAllocator)->Arg(4)->Arg(sizeof(Tensor));
|
|||
|
||||
static void BM_ResolveGraph(benchmark::State& state) {
|
||||
std::shared_ptr<onnxruntime::Model> model_copy;
|
||||
auto st = onnxruntime::Model::Load("../models/opset8/test_tiny_yolov2/model.onnx", model_copy);
|
||||
auto st = onnxruntime::Model::Load(ORT_TSTR("../models/opset8/test_tiny_yolov2/model.onnx"), model_copy);
|
||||
if (!st.IsOK()) {
|
||||
printf("Parse model failed: %s", st.ErrorMessage().c_str());
|
||||
abort();
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <core/common/logging/logging.h>
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/execution_providers.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
|
|
@ -161,8 +162,7 @@ namespace test {
|
|||
|
||||
std::string CreateModel() {
|
||||
RegisterCustomKernel();
|
||||
|
||||
Model model("ModelWithOpaque", false);
|
||||
Model model("ModelWithOpaque", false, logging::LoggingManager::DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class DummyGraphTransformer : public GraphTransformer {
|
|||
private:
|
||||
mutable bool transformer_invoked_;
|
||||
|
||||
Status ApplyImpl(Graph& /*graph*/, bool& /*modified*/, int /*graph_level*/) const override {
|
||||
Status ApplyImpl(Graph& /*graph*/, bool& /*modified*/, int /*graph_level*/, const logging::Logger&) const override {
|
||||
transformer_invoked_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -43,11 +43,12 @@ class DummyRewriteRule : public RewriteRule {
|
|||
private:
|
||||
mutable bool rewrite_rule_invoked_;
|
||||
|
||||
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/) const override {
|
||||
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) const override {
|
||||
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/,
|
||||
const logging::Logger& /*logger*/) const override {
|
||||
rewrite_rule_invoked_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue