mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
Address PR comments
Add unit tests
This commit is contained in:
parent
45d5906358
commit
453f13a2b5
12 changed files with 253 additions and 114 deletions
|
|
@ -40,7 +40,7 @@
|
|||
#include "core/graph/node_arg.h"
|
||||
#include "core/graph/ort_format_load_options.h"
|
||||
|
||||
// Type from Graph API in ORT C API so can't be in a namespace
|
||||
// Type from Model Builder API in ORT C API so can't be in a namespace
|
||||
struct OrtGraph;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -5047,7 +5047,7 @@ struct OrtModelBuilderApi {
|
|||
/** \brief Create an OrtValueInfo for use as an OrtGraph input or output.
|
||||
*
|
||||
* \param[in] name The name of the input or output.
|
||||
* \param[in] type_info The type information for the input or output.
|
||||
* \param[in] type_info The type information for the input or output. The provided value is copied.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
@ -5102,7 +5102,7 @@ struct OrtModelBuilderApi {
|
|||
*
|
||||
* \since Version 1.21.
|
||||
*/
|
||||
ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, const char* domain_name, _In_ const char* node_name,
|
||||
ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name,
|
||||
_In_reads_(input_names_len) const char* const* input_names, size_t input_names_len,
|
||||
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
|
||||
_In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len,
|
||||
|
|
@ -5263,9 +5263,16 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Create an OrtSession using the OrtModel.
|
||||
*
|
||||
* Create an inference session using the OrtModel.
|
||||
* Create an inference session using the OrtModel instance.
|
||||
* The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs
|
||||
* and SetGraphOutputs must have been called.
|
||||
* This will validate the model, run optimizers, and prepare the session for inferencing.
|
||||
*
|
||||
* \param[in] env The OrtEnv instance.
|
||||
* \param[in] model The OrtModel instance.
|
||||
* \param[in] options The OrtSessionOptions instance.
|
||||
* \param[out] out The OrtSession instance.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.21.
|
||||
|
|
@ -5275,11 +5282,17 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Create an OrtSession to augment an existing model.
|
||||
*
|
||||
* Create an OrtSession with an existing model that can be augmented with additional nodes.
|
||||
* Nodes can be added to the model using AddNodeToGraph.
|
||||
* Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes.
|
||||
* Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling
|
||||
* FinalizeModelBuilderSession.
|
||||
* Create an OrtSession with an existing model that will be augmented with additional nodes.
|
||||
* Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the
|
||||
* model is finalized.
|
||||
*
|
||||
* To add nodes to the existing model, first create an OrtModel using CreateModel.
|
||||
* Add additional nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph.
|
||||
* Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs to reflect changes made by the new
|
||||
* nodes. The list of inputs/outputs should be for the overall model and not just the new nodes.
|
||||
*
|
||||
* Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the
|
||||
* session for inferencing by calling FinalizeModelBuilderSession.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
@ -5291,11 +5304,22 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Create an OrtSession to augment an existing model.
|
||||
*
|
||||
* Create an OrtSession with an existing model that can be augmented with additional nodes.
|
||||
* Nodes can be added to the model using AddNodeToGraph.
|
||||
* Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes.
|
||||
* Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling
|
||||
* FinalizeModelBuilderSession.
|
||||
* Create an OrtSession with an existing model that will be augmented with additional nodes.
|
||||
* Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the
|
||||
* model is finalized.
|
||||
*
|
||||
* To add nodes to the existing model, first create an OrtModel using CreateModel.
|
||||
* Add additional nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph.
|
||||
* Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs to reflect changes made by the new
|
||||
* nodes. The list of inputs/outputs should be for the overall model and not just the new nodes.
|
||||
*
|
||||
* Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the
|
||||
* session for inferencing by calling FinalizeModelBuilderSession.
|
||||
*
|
||||
* \param{in} env The OrtEnv instance.
|
||||
* \param{in} model_data The model data for the existing model to augment.
|
||||
* \param{in} model_data_length The length of the model data.
|
||||
* \param{in} options The OrtSessionOptions instance.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
@ -5308,12 +5332,12 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Apply the changes from the model to the session.
|
||||
*
|
||||
* Apply the changes from the model to the session that was created using CreateModelBuilderSession[FromArray].
|
||||
* Apply the changes from the model to a session that was created using CreateModelBuilderSession[FromArray].
|
||||
* All changes will be validated.
|
||||
* Call FinalizeModelBuilderSession to prepare the session for inferencing.
|
||||
*
|
||||
* Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel.
|
||||
* i.e. you don't need to call SetGraphInputs/Outputs if they are unchanged.
|
||||
* i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
|
|||
|
|
@ -667,7 +667,8 @@ struct ModelMetadata;
|
|||
|
||||
namespace ModelBuilderAPI {
|
||||
struct Model;
|
||||
}
|
||||
struct ValueInfo;
|
||||
} // namespace ModelBuilderAPI
|
||||
|
||||
/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
|
||||
* and release them at the end of the scope. The lifespan of the given allocator
|
||||
|
|
@ -1120,6 +1121,11 @@ struct ConstSessionImpl : Base<T> {
|
|||
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
|
||||
|
||||
int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain
|
||||
|
||||
// TODO: Should we move ValueInfo from ModelBuilderAPI to the ORT API? It will also be relevant to the Plugin EP API
|
||||
// although the internal implementation may differ there.
|
||||
std::vector<ModelBuilderAPI::ValueInfo> GetInputs() const;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> GetOutputs() const;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -2592,9 +2598,6 @@ struct NodeImpl : Ort::detail::Base<T> {
|
|||
};
|
||||
} // namespace detail
|
||||
|
||||
// Const object holder that does not own the underlying object
|
||||
using ConstNode = detail::NodeImpl<Ort::detail::Unowned<const OrtNode>>;
|
||||
|
||||
/** \brief Wrapper around ::OrtNode
|
||||
*
|
||||
*/
|
||||
|
|
@ -2616,8 +2619,6 @@ struct Node : detail::NodeImpl<OrtNode> {
|
|||
const std::vector<std::string>& output_names,
|
||||
std::vector<OpAttr>& attributes);
|
||||
|
||||
ConstNode GetConst() const { return ConstNode{this->p_}; }
|
||||
|
||||
private:
|
||||
static void Init(const std::string& operator_name, const std::string& operator_domain,
|
||||
const std::string& node_name,
|
||||
|
|
@ -2640,9 +2641,6 @@ struct GraphImpl : Ort::detail::Base<T> {
|
|||
};
|
||||
} // namespace detail
|
||||
|
||||
// Const object holder that does not own the underlying object
|
||||
using ConstGraph = detail::GraphImpl<Ort::detail::Unowned<const OrtGraph>>;
|
||||
|
||||
/** \brief Wrapper around ::OrtGraph
|
||||
*
|
||||
*/
|
||||
|
|
@ -2650,8 +2648,6 @@ struct Graph : detail::GraphImpl<OrtGraph> {
|
|||
explicit Graph(std::nullptr_t) {} ///< No instance is created
|
||||
explicit Graph(OrtGraph* p) : GraphImpl<OrtGraph>{p} {} ///< Take ownership of a pointer created by C API
|
||||
Graph();
|
||||
|
||||
ConstGraph GetConst() const { return ConstGraph{this->p_}; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
|
@ -2675,7 +2671,7 @@ struct Model : detail::ModelImpl<OrtModel> {
|
|||
|
||||
explicit Model(std::nullptr_t) {} ///< No instance is created
|
||||
explicit Model(OrtModel* p) : ModelImpl<OrtModel>{p} {} ///< Take ownership of a pointer created by C API
|
||||
Model(const std::vector<DomainOpsetPair>& opsets);
|
||||
explicit Model(const std::vector<DomainOpsetPair>& opsets);
|
||||
|
||||
ConstModel GetConst() const { return ConstModel{this->p_}; }
|
||||
};
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@
|
|||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
// Convert OrtStatus to Ort::Status and return
|
||||
// instead of throwing
|
||||
|
|
@ -1109,6 +1111,36 @@ inline int ConstSessionImpl<T>::GetOpset(const std::string& domain) const {
|
|||
return opset;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ModelBuilderAPI::ValueInfo> ConstSessionImpl<T>::GetInputs() const {
|
||||
const std::vector<std::string> input_names = GetInputNames();
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> inputs;
|
||||
inputs.reserve(input_names.size());
|
||||
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
auto type_info = GetInputTypeInfo(i);
|
||||
inputs.emplace_back(ModelBuilderAPI::ValueInfo{input_names[i], type_info.GetConst()});
|
||||
}
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ModelBuilderAPI::ValueInfo> ConstSessionImpl<T>::GetOutputs() const {
|
||||
const std::vector<std::string> output_names = GetOutputNames();
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> outputs;
|
||||
outputs.reserve(output_names.size());
|
||||
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
auto type_info = GetOutputTypeInfo(i);
|
||||
outputs.emplace_back(ModelBuilderAPI::ValueInfo{output_names[i], type_info.GetConst()});
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
const char* const* output_names, size_t output_count) {
|
||||
|
|
@ -2328,6 +2360,7 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con
|
|||
}
|
||||
|
||||
namespace ModelBuilderAPI {
|
||||
namespace detail {
|
||||
inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>& strings) {
|
||||
std::vector<const char*> ptrs;
|
||||
ptrs.reserve(strings.size());
|
||||
|
|
@ -2336,6 +2369,7 @@ inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>
|
|||
|
||||
return ptrs;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// static
|
||||
inline void Node::Init(const std::string& operator_name, const std::string& operator_domain,
|
||||
|
|
@ -2344,8 +2378,8 @@ inline void Node::Init(const std::string& operator_name, const std::string& oper
|
|||
const std::vector<std::string>& output_names,
|
||||
std::vector<OpAttr>& attributes,
|
||||
OrtNode*& node) {
|
||||
auto inputs = StringsToCharPtrs(input_names);
|
||||
auto outputs = StringsToCharPtrs(output_names);
|
||||
auto inputs = detail::StringsToCharPtrs(input_names);
|
||||
auto outputs = detail::StringsToCharPtrs(output_names);
|
||||
|
||||
std::vector<OrtOpAttr*> attributes_ptrs;
|
||||
attributes_ptrs.reserve(attributes.size());
|
||||
|
|
|
|||
|
|
@ -300,13 +300,3 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio
|
|||
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
|
||||
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
|
||||
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
|
||||
|
||||
// Create an Inference Session that will use the Model Builder API to create/update the model.
|
||||
// This flag will create the session but not fully initialize it. A model, if provided, will be loaded.
|
||||
// A session logger will be created, and execution providers will be registered.
|
||||
// Any device specific allocators and IDataTransfer objects will be registered.
|
||||
// This allows CreateAllocator to return device specific allocators registered by EPs.
|
||||
// FUTURE: This will also allow CopyTensors to utilize the IDataTransfer objects
|
||||
// "0": Disabled. [DEFAULT]
|
||||
// "1": Enable Model Builder Session
|
||||
static const char* const kOrtSessionOptionsEnableModelBuilder = "session.model_builder_session";
|
||||
|
|
|
|||
|
|
@ -266,18 +266,6 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
|
||||
bool has_external_memory = false;
|
||||
if (utils::HasExternalData(tensor_proto)) {
|
||||
std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
|
||||
ORT_THROW_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
|
||||
|
||||
has_external_memory = external_data_info->GetRelPath() == onnxruntime::utils::kTensorProtoMemoryAddressTag;
|
||||
}
|
||||
|
||||
return has_external_memory;
|
||||
}
|
||||
|
||||
void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) {
|
||||
tensor_proto.set_raw_data(std::move(param));
|
||||
}
|
||||
|
|
@ -1301,22 +1289,15 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor,
|
|||
const auto* raw_data = tensor.DataRaw();
|
||||
ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor.");
|
||||
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
|
||||
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
// we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto.
|
||||
// use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the
|
||||
// high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path.
|
||||
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(raw_data));
|
||||
|
||||
ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("location");
|
||||
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("offset");
|
||||
entry->set_value(std::to_string(offset));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("length");
|
||||
entry->set_value(std::to_string(tensor.SizeInBytes()));
|
||||
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
|
||||
offset, tensor.SizeInBytes(), tensor_proto);
|
||||
|
||||
} else {
|
||||
utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -514,10 +514,6 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) {
|
|||
}
|
||||
#endif
|
||||
|
||||
// Check if the TensorProto has an external data entry that points to memory rather than an external file.
|
||||
// The external data location will be kTensorProtoMemoryAddressTag in this case.
|
||||
bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto);
|
||||
|
||||
// UnpackTensor from raw data or the type specific data field. Does not handle external data.
|
||||
// If the tensor does not contain raw data then raw_data should be nullptr and raw_data_len should be 0.
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -5807,21 +5807,11 @@ Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updat
|
|||
|
||||
if (is_external) {
|
||||
// pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo
|
||||
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
const void* data_offset = t.DataRaw(); // address of memory not offset into file
|
||||
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(data_offset));
|
||||
|
||||
ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("location");
|
||||
// magic tag for existing memory that causes 'offset' to be treated as a pointer to the memory
|
||||
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("offset");
|
||||
entry->set_value(std::to_string(offset));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("length");
|
||||
entry->set_value(std::to_string(t.SizeInBytes()));
|
||||
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
|
||||
offset, t.SizeInBytes(), tensor_proto);
|
||||
|
||||
// copy OrtValue to keep it alive and to store the deleter if provided.
|
||||
ortvalue_initializers_.emplace(name, v);
|
||||
|
|
|
|||
|
|
@ -300,8 +300,6 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
|
|||
const auto* fbs_raw_data = fbs_tensor.raw_data();
|
||||
if (fbs_raw_data) {
|
||||
if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
|
||||
initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
|
||||
const void* data_offset = fbs_raw_data->Data();
|
||||
// we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto.
|
||||
|
|
@ -309,15 +307,9 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
|
|||
// high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path.
|
||||
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(data_offset));
|
||||
|
||||
ONNX_NAMESPACE::StringStringEntryProto* entry = initializer.mutable_external_data()->Add();
|
||||
entry->set_key("location");
|
||||
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
|
||||
entry = initializer.mutable_external_data()->Add();
|
||||
entry->set_key("offset");
|
||||
entry->set_value(std::to_string(offset));
|
||||
entry = initializer.mutable_external_data()->Add();
|
||||
entry->set_key("length");
|
||||
entry->set_value(std::to_string(fbs_raw_data->size()));
|
||||
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
|
||||
offset, fbs_raw_data->size(), initializer);
|
||||
|
||||
} else {
|
||||
// fbs_raw_data is uint8_t vector, so the size is byte size
|
||||
initializer.set_raw_data(fbs_raw_data->Data(), fbs_raw_data->size());
|
||||
|
|
|
|||
|
|
@ -627,7 +627,7 @@ class InferenceSession {
|
|||
/// convenience pointer to logger. should always be the same as session_state_.Logger();
|
||||
const logging::Logger* session_logger_;
|
||||
|
||||
// The list of execution providers.
|
||||
// The list of execution providers.
|
||||
// This MUST be prior to model_ in case there are values in the model that were allocated using an allocator
|
||||
// provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the
|
||||
// EP instance.
|
||||
|
|
|
|||
|
|
@ -48,13 +48,15 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateValueInfo, _In_ const char* name,
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) {
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info,
|
||||
_Out_ const char** name) {
|
||||
API_IMPL_BEGIN
|
||||
*name = value_info->name.c_str();
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info) {
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info,
|
||||
_Outptr_ const OrtTypeInfo** type_info) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
*type_info = value_info->type_info.get();
|
||||
|
|
|
|||
|
|
@ -369,6 +369,8 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) {
|
|||
ModelBuilderAPI::Model model(opsets);
|
||||
model.AddGraph(graph);
|
||||
|
||||
auto session = CreateSession(*ort_env, model);
|
||||
|
||||
std::vector<Input<float>> inputs(1);
|
||||
auto& input = inputs[0];
|
||||
input.name = "X";
|
||||
|
|
@ -379,7 +381,6 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) {
|
|||
|
||||
std::vector<int64_t> expected_dims = {3, 8};
|
||||
|
||||
auto session = CreateSession(*ort_env, model);
|
||||
TestInference<float>(session, inputs, "Z", expected_dims,
|
||||
{340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f,
|
||||
596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f,
|
||||
|
|
@ -410,12 +411,15 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets; // no additional opsets required
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
|
||||
std::vector<std::string> input_names = session.GetInputNames();
|
||||
ASSERT_EQ(input_names.size(), 1);
|
||||
|
||||
TypeInfo orig_input = session.GetInputTypeInfo(0);
|
||||
ASSERT_EQ(orig_input.GetTensorTypeAndShapeInfo().GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs = session.GetInputs();
|
||||
ASSERT_EQ(graph_inputs.size(), 1);
|
||||
ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
// typically this isn't needed, but we want to replace this input and read info from it later on in the test
|
||||
// validation so we move it out of the vector so it's saved locally.
|
||||
auto orig_input_name = graph_inputs[0].Name();
|
||||
auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape();
|
||||
const std::string new_input_name = "Int64Input";
|
||||
|
||||
// Add Cast node to convert input from float to int64
|
||||
|
|
@ -423,16 +427,16 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT));
|
||||
|
||||
ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, {input_names[0]},
|
||||
ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"},
|
||||
// the existing node will now consume the output from the Cast instead of a graph input
|
||||
{orig_input_name},
|
||||
attributes);
|
||||
|
||||
// we're replacing the only input, so we don't need to call session.GetInputTypeInfo(x) to copy other inputs
|
||||
// in order to preserve them
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
|
||||
// we're replacing the only input. the shape is the same but the name and data type change.
|
||||
TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
|
||||
orig_input.GetTensorTypeAndShapeInfo().GetShape());
|
||||
input_shape);
|
||||
auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst());
|
||||
graph_inputs.emplace_back(new_input_name, input_type_info.GetConst());
|
||||
graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst());
|
||||
|
||||
ModelBuilderAPI::Graph graph; // new info to augment the model with
|
||||
|
||||
|
|
@ -441,13 +445,12 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
|
||||
// the node we added does not require any new opsets.
|
||||
model.AddGraph(graph);
|
||||
|
||||
session.FinalizeModelBuilderSession(model, so);
|
||||
|
||||
std::vector<Input<int64_t>> inputs(1);
|
||||
auto& input = inputs[0];
|
||||
input.name = new_input_name.c_str();
|
||||
input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape();
|
||||
input.dims = input_shape;
|
||||
|
||||
auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies<int64_t>());
|
||||
input.values.resize(size_t(num_values));
|
||||
|
|
@ -465,8 +468,8 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so);
|
||||
std::vector<Input<float>> expected_inputs(1);
|
||||
auto& expected_input = expected_inputs[0];
|
||||
expected_input.name = input_names[0].c_str();
|
||||
expected_input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape();
|
||||
expected_input.name = orig_input_name.c_str();
|
||||
expected_input.dims = input_shape;
|
||||
expected_input.values.reserve(size_t(num_values));
|
||||
std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values),
|
||||
[&](int64_t value) { return float(value); });
|
||||
|
|
@ -489,10 +492,141 @@ TEST(ModelBuilderAPITest, InvalidDimension) {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Tests required
|
||||
TEST(ModelBuilderAPITest, CreateInvalidModel_NoOpsets) {
|
||||
Ort::ModelBuilderAPI::Graph graph;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_outputs;
|
||||
|
||||
- Create invalid model. Graph::Resolve should fail.
|
||||
- Invalid edit. Graph::Resolve should fail.
|
||||
- All the non-tensor Create*TypeInfo functions need to be validated
|
||||
*/
|
||||
std::vector<int64_t> dims({4});
|
||||
TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims);
|
||||
auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst());
|
||||
graph_inputs.emplace_back("X", type_info.GetConst());
|
||||
graph_outputs.emplace_back("Z", type_info.GetConst());
|
||||
|
||||
graph.SetInputs(graph_inputs);
|
||||
graph.SetOutputs(graph_outputs);
|
||||
|
||||
ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "X"}, {"Z"});
|
||||
|
||||
graph.AddNode(node);
|
||||
|
||||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets;
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
model.AddGraph(graph);
|
||||
|
||||
try {
|
||||
auto session = CreateSession(*ort_env, model);
|
||||
FAIL();
|
||||
} catch (const Ort::Exception& e) {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ModelBuilderAPITest, CreateInvalidModel_MissingValue) {
|
||||
Ort::ModelBuilderAPI::Graph graph;
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_outputs;
|
||||
|
||||
std::vector<int64_t> dims({4});
|
||||
TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims);
|
||||
auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst());
|
||||
graph_inputs.emplace_back("X", type_info.GetConst());
|
||||
graph_outputs.emplace_back("Z", type_info.GetConst());
|
||||
|
||||
graph.SetInputs(graph_inputs);
|
||||
graph.SetOutputs(graph_outputs);
|
||||
|
||||
ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "missing"}, {"Z"});
|
||||
graph.AddNode(node);
|
||||
|
||||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets{{onnxruntime::kOnnxDomain, 18}};
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
model.AddGraph(graph);
|
||||
|
||||
try {
|
||||
auto session = CreateSession(*ort_env, model);
|
||||
FAIL();
|
||||
} catch (const Ort::Exception& e) {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("Node input 'missing' is not a graph input, "
|
||||
"initializer, or output of a previous node."));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ModelBuilderAPITest, InvalidModelEdit) {
|
||||
// Add a node but make the edit invalid in various ways
|
||||
// - add node but don't update graph inputs
|
||||
// - add node with invalid domain
|
||||
const auto edit_model = [](bool invalid_domain) {
|
||||
SessionOptions so;
|
||||
|
||||
// Set this to save the model if you want to debug.
|
||||
// so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx"));
|
||||
|
||||
Session session = Session::CreateModelBuilderSession(*ort_env, TSTR("testdata/mnist.onnx"), so);
|
||||
|
||||
ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string
|
||||
|
||||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets; // no additional opsets required
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
ModelBuilderAPI::Graph graph; // new info to augment the model with
|
||||
|
||||
const char* domain = invalid_domain ? "invalid_domain" : onnxruntime::kOnnxDomain;
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs = session.GetInputs();
|
||||
ASSERT_EQ(graph_inputs.size(), 1);
|
||||
ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
const std::string new_input_name = "Int64Input";
|
||||
|
||||
// Add Cast node to convert input from float to int64
|
||||
std::vector<OpAttr> attributes;
|
||||
int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT));
|
||||
|
||||
ModelBuilderAPI::Node node("Cast", domain, "NewInputNode", {new_input_name},
|
||||
// the existing node will now consume the output from the Cast instead of a graph input
|
||||
{graph_inputs[0].Name()},
|
||||
attributes);
|
||||
graph.AddNode(node);
|
||||
|
||||
if (invalid_domain) {
|
||||
// we're replacing the only input. the shape is the same but the name and data type change.
|
||||
TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
|
||||
graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape());
|
||||
auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst());
|
||||
graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst());
|
||||
graph.SetInputs(graph_inputs);
|
||||
} else {
|
||||
// model should be invalid as we didn't connect the new node up to the graph inputs
|
||||
}
|
||||
|
||||
// the node we added does not require any new opsets.
|
||||
model.AddGraph(graph);
|
||||
|
||||
try {
|
||||
session.FinalizeModelBuilderSession(model, so);
|
||||
FAIL() << "Should have failed to resolve graph due to invalid edits.";
|
||||
} catch (const Ort::Exception& e) {
|
||||
if (invalid_domain) {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain 'invalid_domain'"));
|
||||
} else {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("This is an invalid model"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
edit_model(false);
|
||||
edit_model(true); // add node with invalid domain
|
||||
}
|
||||
|
||||
TEST(ModelBuilderAPITest, CreateTypeInfo) {
|
||||
// sparse tensor
|
||||
|
||||
// sequence
|
||||
|
||||
// map
|
||||
|
||||
// optional
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue