Address PR comments

Add unit tests
This commit is contained in:
Scott McKay 2025-01-16 19:43:46 +10:00
parent 45d5906358
commit 453f13a2b5
12 changed files with 253 additions and 114 deletions

View file

@ -40,7 +40,7 @@
#include "core/graph/node_arg.h"
#include "core/graph/ort_format_load_options.h"
// Type from Graph API in ORT C API so can't be in a namespace
// Type from Model Builder API in ORT C API so can't be in a namespace
struct OrtGraph;
namespace onnxruntime {

View file

@ -5047,7 +5047,7 @@ struct OrtModelBuilderApi {
/** \brief Create an OrtValueInfo for use as an OrtGraph input or output.
*
* \param[in] name The name of the input or output.
* \param[in] type_info The type information for the input or output.
* \param[in] type_info The type information for the input or output. The provided value is copied.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
@ -5102,7 +5102,7 @@ struct OrtModelBuilderApi {
*
* \since Version 1.21.
*/
ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, const char* domain_name, _In_ const char* node_name,
ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name,
_In_reads_(input_names_len) const char* const* input_names, size_t input_names_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len,
@ -5263,9 +5263,16 @@ struct OrtModelBuilderApi {
/** \brief Create an OrtSession using the OrtModel.
*
* Create an inference session using the OrtModel.
* Create an inference session using the OrtModel instance.
* The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs
* and SetGraphOutputs must have been called.
* This will validate the model, run optimizers, and prepare the session for inferencing.
*
* \param[in] env The OrtEnv instance.
* \param[in] model The OrtModel instance.
* \param[in] options The OrtSessionOptions instance.
* \param[out] out The OrtSession instance.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.21.
@ -5275,11 +5282,17 @@ struct OrtModelBuilderApi {
/** \brief Create an OrtSession to augment an existing model.
*
* Create an OrtSession with an existing model that can be augmented with additional nodes.
* Nodes can be added to the model using AddNodeToGraph.
* Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes.
* Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling
* FinalizeModelBuilderSession.
* Create an OrtSession with an existing model that will be augmented with additional nodes.
* Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the
* model is finalized.
*
* To add nodes to the existing model, first create an OrtModel using CreateModel.
* Add additional nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph.
* Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs to reflect changes made by the new
* nodes. The list of inputs/outputs should be for the overall model and not just the new nodes.
*
* Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the
* session for inferencing by calling FinalizeModelBuilderSession.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
@ -5291,11 +5304,22 @@ struct OrtModelBuilderApi {
/** \brief Create an OrtSession to augment an existing model.
*
* Create an OrtSession with an existing model that can be augmented with additional nodes.
* Nodes can be added to the model using AddNodeToGraph.
* Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes.
* Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling
* FinalizeModelBuilderSession.
* Create an OrtSession with an existing model that will be augmented with additional nodes.
* Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the
* model is finalized.
*
* To add nodes to the existing model, first create an OrtModel using CreateModel.
* Add additional nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph.
* Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs to reflect changes made by the new
* nodes. The list of inputs/outputs should be for the overall model and not just the new nodes.
*
* Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the
* session for inferencing by calling FinalizeModelBuilderSession.
*
* \param{in} env The OrtEnv instance.
* \param{in} model_data The model data for the existing model to augment.
* \param{in} model_data_length The length of the model data.
* \param{in} options The OrtSessionOptions instance.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
@ -5308,12 +5332,12 @@ struct OrtModelBuilderApi {
/** \brief Apply the changes from the model to the session.
*
* Apply the changes from the model to the session that was created using CreateModelBuilderSession[FromArray].
* Apply the changes from the model to a session that was created using CreateModelBuilderSession[FromArray].
* All changes will be validated.
* Call FinalizeModelBuilderSession to prepare the session for inferencing.
*
* Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel.
* i.e. you don't need to call SetGraphInputs/Outputs if they are unchanged.
* i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*

View file

@ -667,7 +667,8 @@ struct ModelMetadata;
namespace ModelBuilderAPI {
struct Model;
}
struct ValueInfo;
} // namespace ModelBuilderAPI
/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
* and release them at the end of the scope. The lifespan of the given allocator
@ -1120,6 +1121,11 @@ struct ConstSessionImpl : Base<T> {
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain
// TODO: Should we move ValueInfo from ModelBuilderAPI to the ORT API? It will also be relevant to the Plugin EP API
// although the internal implementation may differ there.
std::vector<ModelBuilderAPI::ValueInfo> GetInputs() const;
std::vector<ModelBuilderAPI::ValueInfo> GetOutputs() const;
};
template <typename T>
@ -2592,9 +2598,6 @@ struct NodeImpl : Ort::detail::Base<T> {
};
} // namespace detail
// Const object holder that does not own the underlying object
using ConstNode = detail::NodeImpl<Ort::detail::Unowned<const OrtNode>>;
/** \brief Wrapper around ::OrtNode
*
*/
@ -2616,8 +2619,6 @@ struct Node : detail::NodeImpl<OrtNode> {
const std::vector<std::string>& output_names,
std::vector<OpAttr>& attributes);
ConstNode GetConst() const { return ConstNode{this->p_}; }
private:
static void Init(const std::string& operator_name, const std::string& operator_domain,
const std::string& node_name,
@ -2640,9 +2641,6 @@ struct GraphImpl : Ort::detail::Base<T> {
};
} // namespace detail
// Const object holder that does not own the underlying object
using ConstGraph = detail::GraphImpl<Ort::detail::Unowned<const OrtGraph>>;
/** \brief Wrapper around ::OrtGraph
*
*/
@ -2650,8 +2648,6 @@ struct Graph : detail::GraphImpl<OrtGraph> {
explicit Graph(std::nullptr_t) {} ///< No instance is created
explicit Graph(OrtGraph* p) : GraphImpl<OrtGraph>{p} {} ///< Take ownership of a pointer created by C API
Graph();
ConstGraph GetConst() const { return ConstGraph{this->p_}; }
};
namespace detail {
@ -2675,7 +2671,7 @@ struct Model : detail::ModelImpl<OrtModel> {
explicit Model(std::nullptr_t) {} ///< No instance is created
explicit Model(OrtModel* p) : ModelImpl<OrtModel>{p} {} ///< Take ownership of a pointer created by C API
Model(const std::vector<DomainOpsetPair>& opsets);
explicit Model(const std::vector<DomainOpsetPair>& opsets);
ConstModel GetConst() const { return ConstModel{this->p_}; }
};

View file

@ -10,7 +10,9 @@
#include <algorithm>
#include <functional>
#include <iterator>
#include <string>
#include <type_traits>
#include <vector>
// Convert OrtStatus to Ort::Status and return
// instead of throwing
@ -1109,6 +1111,36 @@ inline int ConstSessionImpl<T>::GetOpset(const std::string& domain) const {
return opset;
}
template <typename T>
std::vector<ModelBuilderAPI::ValueInfo> ConstSessionImpl<T>::GetInputs() const {
const std::vector<std::string> input_names = GetInputNames();
std::vector<ModelBuilderAPI::ValueInfo> inputs;
inputs.reserve(input_names.size());
for (int i = 0; i < input_names.size(); ++i) {
auto type_info = GetInputTypeInfo(i);
inputs.emplace_back(ModelBuilderAPI::ValueInfo{input_names[i], type_info.GetConst()});
}
return inputs;
}
template <typename T>
std::vector<ModelBuilderAPI::ValueInfo> ConstSessionImpl<T>::GetOutputs() const {
const std::vector<std::string> output_names = GetOutputNames();
std::vector<ModelBuilderAPI::ValueInfo> outputs;
outputs.reserve(output_names.size());
for (int i = 0; i < output_names.size(); ++i) {
auto type_info = GetOutputTypeInfo(i);
outputs.emplace_back(ModelBuilderAPI::ValueInfo{output_names[i], type_info.GetConst()});
}
return outputs;
}
template <typename T>
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count) {
@ -2328,6 +2360,7 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con
}
namespace ModelBuilderAPI {
namespace detail {
inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>& strings) {
std::vector<const char*> ptrs;
ptrs.reserve(strings.size());
@ -2336,6 +2369,7 @@ inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>
return ptrs;
}
} // namespace detail
// static
inline void Node::Init(const std::string& operator_name, const std::string& operator_domain,
@ -2344,8 +2378,8 @@ inline void Node::Init(const std::string& operator_name, const std::string& oper
const std::vector<std::string>& output_names,
std::vector<OpAttr>& attributes,
OrtNode*& node) {
auto inputs = StringsToCharPtrs(input_names);
auto outputs = StringsToCharPtrs(output_names);
auto inputs = detail::StringsToCharPtrs(input_names);
auto outputs = detail::StringsToCharPtrs(output_names);
std::vector<OrtOpAttr*> attributes_ptrs;
attributes_ptrs.reserve(attributes.size());

View file

@ -300,13 +300,3 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
// Create an Inference Session that will use the Model Builder API to create/update the model.
// This flag will create the session but not fully initialize it. A model, if provided, will be loaded.
// A session logger will be created, and execution providers will be registered.
// Any device specific allocators and IDataTransfer objects will be registered.
// This allows CreateAllocator to return device specific allocators registered by EPs.
// FUTURE: This will also allow CopyTensors to utilize the IDataTransfer objects
// "0": Disabled. [DEFAULT]
// "1": Enable Model Builder Session
static const char* const kOrtSessionOptionsEnableModelBuilder = "session.model_builder_session";

View file

@ -266,18 +266,6 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
return Status::OK();
}
bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
bool has_external_memory = false;
if (utils::HasExternalData(tensor_proto)) {
std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
ORT_THROW_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
has_external_memory = external_data_info->GetRelPath() == onnxruntime::utils::kTensorProtoMemoryAddressTag;
}
return has_external_memory;
}
void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) {
tensor_proto.set_raw_data(std::move(param));
}
@ -1301,22 +1289,15 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor,
const auto* raw_data = tensor.DataRaw();
ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor.");
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
// we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto.
// use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the
// high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path.
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(raw_data));
ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("location");
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("offset");
entry->set_value(std::to_string(offset));
entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("length");
entry->set_value(std::to_string(tensor.SizeInBytes()));
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
offset, tensor.SizeInBytes(), tensor_proto);
} else {
utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes());
}

