mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Always enable ORT format model loading. (#9586)
This commit is contained in:
parent
5c56fa0def
commit
c315d1b3cd
21 changed files with 7 additions and 121 deletions
|
|
@ -117,7 +117,6 @@ cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handlin
|
|||
option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF)
|
||||
option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF)
|
||||
option(onnxruntime_REDUCED_OPS_BUILD "Reduced set of kernels are registered in build via modification of the kernel registration source files." OFF)
|
||||
option(onnxruntime_DISABLE_ORT_FORMAT_LOAD "Disable loading an ORT format model when onnxruntime_MINIMAL_BUILD=OFF (i.e. in a full build)." OFF)
|
||||
option(onnxruntime_DISABLE_EXTERNAL_INITIALIZERS "Don't allow models to load external data" OFF)
|
||||
cmake_dependent_option(onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION "Enable runtime graph optimization of ORT format models." ON
|
||||
"NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD" OFF)
|
||||
|
|
@ -334,7 +333,6 @@ endif()
|
|||
# ORT build with as much excluded as possible. Supports ORT flatbuffers models only.
|
||||
if (onnxruntime_MINIMAL_BUILD)
|
||||
add_compile_definitions(ORT_MINIMAL_BUILD)
|
||||
add_compile_definitions(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
if (onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
# enable EPs that compile kernels at runtime
|
||||
|
|
@ -376,11 +374,6 @@ if (onnxruntime_MINIMAL_BUILD)
|
|||
string(APPEND CMAKE_C_FLAGS " -g")
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
# support ORT format model loading unless onnxruntime_DISABLE_ORT_FORMAT_LOAD is set
|
||||
if (NOT onnxruntime_DISABLE_ORT_FORMAT_LOAD)
|
||||
add_compile_definitions(ENABLE_ORT_FORMAT_LOAD)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
|
||||
|
|
|
|||
|
|
@ -431,13 +431,11 @@ class Node {
|
|||
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
static Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Node& fbs_node, Graph& graph,
|
||||
const logging::Logger& logger, std::unique_ptr<Node>& node);
|
||||
|
||||
Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Node& fbs_node, const logging::Logger& logger);
|
||||
Status LoadEdgesFromOrtFormat(const onnxruntime::experimental::fbs::NodeEdge& fbs_node_edgs, const Graph& graph);
|
||||
#endif
|
||||
|
||||
/**
|
||||
@class Definitions
|
||||
|
|
@ -1169,7 +1167,6 @@ class Graph {
|
|||
|
||||
virtual ~Graph();
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
static common::Status LoadFromOrtFormat(
|
||||
const onnxruntime::experimental::fbs::Graph& fbs_graph, const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
|
|
@ -1182,7 +1179,7 @@ class Graph {
|
|||
static Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs_graph,
|
||||
Graph& parent_graph, const Node& parent_node,
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
|
||||
#endif
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);
|
||||
|
||||
|
|
@ -1201,10 +1198,8 @@ class Graph {
|
|||
Graph* parent_graph, const Node* parent_node,
|
||||
const logging::Logger& logger);
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
// Populate Graph instance from ORT format serialized data.
|
||||
common::Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs_graph);
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// Constructor: Given a <GraphProto> loaded from model file, construct
|
||||
|
|
|
|||
|
|
@ -165,8 +165,6 @@ Status SaveValueInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|||
|
||||
#endif // #if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
void LoadStringFromOrtFormat(std::string& dst, const flatbuffers::String* fbs_string) {
|
||||
if (fbs_string)
|
||||
dst = fbs_string->c_str();
|
||||
|
|
@ -307,8 +305,6 @@ Status LoadOpsetImportOrtFormat(const flatbuffers::Vector<flatbuffers::Offset<fb
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
bool IsOrtFormatModelBytes(const void* bytes, int num_bytes) {
|
||||
return num_bytes > 8 && // check buffer is large enough to contain identifier so we don't read random memory
|
||||
fbs::InferenceSessionBufferHasIdentifier(bytes);
|
||||
|
|
|
|||
|
|
@ -43,8 +43,6 @@ onnxruntime::common::Status SaveValueInfoOrtFormat(
|
|||
flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::ValueInfoProto& value_info_proto,
|
||||
flatbuffers::Offset<fbs::ValueInfo>& fbs_value_info);
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
void LoadStringFromOrtFormat(std::string& dst, const flatbuffers::String* fbs_string);
|
||||
|
||||
// This macro is to be used on a protobuf message (protobug_msg), which will not create an empty string field (str_field)
|
||||
|
|
@ -62,8 +60,6 @@ onnxruntime::common::Status LoadOpsetImportOrtFormat(
|
|||
const flatbuffers::Vector<flatbuffers::Offset<fbs::OperatorSetId>>* fbs_op_set_ids,
|
||||
std::unordered_map<std::string, int>& domain_to_version);
|
||||
|
||||
#endif
|
||||
|
||||
// check if filename ends in .ort
|
||||
template <typename T>
|
||||
bool IsOrtFormatModel(const std::basic_string<T>& filename) {
|
||||
|
|
|
|||
|
|
@ -949,7 +949,6 @@ Status SessionState::CreateSubgraphSessionState() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_state,
|
||||
const KernelRegistryManager& kernel_registry_manager) {
|
||||
using experimental::utils::FbsSessionStateViewer;
|
||||
|
|
@ -1011,7 +1010,6 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Calculate the use count of a constant initialized tensor, including the use in subgraph.
|
||||
// Note: This function doesn't handle the case below:
|
||||
|
|
@ -1052,13 +1050,7 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
|
|||
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState());
|
||||
|
||||
if (serialized_session_state) {
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
ORT_RETURN_IF_ERROR(LoadFromOrtFormat(*serialized_session_state, kernel_registry_manager));
|
||||
#else
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"ORT format model is not supported in this build.");
|
||||
#endif
|
||||
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));
|
||||
|
|
|
|||
|
|
@ -297,14 +297,12 @@ class SessionState {
|
|||
flatbuffers::Offset<onnxruntime::experimental::fbs::SessionState>& fbs_session_state) const;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
void SetCompiledKernelHashes(std::unordered_map<std::string, uint64_t>&& compiled_kernel_hashes) {
|
||||
compiled_kernel_hashes_ = std::move(compiled_kernel_hashes);
|
||||
}
|
||||
|
||||
Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::SessionState& fbs_session_state,
|
||||
const KernelRegistryManager& kernel_registry_manager);
|
||||
#endif
|
||||
|
||||
Status FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
|
|
|
|||
|
|
@ -611,7 +611,6 @@ flatbuffers::Offset<fbs::NodeEdge> Node::SaveEdgesToOrtFormat(flatbuffers::FlatB
|
|||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
Status Node::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Node& fbs_node, Graph& graph,
|
||||
const logging::Logger& logger, std::unique_ptr<Node>& node) {
|
||||
node.reset(new Node(fbs_node.index(), graph));
|
||||
|
|
@ -708,7 +707,6 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::experimental::fbs::NodeEd
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
void Node::Init(const std::string& name,
|
||||
|
|
@ -4019,7 +4017,6 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) {
|
|||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs_graph,
|
||||
const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
|
|
@ -4218,6 +4215,4 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Gr
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -174,8 +174,6 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|||
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor,
|
||||
TensorProto& initializer) {
|
||||
initializer.Clear();
|
||||
|
|
@ -310,8 +308,6 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
} // namespace utils
|
||||
} // namespace experimental
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -52,8 +52,6 @@ onnxruntime::common::Status SaveAttributeOrtFormat(
|
|||
flatbuffers::Offset<fbs::Attribute>& fbs_attr, const Path& model_path,
|
||||
const onnxruntime::Graph* subgraph);
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
onnxruntime::common::Status LoadInitializerOrtFormat(
|
||||
const fbs::Tensor& fbs_tensor, ONNX_NAMESPACE::TensorProto& initializer);
|
||||
|
||||
|
|
@ -69,8 +67,6 @@ onnxruntime::common::Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_att
|
|||
Graph& graph, Node& node,
|
||||
const logging::Logger& logger);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace utils
|
||||
} // namespace experimental
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -711,7 +711,6 @@ common::Status Model::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|||
Model::Model() : model_path_{} {
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
|
|
@ -780,6 +779,5 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
|||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -263,14 +263,12 @@ class Model {
|
|||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
static common::Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::Model& fbs_model,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
#endif
|
||||
const logging::Logger& logger,
|
||||
std::unique_ptr<Model>& model);
|
||||
#endif
|
||||
|
||||
private:
|
||||
Model();
|
||||
|
|
|
|||
|
|
@ -181,7 +181,6 @@ std::atomic<uint32_t> InferenceSession::global_session_id_{1};
|
|||
// Version 4 - update kernel def hashing to not depend on ordering of type constraint types (NOT BACKWARDS COMPATIBLE)
|
||||
static constexpr const char* kOrtModelVersion = "4";
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
// Check if the given ort model version is supported in this build
|
||||
static bool IsOrtModelVersionSupported(const std::string& ort_model_version) {
|
||||
// The ort model versions we will support in this build
|
||||
|
|
@ -192,7 +191,6 @@ static bool IsOrtModelVersionSupported(const std::string& ort_model_version) {
|
|||
|
||||
return kSupportedOrtModelVersions.find(ort_model_version) != kSupportedOrtModelVersions.cend();
|
||||
}
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options,
|
||||
const ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
|
|
@ -703,11 +701,7 @@ common::Status InferenceSession::Load(const std::string& model_uri) {
|
|||
|
||||
if ((has_explicit_type && model_type == "ORT") ||
|
||||
(!has_explicit_type && experimental::utils::IsOrtFormatModel(model_uri))) {
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
return LoadOrtModel(model_uri);
|
||||
#else
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ORT format model is not supported in this build.");
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
@ -730,11 +724,7 @@ common::Status InferenceSession::Load(const std::wstring& model_uri) {
|
|||
|
||||
if ((has_explicit_type && model_type == "ORT") ||
|
||||
(!has_explicit_type && experimental::utils::IsOrtFormatModel(model_uri))) {
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
return LoadOrtModel(model_uri);
|
||||
#else
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ORT format model is not supported in this build.");
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
@ -758,11 +748,7 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
|
|||
if ((has_explicit_type && model_type == "ORT") ||
|
||||
(!has_explicit_type &&
|
||||
experimental::utils::IsOrtFormatModelBytes(model_data, model_data_len))) {
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
return LoadOrtModel(model_data, model_data_len);
|
||||
#else
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ORT format model is not supported in this build.");
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
@ -963,8 +949,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
template <typename T>
|
||||
static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
|
||||
std::basic_string<ORTCHAR_T>& model_location,
|
||||
|
|
@ -1090,8 +1074,6 @@ Status InferenceSession::LoadOrtModel(std::function<Status()> load_ort_format_mo
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
bool InferenceSession::IsInitialized() const {
|
||||
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
||||
return is_inited_;
|
||||
|
|
@ -1153,7 +1135,6 @@ common::Status InferenceSession::AddPrePackedWeightsContainer(PrepackedWeightsCo
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
namespace {
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
|
||||
|
|
@ -1230,7 +1211,6 @@ Status AssignNodesToEpsFromHashes(Graph& graph, const fbs::SessionState& fbs_ses
|
|||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
common::Status InferenceSession::Initialize() {
|
||||
Status status = Status::OK();
|
||||
|
|
@ -1375,7 +1355,6 @@ common::Status InferenceSession::Initialize() {
|
|||
{
|
||||
ORT_ENFORCE(loading_ort_format && serialized_session_state != nullptr);
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
// nodes are already partitioned, but a custom EP may compile some at runtime.
|
||||
// run the partitioning to allow that to happen.
|
||||
|
|
@ -1393,9 +1372,6 @@ common::Status InferenceSession::Initialize() {
|
|||
|
||||
ORT_RETURN_IF_ERROR(AssignNodesToEpsFromHashes(graph, *serialized_session_state,
|
||||
kernel_registry_manager_, *session_logger_));
|
||||
#else // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
ORT_NOT_IMPLEMENTED("ORT format loading not enabled.");
|
||||
#endif
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
|
|
|
|||
|
|
@ -514,7 +514,6 @@ class InferenceSession {
|
|||
common::Status SaveToOrtFormat(const std::basic_string<ORTCHAR_T>& filepath) const;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
/**
|
||||
* Load an ORT format model.
|
||||
* @param model_uri absolute path of the model file.
|
||||
|
|
@ -537,8 +536,6 @@ class InferenceSession {
|
|||
|
||||
common::Status LoadOrtModel(std::function<Status()> load_ort_format_model_bytes) ORT_MUST_USE_RESULT;
|
||||
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
// Create a Logger for a single execution if possible. Otherwise use the default logger.
|
||||
// If a new logger is created, it will also be stored in new_run_logger,
|
||||
// which must remain valid for the duration of the execution.
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// if we can't load an ORT format model we can't really test anything
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
|
@ -528,5 +525,3 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBufferNoCopy) {
|
|||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
|
|
|||
|
|
@ -269,11 +269,9 @@ std::unique_ptr<TestModelInfo> TestModelInfo::LoadOnnxModel(_In_ const PATH_CHAR
|
|||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
std::unique_ptr<TestModelInfo> TestModelInfo::LoadOrtModel(_In_ const PATH_CHAR_TYPE* model_url) {
|
||||
return std::unique_ptr<TestModelInfo>(new OnnxModelInfo(model_url, true));
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* test_case_dir must have contents of:
|
||||
|
|
@ -662,9 +660,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
|
|||
is_valid_model = is_onnx_format;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
is_valid_model = is_valid_model || is_ort_format;
|
||||
#endif
|
||||
if (!is_valid_model)
|
||||
return true;
|
||||
|
||||
|
|
@ -688,11 +684,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
|
|||
ORT_THROW("onnx model is not supported in this build");
|
||||
#endif
|
||||
} else if (is_ort_format) {
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
model_info = TestModelInfo::LoadOrtModel(p.c_str());
|
||||
#else
|
||||
ORT_THROW("ort model is not supported in this build");
|
||||
#endif
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED(ToMBString(filename_str), " is not supported");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,9 +73,7 @@ class TestModelInfo {
|
|||
static std::unique_ptr<TestModelInfo> LoadOnnxModel(_In_ const PATH_CHAR_TYPE* model_url);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
static std::unique_ptr<TestModelInfo> LoadOrtModel(_In_ const PATH_CHAR_TYPE* model_url);
|
||||
#endif
|
||||
|
||||
static const std::string unknown_version;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,29 +2,23 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnx_model_info.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "re2/re2.h"
|
||||
#include "pb_helper.h"
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "pb_helper.h"
|
||||
#include "re2/re2.h"
|
||||
|
||||
#include "core/flatbuffers/schema/ort.fbs.h"
|
||||
#include "core/flatbuffers/flatbuffers_utils.h"
|
||||
#include "core/platform/env.h"
|
||||
|
||||
using namespace onnxruntime::experimental;
|
||||
|
||||
#endif
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
OnnxModelInfo::OnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url, bool is_ort_model)
|
||||
: model_url_(model_url) {
|
||||
if (is_ort_model) {
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
InitOrtModelInfo(model_url);
|
||||
#else
|
||||
ORT_THROW("ort model is not supported in this build");
|
||||
#endif
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
InitOnnxModelInfo(model_url);
|
||||
|
|
@ -100,8 +94,6 @@ void OnnxModelInfo::InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url) { /
|
|||
|
||||
#endif // #if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
void OnnxModelInfo::InitOrtModelInfo(_In_ const PATH_CHAR_TYPE* model_url) {
|
||||
std::vector<uint8_t> bytes;
|
||||
size_t num_bytes = 0;
|
||||
|
|
@ -175,5 +167,3 @@ void OnnxModelInfo::InitOrtModelInfo(_In_ const PATH_CHAR_TYPE* model_url) {
|
|||
ORT_THROW_IF_ERROR(add_node_args(fbs_graph->inputs(), input_value_info_));
|
||||
ORT_THROW_IF_ERROR(add_node_args(fbs_graph->outputs(), output_value_info_));
|
||||
}
|
||||
|
||||
#endif //#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
|
|
|||
|
|
@ -18,9 +18,7 @@ class OnnxModelInfo : public TestModelInfo {
|
|||
void InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
void InitOrtModelInfo(_In_ const PATH_CHAR_TYPE* model_url);
|
||||
#endif
|
||||
|
||||
public:
|
||||
OnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url, bool is_ort_model = false);
|
||||
|
|
|
|||
|
|
@ -240,11 +240,9 @@ static std::unique_ptr<TestModelInfo> CreateModelInfo(const PerformanceTestConfi
|
|||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
if (HasExtensionOf(file_path, ORT_TSTR("ort"))) {
|
||||
return TestModelInfo::LoadOrtModel(performance_test_config_.model_info.model_file_path.c_str());
|
||||
}
|
||||
#endif
|
||||
|
||||
ORT_NOT_IMPLEMENTED(ToMBString(file_path), " is not supported");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// if we can't load an ORT format model we can't really test anything
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
// custom ops are only supported in a minimal build if explicitly enabled
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
|
||||
|
|
@ -147,5 +144,3 @@ TEST(OrtFormatCustomOpTests, LoadOrtModel) {
|
|||
#endif
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
|
||||
#endif // #if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
|
|
|||
|
|
@ -535,8 +535,6 @@ def parse_arguments():
|
|||
parser.add_argument("--disable_rtti", action='store_true', help="Disable RTTI (reduces binary size)")
|
||||
parser.add_argument("--disable_exceptions", action='store_true',
|
||||
help="Disable exceptions to reduce binary size. Requires --minimal_build.")
|
||||
parser.add_argument("--disable_ort_format_load", action='store_true',
|
||||
help='Disable support for loading ORT format models in a non-minimal build.')
|
||||
|
||||
parser.add_argument(
|
||||
"--rocm_version", help="The version of ROCM stack to use. ")
|
||||
|
|
@ -760,7 +758,6 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
|
|||
"-Donnxruntime_DISABLE_ML_OPS=" + ("ON" if args.disable_ml_ops else "OFF"),
|
||||
"-Donnxruntime_DISABLE_RTTI=" + ("ON" if args.disable_rtti else "OFF"),
|
||||
"-Donnxruntime_DISABLE_EXCEPTIONS=" + ("ON" if args.disable_exceptions else "OFF"),
|
||||
"-Donnxruntime_DISABLE_ORT_FORMAT_LOAD=" + ("ON" if args.disable_ort_format_load else "OFF"),
|
||||
# Need to use 'is not None' with minimal_build check as it could be an empty list.
|
||||
"-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build is not None else "OFF"),
|
||||
"-Donnxruntime_EXTENDED_MINIMAL_BUILD=" + ("ON" if args.minimal_build and 'extended' in args.minimal_build
|
||||
|
|
@ -2022,9 +2019,6 @@ def main():
|
|||
if args.enable_pybind and args.disable_exceptions:
|
||||
raise BuildError('Python bindings require exceptions to be enabled.')
|
||||
|
||||
if args.minimal_build is not None and args.disable_ort_format_load:
|
||||
raise BuildError('Minimal build requires loading ORT format models.')
|
||||
|
||||
if args.nnapi_min_api:
|
||||
if not args.use_nnapi:
|
||||
raise BuildError("Using --nnapi_min_api requires --use_nnapi")
|
||||
|
|
|
|||
Loading…
Reference in a new issue