diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 2c40b41774..39a08962af 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -40,7 +40,7 @@ #include "core/graph/node_arg.h" #include "core/graph/ort_format_load_options.h" -// Type from Graph API in ORT C API so can't be in a namespace +// Type from Model Builder API in ORT C API so can't be in a namespace struct OrtGraph; namespace onnxruntime { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index fcd42323ef..e0beb2f22b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5047,7 +5047,7 @@ struct OrtModelBuilderApi { /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. * * \param[in] name The name of the input or output. - * \param[in] type_info The type information for the input or output. + * \param[in] type_info The type information for the input or output. The provided value is copied. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -5102,7 +5102,7 @@ struct OrtModelBuilderApi { * * \since Version 1.21. */ - ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, const char* domain_name, _In_ const char* node_name, + ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name, _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, @@ -5263,9 +5263,16 @@ struct OrtModelBuilderApi { /** \brief Create an OrtSession using the OrtModel. * - * Create an inference session using the OrtModel. + * Create an inference session using the OrtModel instance. + * The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs + * and SetGraphOutputs must have been called. * This will validate the model, run optimizers, and prepare the session for inferencing. * + * \param[in] env The OrtEnv instance. + * \param[in] model The OrtModel instance. + * \param[in] options The OrtSessionOptions instance. + * \param[out] out The OrtSession instance. + * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.21. @@ -5275,11 +5282,17 @@ struct OrtModelBuilderApi { /** \brief Create an OrtSession to augment an existing model. * - * Create an OrtSession with an existing model that can be augmented with additional nodes. - * Nodes can be added to the model using AddNodeToGraph. - * Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes. - * Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling - * FinalizeModelBuilderSession. + * Create an OrtSession with an existing model that will be augmented with additional nodes. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes to the existing model, first create an OrtModel using CreateModel. + * Add additional nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs to reflect changes made by the new + * nodes. The list of inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelBuilderSession. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -5291,11 +5304,22 @@ struct OrtModelBuilderApi { /** \brief Create an OrtSession to augment an existing model. * - * Create an OrtSession with an existing model that can be augmented with additional nodes. - * Nodes can be added to the model using AddNodeToGraph. - * Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes. - * Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling - * FinalizeModelBuilderSession. + * Create an OrtSession with an existing model that will be augmented with additional nodes. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes to the existing model, first create an OrtModel using CreateModel. + * Add additional nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs to reflect changes made by the new + * nodes. The list of inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelBuilderSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_data The model data for the existing model to augment. + * \param{in} model_data_length The length of the model data. + * \param{in} options The OrtSessionOptions instance. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -5308,12 +5332,12 @@ struct OrtModelBuilderApi { /** \brief Apply the changes from the model to the session. * - * Apply the changes from the model to the session that was created using CreateModelBuilderSession[FromArray]. + * Apply the changes from the model to a session that was created using CreateModelBuilderSession[FromArray]. * All changes will be validated. * Call FinalizeModelBuilderSession to prepare the session for inferencing. * * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. - * i.e. you don't need to call SetGraphInputs/Outputs if they are unchanged. + * i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged. * * \snippet{doc} snippets.dox OrtStatus Return Value * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 715f61b171..061c24ba4c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -667,7 +667,8 @@ struct ModelMetadata; namespace ModelBuilderAPI { struct Model; -} +struct ValueInfo; +} // namespace ModelBuilderAPI /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators * and release them at the end of the scope. The lifespan of the given allocator @@ -1120,6 +1121,11 @@ struct ConstSessionImpl : Base { TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain + + // TODO: Should we move ValueInfo from ModelBuilderAPI to the ORT API? It will also be relevant to the Plugin EP API + // although the internal implementation may differ there. + std::vector GetInputs() const; + std::vector GetOutputs() const; }; template @@ -2592,9 +2598,6 @@ struct NodeImpl : Ort::detail::Base { }; } // namespace detail -// Const object holder that does not own the underlying object -using ConstNode = detail::NodeImpl>; - /** \brief Wrapper around ::OrtNode * */ @@ -2616,8 +2619,6 @@ struct Node : detail::NodeImpl { const std::vector& output_names, std::vector& attributes); - ConstNode GetConst() const { return ConstNode{this->p_}; } - private: static void Init(const std::string& operator_name, const std::string& operator_domain, const std::string& node_name, @@ -2640,9 +2641,6 @@ struct GraphImpl : Ort::detail::Base { }; } // namespace detail -// Const object holder that does not own the underlying object -using ConstGraph = detail::GraphImpl>; - /** \brief Wrapper around ::OrtGraph * */ @@ -2650,8 +2648,6 @@ struct Graph : detail::GraphImpl { explicit Graph(std::nullptr_t) {} ///< No instance is created explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API Graph(); - - ConstGraph GetConst() const { return ConstGraph{this->p_}; } }; namespace detail { @@ -2675,7 +2671,7 @@ struct Model : detail::ModelImpl { explicit Model(std::nullptr_t) {} ///< No instance is created explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API - Model(const std::vector& opsets); + explicit Model(const std::vector& opsets); ConstModel GetConst() const { return ConstModel{this->p_}; } }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 7365d39938..de1ab81080 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include // Convert OrtStatus to Ort::Status and return // instead of throwing @@ -1109,6 +1111,36 @@ inline int ConstSessionImpl::GetOpset(const std::string& domain) const { return opset; } +template +std::vector ConstSessionImpl::GetInputs() const { + const std::vector input_names = GetInputNames(); + + std::vector inputs; + inputs.reserve(input_names.size()); + + for (int i = 0; i < input_names.size(); ++i) { + auto type_info = GetInputTypeInfo(i); + inputs.emplace_back(ModelBuilderAPI::ValueInfo{input_names[i], type_info.GetConst()}); + } + + return inputs; +} + +template +std::vector ConstSessionImpl::GetOutputs() const { + const std::vector output_names = GetOutputNames(); + + std::vector outputs; + outputs.reserve(output_names.size()); + + for (int i = 0; i < output_names.size(); ++i) { + auto type_info = GetOutputTypeInfo(i); + outputs.emplace_back(ModelBuilderAPI::ValueInfo{output_names[i], type_info.GetConst()}); + } + + return outputs; +} + template inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count) { @@ -2328,6 +2360,7 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con } namespace ModelBuilderAPI { +namespace detail { inline std::vector StringsToCharPtrs(const std::vector& strings) { std::vector ptrs; ptrs.reserve(strings.size()); @@ -2336,6 +2369,7 @@ inline std::vector StringsToCharPtrs(const std::vector return ptrs; } +} // namespace detail // static inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, @@ -2344,8 +2378,8 @@ inline void Node::Init(const std::string& operator_name, const std::string& oper const std::vector& output_names, std::vector& attributes, OrtNode*& node) { - auto inputs = StringsToCharPtrs(input_names); - auto outputs = StringsToCharPtrs(output_names); + auto inputs = detail::StringsToCharPtrs(input_names); + auto outputs = detail::StringsToCharPtrs(output_names); std::vector attributes_ptrs; attributes_ptrs.reserve(attributes.size()); diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 0c3ebebc52..64a4dd19c1 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -300,13 +300,3 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio // “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] // “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; - -// Create an Inference Session that will use the Model Builder API to create/update the model. -// This flag will create the session but not fully initialize it. A model, if provided, will be loaded. -// A session logger will be created, and execution providers will be registered. -// Any device specific allocators and IDataTransfer objects will be registered. -// This allows CreateAllocator to return device specific allocators registered by EPs. -// FUTURE: This will also allow CopyTensors to utilize the IDataTransfer objects -// "0": Disabled. [DEFAULT] -// "1": Enable Model Builder Session -static const char* const kOrtSessionOptionsEnableModelBuilder = "session.model_builder_session"; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index c372bd97fc..bf8fcb02c1 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -266,18 +266,6 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, return Status::OK(); } -bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto) { - bool has_external_memory = false; - if (utils::HasExternalData(tensor_proto)) { - std::unique_ptr external_data_info; - ORT_THROW_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - - has_external_memory = external_data_info->GetRelPath() == onnxruntime::utils::kTensorProtoMemoryAddressTag; - } - - return has_external_memory; -} - void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) { tensor_proto.set_raw_data(std::move(param)); } @@ -1301,22 +1289,15 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const auto* raw_data = tensor.DataRaw(); ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); - tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(raw_data)); - ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("location"); - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(tensor.SizeInBytes())); + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, tensor.SizeInBytes(), tensor_proto); + } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index ea6edb7fcd..7b9a478423 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -514,10 +514,6 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) { } #endif -// Check if the TensorProto has an external data entry that points to memory rather than an external file. -// The external data location will be kTensorProtoMemoryAddressTag in this case. -bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto); - // UnpackTensor from raw data or the type specific data field. Does not handle external data. // If the tensor does not contain raw data then raw_data should be nullptr and raw_data_len should be 0. template diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 660b0cf288..6008168e66 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -5807,21 +5807,11 @@ Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updat if (is_external) { // pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo - tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); - const void* data_offset = t.DataRaw(); // address of memory not offset into file auto offset = narrow(reinterpret_cast(data_offset)); - ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("location"); - // magic tag for existing memory that causes 'offset' to be treated as a pointer to the memory - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(t.SizeInBytes())); + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, t.SizeInBytes(), tensor_proto); // copy OrtValue to keep it alive and to store the deleter if provided. ortvalue_initializers_.emplace(name, v); diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 922759b02e..199aa79cc1 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -300,8 +300,6 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init const auto* fbs_raw_data = fbs_tensor.raw_data(); if (fbs_raw_data) { if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { - initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); - static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); const void* data_offset = fbs_raw_data->Data(); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. @@ -309,15 +307,9 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(data_offset)); - ONNX_NAMESPACE::StringStringEntryProto* entry = initializer.mutable_external_data()->Add(); - entry->set_key("location"); - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = initializer.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = initializer.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(fbs_raw_data->size())); + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, fbs_raw_data->size(), initializer); + } else { // fbs_raw_data is uint8_t vector, so the size is byte size initializer.set_raw_data(fbs_raw_data->Data(), fbs_raw_data->size()); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 89a2693d19..6432a7c940 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -627,7 +627,7 @@ class InferenceSession { /// convenience pointer to logger. should always be the same as session_state_.Logger(); const logging::Logger* session_logger_; - // The list of execution providers. + // The list of execution providers. // This MUST be prior to model_ in case there are values in the model that were allocated using an allocator // provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the // EP instance. diff --git a/onnxruntime/core/session/model_builder_c_api.cc b/onnxruntime/core/session/model_builder_c_api.cc index 8eac1ebce3..8d71eaed07 100644 --- a/onnxruntime/core/session/model_builder_c_api.cc +++ b/onnxruntime/core/session/model_builder_c_api.cc @@ -48,13 +48,15 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateValueInfo, _In_ const char* name, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) { +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info, + _Out_ const char** name) { API_IMPL_BEGIN *name = value_info->name.c_str(); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info) { +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtTypeInfo** type_info) { API_IMPL_BEGIN *type_info = value_info->type_info.get(); diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index b6314f48a0..91d82e042d 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -369,6 +369,8 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { ModelBuilderAPI::Model model(opsets); model.AddGraph(graph); + auto session = CreateSession(*ort_env, model); + std::vector> inputs(1); auto& input = inputs[0]; input.name = "X"; @@ -379,7 +381,6 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { std::vector expected_dims = {3, 8}; - auto session = CreateSession(*ort_env, model); TestInference(session, inputs, "Z", expected_dims, {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, @@ -410,12 +411,15 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { std::vector opsets; // no additional opsets required ModelBuilderAPI::Model model(opsets); - std::vector input_names = session.GetInputNames(); - ASSERT_EQ(input_names.size(), 1); - - TypeInfo orig_input = session.GetInputTypeInfo(0); - ASSERT_EQ(orig_input.GetTensorTypeAndShapeInfo().GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + std::vector graph_inputs = session.GetInputs(); + ASSERT_EQ(graph_inputs.size(), 1); + ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + // typically this isn't needed, but we want to replace this input and read info from it later on in the test + // validation so we move it out of the vector so it's saved locally. + auto orig_input_name = graph_inputs[0].Name(); + auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape(); const std::string new_input_name = "Int64Input"; // Add Cast node to convert input from float to int64 @@ -423,16 +427,16 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); - ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, {input_names[0]}, + ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, + // the existing node will now consume the output from the Cast instead of a graph input + {orig_input_name}, attributes); - // we're replacing the only input, so we don't need to call session.GetInputTypeInfo(x) to copy other inputs - // in order to preserve them - std::vector graph_inputs; + // we're replacing the only input. the shape is the same but the name and data type change. TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - orig_input.GetTensorTypeAndShapeInfo().GetShape()); + input_shape); auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); - graph_inputs.emplace_back(new_input_name, input_type_info.GetConst()); + graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst()); ModelBuilderAPI::Graph graph; // new info to augment the model with @@ -441,13 +445,12 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { // the node we added does not require any new opsets. model.AddGraph(graph); - session.FinalizeModelBuilderSession(model, so); std::vector> inputs(1); auto& input = inputs[0]; input.name = new_input_name.c_str(); - input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape(); + input.dims = input_shape; auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies()); input.values.resize(size_t(num_values)); @@ -465,8 +468,8 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so); std::vector> expected_inputs(1); auto& expected_input = expected_inputs[0]; - expected_input.name = input_names[0].c_str(); - expected_input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape(); + expected_input.name = orig_input_name.c_str(); + expected_input.dims = input_shape; expected_input.values.reserve(size_t(num_values)); std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values), [&](int64_t value) { return float(value); }); @@ -489,10 +492,141 @@ TEST(ModelBuilderAPITest, InvalidDimension) { } } -/* -Tests required +TEST(ModelBuilderAPITest, CreateInvalidModel_NoOpsets) { + Ort::ModelBuilderAPI::Graph graph; + std::vector graph_inputs; + std::vector graph_outputs; -- Create invalid model. Graph::Resolve should fail. -- Invalid edit. Graph::Resolve should fail. -- All the non-tensor Create*TypeInfo functions need to be validated -*/ + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph_outputs.emplace_back("Z", type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "X"}, {"Z"}); + + graph.AddNode(node); + + std::vector opsets; + ModelBuilderAPI::Model model(opsets); + model.AddGraph(graph); + + try { + auto session = CreateSession(*ort_env, model); + FAIL(); + } catch (const Ort::Exception& e) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain")); + } +} + +TEST(ModelBuilderAPITest, CreateInvalidModel_MissingValue) { + Ort::ModelBuilderAPI::Graph graph; + + std::vector graph_inputs; + std::vector graph_outputs; + + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph_outputs.emplace_back("Z", type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "missing"}, {"Z"}); + graph.AddNode(node); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + ModelBuilderAPI::Model model(opsets); + model.AddGraph(graph); + + try { + auto session = CreateSession(*ort_env, model); + FAIL(); + } catch (const Ort::Exception& e) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Node input 'missing' is not a graph input, " + "initializer, or output of a previous node.")); + } +} + +TEST(ModelBuilderAPITest, InvalidModelEdit) { + // Add a node but make the edit invalid in various ways + // - add node but don't update graph inputs + // - add node with invalid domain + const auto edit_model = [](bool invalid_domain) { + SessionOptions so; + + // Set this to save the model if you want to debug. + // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); + + Session session = Session::CreateModelBuilderSession(*ort_env, TSTR("testdata/mnist.onnx"), so); + + ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string + + std::vector opsets; // no additional opsets required + ModelBuilderAPI::Model model(opsets); + ModelBuilderAPI::Graph graph; // new info to augment the model with + + const char* domain = invalid_domain ? "invalid_domain" : onnxruntime::kOnnxDomain; + + std::vector graph_inputs = session.GetInputs(); + ASSERT_EQ(graph_inputs.size(), 1); + ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + const std::string new_input_name = "Int64Input"; + + // Add Cast node to convert input from float to int64 + std::vector attributes; + int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); + + ModelBuilderAPI::Node node("Cast", domain, "NewInputNode", {new_input_name}, + // the existing node will now consume the output from the Cast instead of a graph input + {graph_inputs[0].Name()}, + attributes); + graph.AddNode(node); + + if (invalid_domain) { + // we're replacing the only input. the shape is the same but the name and data type change. + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape()); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst()); + graph.SetInputs(graph_inputs); + } else { + // model should be invalid as we didn't connect the new node up to the graph inputs + } + + // the node we added does not require any new opsets. + model.AddGraph(graph); + + try { + session.FinalizeModelBuilderSession(model, so); + FAIL() << "Should have failed to resolve graph due to invalid edits."; + } catch (const Ort::Exception& e) { + if (invalid_domain) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain 'invalid_domain'")); + } else { + ASSERT_THAT(e.what(), ::testing::HasSubstr("This is an invalid model")); + } + } + }; + + edit_model(false); + edit_model(true); // add node with invalid domain +} + +TEST(ModelBuilderAPITest, CreateTypeInfo) { + // sparse tensor + + // sequence + + // map + + // optional +}