View file

@ -514,10 +514,6 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) {
}
#endif
// Check if the TensorProto has an external data entry that points to memory rather than an external file.
// The external data location will be kTensorProtoMemoryAddressTag in this case.
bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto);
// UnpackTensor from raw data or the type specific data field. Does not handle external data.
// If the tensor does not contain raw data then raw_data should be nullptr and raw_data_len should be 0.
template <typename T>

View file

@ -5807,21 +5807,11 @@ Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updat
if (is_external) {
// pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
const void* data_offset = t.DataRaw(); // address of memory not offset into file
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(data_offset));
ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("location");
// magic tag for existing memory that causes 'offset' to be treated as a pointer to the memory
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("offset");
entry->set_value(std::to_string(offset));
entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("length");
entry->set_value(std::to_string(t.SizeInBytes()));
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
offset, t.SizeInBytes(), tensor_proto);
// copy OrtValue to keep it alive and to store the deleter if provided.
ortvalue_initializers_.emplace(name, v);

View file

@ -300,8 +300,6 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
const auto* fbs_raw_data = fbs_tensor.raw_data();
if (fbs_raw_data) {
if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
const void* data_offset = fbs_raw_data->Data();
// we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto.
@ -309,15 +307,9 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
// high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path.
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(data_offset));
ONNX_NAMESPACE::StringStringEntryProto* entry = initializer.mutable_external_data()->Add();
entry->set_key("location");
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
entry = initializer.mutable_external_data()->Add();
entry->set_key("offset");
entry->set_value(std::to_string(offset));
entry = initializer.mutable_external_data()->Add();
entry->set_key("length");
entry->set_value(std::to_string(fbs_raw_data->size()));
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
offset, fbs_raw_data->size(), initializer);
} else {
// fbs_raw_data is uint8_t vector, so the size is byte size
initializer.set_raw_data(fbs_raw_data->Data(), fbs_raw_data->size());

View file

@ -627,7 +627,7 @@ class InferenceSession {
/// convenience pointer to logger. should always be the same as session_state_.Logger();
const logging::Logger* session_logger_;
// The list of execution providers.
// The list of execution providers.
// This MUST be prior to model_ in case there are values in the model that were allocated using an allocator
// provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the
// EP instance.

View file

@ -48,13 +48,15 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateValueInfo, _In_ const char* name,
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) {
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info,
_Out_ const char** name) {
API_IMPL_BEGIN
*name = value_info->name.c_str();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info) {
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info,
_Outptr_ const OrtTypeInfo** type_info) {
API_IMPL_BEGIN
*type_info = value_info->type_info.get();

View file

@ -369,6 +369,8 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) {
ModelBuilderAPI::Model model(opsets);
model.AddGraph(graph);
auto session = CreateSession(*ort_env, model);
std::vector<Input<float>> inputs(1);
auto& input = inputs[0];
input.name = "X";
@ -379,7 +381,6 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) {
std::vector<int64_t> expected_dims = {3, 8};
auto session = CreateSession(*ort_env, model);
TestInference<float>(session, inputs, "Z", expected_dims,
{340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f,
596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f,
@ -410,12 +411,15 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets; // no additional opsets required
ModelBuilderAPI::Model model(opsets);
std::vector<std::string> input_names = session.GetInputNames();
ASSERT_EQ(input_names.size(), 1);
TypeInfo orig_input = session.GetInputTypeInfo(0);
ASSERT_EQ(orig_input.GetTensorTypeAndShapeInfo().GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs = session.GetInputs();
ASSERT_EQ(graph_inputs.size(), 1);
ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
// typically this isn't needed, but we want to replace this input and read info from it later on in the test
// validation so we move it out of the vector so it's saved locally.
auto orig_input_name = graph_inputs[0].Name();
auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape();
const std::string new_input_name = "Int64Input";
// Add Cast node to convert input from float to int64
@ -423,16 +427,16 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT));
ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, {input_names[0]},
ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"},
// the existing node will now consume the output from the Cast instead of a graph input
{orig_input_name},
attributes);
// we're replacing the only input, so we don't need to call session.GetInputTypeInfo(x) to copy other inputs
// in order to preserve them
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
// we're replacing the only input. the shape is the same but the name and data type change.
TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
orig_input.GetTensorTypeAndShapeInfo().GetShape());
input_shape);
auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst());
graph_inputs.emplace_back(new_input_name, input_type_info.GetConst());
graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst());
ModelBuilderAPI::Graph graph; // new info to augment the model with
@ -441,13 +445,12 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
// the node we added does not require any new opsets.
model.AddGraph(graph);
session.FinalizeModelBuilderSession(model, so);
std::vector<Input<int64_t>> inputs(1);
auto& input = inputs[0];
input.name = new_input_name.c_str();
input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape();
input.dims = input_shape;
auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies<int64_t>());
input.values.resize(size_t(num_values));
@ -465,8 +468,8 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so);
std::vector<Input<float>> expected_inputs(1);
auto& expected_input = expected_inputs[0];
expected_input.name = input_names[0].c_str();
expected_input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape();
expected_input.name = orig_input_name.c_str();
expected_input.dims = input_shape;
expected_input.values.reserve(size_t(num_values));
std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values),
[&](int64_t value) { return float(value); });
@ -489,10 +492,141 @@ TEST(ModelBuilderAPITest, InvalidDimension) {
}
}
/*
Tests required
TEST(ModelBuilderAPITest, CreateInvalidModel_NoOpsets) {
Ort::ModelBuilderAPI::Graph graph;
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
std::vector<ModelBuilderAPI::ValueInfo> graph_outputs;
- Create invalid model. Graph::Resolve should fail.
- Invalid edit. Graph::Resolve should fail.
- All the non-tensor Create*TypeInfo functions need to be validated
*/
std::vector<int64_t> dims({4});
TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims);
auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst());
graph_inputs.emplace_back("X", type_info.GetConst());
graph_outputs.emplace_back("Z", type_info.GetConst());
graph.SetInputs(graph_inputs);
graph.SetOutputs(graph_outputs);
ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "X"}, {"Z"});
graph.AddNode(node);
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets;
ModelBuilderAPI::Model model(opsets);
model.AddGraph(graph);
try {
auto session = CreateSession(*ort_env, model);
FAIL();
} catch (const Ort::Exception& e) {
ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain"));
}
}
TEST(ModelBuilderAPITest, CreateInvalidModel_MissingValue) {
Ort::ModelBuilderAPI::Graph graph;
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
std::vector<ModelBuilderAPI::ValueInfo> graph_outputs;
std::vector<int64_t> dims({4});
TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims);
auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst());
graph_inputs.emplace_back("X", type_info.GetConst());
graph_outputs.emplace_back("Z", type_info.GetConst());
graph.SetInputs(graph_inputs);
graph.SetOutputs(graph_outputs);
ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "missing"}, {"Z"});
graph.AddNode(node);
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets{{onnxruntime::kOnnxDomain, 18}};
ModelBuilderAPI::Model model(opsets);
model.AddGraph(graph);
try {
auto session = CreateSession(*ort_env, model);
FAIL();
} catch (const Ort::Exception& e) {
ASSERT_THAT(e.what(), ::testing::HasSubstr("Node input 'missing' is not a graph input, "
"initializer, or output of a previous node."));
}
}
TEST(ModelBuilderAPITest, InvalidModelEdit) {
// Add a node but make the edit invalid in various ways
// - add node but don't update graph inputs
// - add node with invalid domain
const auto edit_model = [](bool invalid_domain) {
SessionOptions so;
// Set this to save the model if you want to debug.
// so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx"));
Session session = Session::CreateModelBuilderSession(*ort_env, TSTR("testdata/mnist.onnx"), so);
ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets; // no additional opsets required
ModelBuilderAPI::Model model(opsets);
ModelBuilderAPI::Graph graph; // new info to augment the model with
const char* domain = invalid_domain ? "invalid_domain" : onnxruntime::kOnnxDomain;
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs = session.GetInputs();
ASSERT_EQ(graph_inputs.size(), 1);
ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
const std::string new_input_name = "Int64Input";
// Add Cast node to convert input from float to int64
std::vector<OpAttr> attributes;
int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT));
ModelBuilderAPI::Node node("Cast", domain, "NewInputNode", {new_input_name},
// the existing node will now consume the output from the Cast instead of a graph input
{graph_inputs[0].Name()},
attributes);
graph.AddNode(node);
if (invalid_domain) {
// we're replacing the only input. the shape is the same but the name and data type change.
TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape());
auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst());
graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst());
graph.SetInputs(graph_inputs);
} else {
// model should be invalid as we didn't connect the new node up to the graph inputs
}
// the node we added does not require any new opsets.
model.AddGraph(graph);
try {
session.FinalizeModelBuilderSession(model, so);
FAIL() << "Should have failed to resolve graph due to invalid edits.";
} catch (const Ort::Exception& e) {
if (invalid_domain) {
ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain 'invalid_domain'"));
} else {
ASSERT_THAT(e.what(), ::testing::HasSubstr("This is an invalid model"));
}
}
};
edit_model(false);
edit_model(true); // add node with invalid domain
}
TEST(ModelBuilderAPITest, CreateTypeInfo) {
// sparse tensor
// sequence
// map
// optional
}