From d42399e1b07ce61e95aae88bc6b6ea5dcaae2011 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 6 Jan 2021 11:48:03 +0530 Subject: [PATCH] Allow querying a GraphProto's doc_string as part of ModelMetadata (#6248) --- .../InferenceSession.cs | 21 +++++++++ .../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 12 +++++ .../InferenceTest.cs | 2 + .../model_with_valid_ort_config_json.onnx | Bin 348 -> 367 bytes .../core/session/onnxruntime_c_api.h | 14 +++++- .../core/session/onnxruntime_cxx_api.h | 1 + .../core/session/onnxruntime_cxx_inline.h | 6 +++ onnxruntime/core/flatbuffers/schema/README.md | 5 ++- onnxruntime/core/flatbuffers/schema/ort.fbs | 4 +- onnxruntime/core/flatbuffers/schema/ort.fbs.h | 22 +++++++-- onnxruntime/core/graph/model.cc | 42 ++++++++++++++---- onnxruntime/core/graph/model.h | 40 ++++++++++------- onnxruntime/core/session/inference_session.cc | 7 ++- onnxruntime/core/session/inference_session.h | 1 + onnxruntime/core/session/onnxruntime_c_api.cc | 17 +++++-- onnxruntime/core/session/ort_apis.h | 2 + .../python/onnxruntime_pybind_state.cc | 1 + .../test/python/onnxruntime_test_python.py | 1 + onnxruntime/test/shared_lib/test_inference.cc | 29 +++++++----- .../model_with_valid_ort_config_json.onnx | Bin 348 -> 367 bytes 20 files changed, 182 insertions(+), 45 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index d78d2bfed9..af59d13df8 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -1058,6 +1058,7 @@ namespace Microsoft.ML.OnnxRuntime private string _graphName; private string _domain; private string _description; + private string _graphDescription; private long _version; private Dictionary _customMetadataMap = new Dictionary(); @@ -1107,6 +1108,14 @@ namespace Microsoft.ML.OnnxRuntime _description = NativeOnnxValueHelper.StringFromNativeUtf8(descriptionHandle); } + // Process graph description + IntPtr graphDescriptionHandle = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphDescription(modelMetadataHandle, allocator.Pointer, out graphDescriptionHandle)); + using (var ortAllocation = new OrtMemoryAllocation(allocator, graphDescriptionHandle, 0)) + { + _graphDescription = NativeOnnxValueHelper.StringFromNativeUtf8(graphDescriptionHandle); + } + // Process version NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetVersion(modelMetadataHandle, out _version)); @@ -1205,6 +1214,18 @@ namespace Microsoft.ML.OnnxRuntime } } + /// + /// Unstructured graph description + /// + /// description string + public string GraphDescription + { + get + { + return _graphDescription; + } + } + /// /// Version number /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 18e2206a1e..f416fa2008 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -188,6 +188,7 @@ namespace Microsoft.ML.OnnxRuntime public IntPtr SetGlobalDenormalAsZero; public IntPtr CreateArenaCfg; public IntPtr ReleaseArenaCfg; + public IntPtr ModelMetadataGetGraphDescription; } internal static class NativeMethods @@ -325,6 +326,7 @@ namespace Microsoft.ML.OnnxRuntime OrtModelMetadataGetGraphName = (DOrtModelMetadataGetGraphName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphName, typeof(DOrtModelMetadataGetGraphName)); OrtModelMetadataGetDomain = (DOrtModelMetadataGetDomain)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDomain, typeof(DOrtModelMetadataGetDomain)); OrtModelMetadataGetDescription = (DOrtModelMetadataGetDescription)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDescription, typeof(DOrtModelMetadataGetDescription)); + OrtModelMetadataGetGraphDescription = (DOrtModelMetadataGetGraphDescription)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphDescription, typeof(DOrtModelMetadataGetGraphDescription)); OrtModelMetadataGetVersion = (DOrtModelMetadataGetVersion)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetVersion, typeof(DOrtModelMetadataGetVersion)); OrtModelMetadataGetCustomMetadataMapKeys = (DOrtModelMetadataGetCustomMetadataMapKeys)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetCustomMetadataMapKeys, typeof(DOrtModelMetadataGetCustomMetadataMapKeys)); OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap)); @@ -961,6 +963,16 @@ namespace Microsoft.ML.OnnxRuntime IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetDescription OrtModelMetadataGetDescription; + /// + /// Gets the description associated with a ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// instance of OrtAllocator + /// (output) graph description from the ModelMetadata instance + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + public static DOrtModelMetadataGetGraphDescription OrtModelMetadataGetGraphDescription; + /// /// Gets the version associated with a ModelMetadata instance /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index caab17bb47..1ef865b736 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -1823,6 +1823,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests Assert.Equal("This is a test model with a valid ORT config Json", modelMetadata.Description); + Assert.Equal("graph description", modelMetadata.GraphDescription); + Assert.Equal(2, modelMetadata.CustomMetadataMap.Keys.Count); Assert.Equal("dummy_value", modelMetadata.CustomMetadataMap["dummy_key"]); Assert.Equal("{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}", diff --git a/csharp/testdata/model_with_valid_ort_config_json.onnx b/csharp/testdata/model_with_valid_ort_config_json.onnx index f2a0a9bb8e72dfe7016e3e5ef019fde5e340f7f2..fd0e0706c9b4da18c6b28feb24ff49622509625e 100644 GIT binary patch delta 34 pcmcb^^qy&g9b?Tzdwby^!StfUf((U})Z*l#%z~24{Je>E_W;@04PyWR delta 14 WcmaFQbcbnz9b@K1d;5tK?*RZSr3M!O diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 262e388e08..ab8b65cbd2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7,7 +7,7 @@ #include // This value is used in structures passed to ORT so that a newer version of ORT will still work with them -#define ORT_API_VERSION 6 +#define ORT_API_VERSION 7 #ifdef __cplusplus extern "C" { @@ -1134,6 +1134,18 @@ struct OrtApi { int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out); ORT_CLASS_RELEASE(ArenaCfg); + + /** + * Use this API to obtain the description of the graph present in the model + * (doc_string field of the GraphProto message within the ModelProto message). + * If it doesn't exist, an empty string will be returned. + * \param model_metadata - an instance of OrtModelMetadata + * \param allocator - allocator used to allocate the string that will be returned back + * \param value - is set to a null terminated string allocated using 'allocator'. + The caller is responsible for freeing it. + */ + ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 9907c56a4a..d5aa79a79d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -336,6 +336,7 @@ struct ModelMetadata : Base { char* GetGraphName(OrtAllocator* allocator) const; char* GetDomain(OrtAllocator* allocator) const; char* GetDescription(OrtAllocator* allocator) const; + char* GetGraphDescription(OrtAllocator* allocator) const; char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; int64_t GetVersion() const; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index ee89acde8f..a5ce8219f6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -602,6 +602,12 @@ inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const { return out; } +inline char* ModelMetadata::GetGraphDescription(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); + return out; +} + inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); diff --git a/onnxruntime/core/flatbuffers/schema/README.md b/onnxruntime/core/flatbuffers/schema/README.md index 6815a1e19c..94b24cacb3 100644 --- a/onnxruntime/core/flatbuffers/schema/README.md +++ b/onnxruntime/core/flatbuffers/schema/README.md @@ -24,4 +24,7 @@ This should result in ort.fbs.h being updated. Initial support for FlatBuffers that includes Model support. Graph support including Attributes, Tensors, Tensor Sequences, Maps and Sequences. Constant initializers are also supported. Constant nodes are converted to constant initializers in the ORT format. ## Version 2. -Support for sparse initialiers. Sparse intializers are stored within ORT FlatBuffers format, which includes sparse initializers converted from Constant node attribute. \ No newline at end of file +Support for sparse initialiers. Sparse intializers are stored within ORT FlatBuffers format, which includes sparse initializers converted from Constant node attribute. + +## Version 3. +Support for storing `graph_doc_string` field in Model (ORT FlatBuffers format). \ No newline at end of file diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs b/onnxruntime/core/flatbuffers/schema/ort.fbs index cb111247d9..fa4e216cda 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs @@ -202,8 +202,10 @@ table Model { domain:string; model_version:int64; doc_string:string; - + graph:Graph; + + graph_doc_string:string; } table KernelCreateInfos { diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs.h b/onnxruntime/core/flatbuffers/schema/ort.fbs.h index 3da4d060c9..a56461d594 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs.h @@ -1815,7 +1815,8 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_DOMAIN = 12, VT_MODEL_VERSION = 14, VT_DOC_STRING = 16, - VT_GRAPH = 18 + VT_GRAPH = 18, + VT_GRAPH_DOC_STRING = 20 }; int64_t ir_version() const { return GetField(VT_IR_VERSION, 0); @@ -1841,6 +1842,9 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const onnxruntime::experimental::fbs::Graph *graph() const { return GetPointer(VT_GRAPH); } + const flatbuffers::String *graph_doc_string() const { + return GetPointer(VT_GRAPH_DOC_STRING); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_IR_VERSION) && @@ -1858,6 +1862,8 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(doc_string()) && VerifyOffset(verifier, VT_GRAPH) && verifier.VerifyTable(graph()) && + VerifyOffset(verifier, VT_GRAPH_DOC_STRING) && + verifier.VerifyString(graph_doc_string()) && verifier.EndTable(); } }; @@ -1890,6 +1896,9 @@ struct ModelBuilder { void add_graph(flatbuffers::Offset graph) { fbb_.AddOffset(Model::VT_GRAPH, graph); } + void add_graph_doc_string(flatbuffers::Offset graph_doc_string) { + fbb_.AddOffset(Model::VT_GRAPH_DOC_STRING, graph_doc_string); + } explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1911,10 +1920,12 @@ inline flatbuffers::Offset CreateModel( flatbuffers::Offset domain = 0, int64_t model_version = 0, flatbuffers::Offset doc_string = 0, - flatbuffers::Offset graph = 0) { + flatbuffers::Offset graph = 0, + flatbuffers::Offset graph_doc_string = 0) { ModelBuilder builder_(_fbb); builder_.add_model_version(model_version); builder_.add_ir_version(ir_version); + builder_.add_graph_doc_string(graph_doc_string); builder_.add_graph(graph); builder_.add_doc_string(doc_string); builder_.add_domain(domain); @@ -1933,12 +1944,14 @@ inline flatbuffers::Offset CreateModelDirect( const char *domain = nullptr, int64_t model_version = 0, const char *doc_string = nullptr, - flatbuffers::Offset graph = 0) { + flatbuffers::Offset graph = 0, + const char *graph_doc_string = nullptr) { auto opset_import__ = opset_import ? _fbb.CreateVector>(*opset_import) : 0; auto producer_name__ = producer_name ? _fbb.CreateString(producer_name) : 0; auto producer_version__ = producer_version ? _fbb.CreateString(producer_version) : 0; auto domain__ = domain ? _fbb.CreateString(domain) : 0; auto doc_string__ = doc_string ? _fbb.CreateString(doc_string) : 0; + auto graph_doc_string__ = graph_doc_string ? _fbb.CreateString(graph_doc_string) : 0; return onnxruntime::experimental::fbs::CreateModel( _fbb, ir_version, @@ -1948,7 +1961,8 @@ inline flatbuffers::Offset CreateModelDirect( domain__, model_version, doc_string__, - graph); + graph, + graph_doc_string__); } struct KernelCreateInfos FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 9fcd078fd8..a5ebe136f0 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -181,24 +181,33 @@ Version Model::IrVersion() const { return kNoVersion; } -const std::string& Model::ProducerName() const { - return model_proto_.producer_name(); +const std::string Model::ProducerName() const { + if (model_proto_.has_producer_name()) { + return model_proto_.producer_name(); + } + return std::string(); } void Model::SetProducerName(const std::string& producer_name) { model_proto_.set_producer_name(producer_name); } -const std::string& Model::ProducerVersion() const { - return model_proto_.producer_version(); +const std::string Model::ProducerVersion() const { + if (model_proto_.has_producer_version()) { + return model_proto_.producer_version(); + } + return std::string(); } void Model::SetProducerVersion(const std::string& producer_version) { model_proto_.set_producer_version(producer_version); } -const std::string& Model::Domain() const { - return model_proto_.domain(); +const std::string Model::Domain() const { + if (model_proto_.has_domain()) { + return model_proto_.domain(); + } + return std::string(); } void Model::SetDomain(const std::string& domain) { @@ -216,14 +225,24 @@ void Model::SetModelVersion(onnxruntime::Version version) { model_proto_.set_model_version(version); } -const std::string& Model::DocString() const { - return model_proto_.doc_string(); +const std::string Model::DocString() const { + if (model_proto_.has_doc_string()) { + return model_proto_.doc_string(); + } + return std::string(); } void Model::SetDocString(const std::string& doc_string) { model_proto_.set_doc_string(doc_string); } +const std::string Model::GraphDocString() const { + if (model_proto_.has_graph() && model_proto_.graph().has_doc_string()) { + return model_proto_.graph().doc_string(); + } + return std::string(); +} + #endif // !defined(ORT_MINIMAL_BUILD) const ModelMetaData& Model::MetaData() const noexcept { @@ -558,6 +577,8 @@ common::Status Model::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, auto domain = builder.CreateSharedString(model_proto_.domain()); auto doc_string = experimental::utils::SaveStringToOrtFormat( builder, model_proto_.has_doc_string(), model_proto_.doc_string()); + auto graph_doc_string = experimental::utils::SaveStringToOrtFormat( + builder, model_proto_.has_graph() && model_proto_.graph().has_doc_string(), model_proto_.graph().doc_string()); std::vector> op_set_ids_vec; op_set_ids_vec.reserve(model_proto_.opset_import().size()); @@ -581,6 +602,7 @@ common::Status Model::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, mb.add_domain(domain); mb.add_model_version(model_proto_.model_version()); mb.add_doc_string(doc_string); + mb.add_graph_doc_string(graph_doc_string); mb.add_graph(fbs_graph); // add graph @@ -605,6 +627,9 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, producer_version, fbs_model.producer_version()); LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, domain, fbs_model.domain()); LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, doc_string, fbs_model.doc_string()); + if (fbs_model.graph_doc_string()) { + model->model_proto_.mutable_graph()->set_doc_string(fbs_model.graph_doc_string()->c_str()); + } model->model_proto_.set_model_version(fbs_model.model_version()); model->model_proto_.set_ir_version(fbs_model.ir_version()); #else @@ -612,6 +637,7 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, experimental::utils::LoadStringFromOrtFormat(model->producer_version_, fbs_model.producer_version()); experimental::utils::LoadStringFromOrtFormat(model->domain_, fbs_model.domain()); experimental::utils::LoadStringFromOrtFormat(model->doc_string_, fbs_model.doc_string()); + experimental::utils::LoadStringFromOrtFormat(model->graph_doc_string_, fbs_model.graph_doc_string()); model->model_version_ = fbs_model.model_version(); model->ir_version_ = fbs_model.ir_version(); #endif diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 53968de60e..7f4dfd0d2f 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -89,20 +89,20 @@ class Model { Version IrVersion() const; // Get model's producer name. - // Return null pointer if not specified. - const std::string& ProducerName() const; + // Returns empty string if not specified. + const std::string ProducerName() const; // Set model's producer name. void SetProducerName(const std::string& producer_name); // Get model's producer version. - // Return null pointer if not specified. - const std::string& ProducerVersion() const; + // Returns empty string if not specified. + const std::string ProducerVersion() const; // Set model's producer version. void SetProducerVersion(const std::string& producer_version); // Get model's domain. - // Return null pointer if not specified. - const std::string& Domain() const; + // Returns empty string if not specified. + const std::string Domain() const; // Set models' domain. void SetDomain(const std::string& domain); @@ -113,34 +113,43 @@ class Model { void SetModelVersion(onnxruntime::Version model_version); // Get model's doc string. - // Return null pointer if not specified. - const std::string& DocString() const; + // Returns empty string if not specified. + const std::string DocString() const; // Set models' doc string. void SetDocString(const std::string& doc_string); + + // Get graph's doc string. + // Returns empty string if not specified. + const std::string GraphDocString() const; + #else // Get model's IR version. // Return if not specified. Version IrVersion() const { return ir_version_; } // Get model's producer name. - // Return null pointer if not specified. - const std::string& ProducerName() const { return producer_name_; } + // Returns empty string if not specified. + const std::string ProducerName() const { return producer_name_; } // Get model's producer version. - // Return null pointer if not specified. - const std::string& ProducerVersion() const { return producer_version_; } + // Returns empty string if not specified. + const std::string ProducerVersion() const { return producer_version_; } // Get model's domain. - const std::string& Domain() const { return domain_; } + // Returns empty string if not specified. + const std::string Domain() const { return domain_; } // Get model's version. // Return null pointer if not specified. Version ModelVersion() const { return model_version_; } // Get model's doc string. - // Return null pointer if not specified. - const std::string& DocString() const { return doc_string_; } + // Returns empty string if not specified. + const std::string DocString() const { return doc_string_; } + // Get graph's doc string. + // Returns empty string if not specified. + const std::string GraphDocString() const { return graph_doc_string_; } #endif const ModelMetaData& MetaData() const noexcept; @@ -248,6 +257,7 @@ class Model { int64_t ir_version_ = kNoVersion; std::string domain_; std::string doc_string_; + std::string graph_doc_string_; #endif // This is a duplication of . diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 59f445f5e1..660b5933df 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -105,8 +105,9 @@ std::atomic InferenceSession::global_session_id_{1}; // below will also need to be updated. // See onnxruntime/core/session/flatbuffers/schema/README.md for more details on versioning. // Version 1 - history begins -// Version 2 - add serailization/deserialization of sparse_initializer -static constexpr const char* kOrtModelVersion = "2"; +// Version 2 - add serialization/deserialization of sparse_initializer +// Version 3 - add `graph_doc_string` to Model +static constexpr const char* kOrtModelVersion = "3"; #if defined(ENABLE_ORT_FORMAT_LOAD) // Check if the given ort model version is supported in this build @@ -116,6 +117,7 @@ static bool IsOrtModelVersionSupported(const std::string& ort_model_version) { static const std::unordered_set kSupportedOrtModelVersions{ std::string("1.4.0"), // This is a special model version for existing converted model std::string("1"), + std::string("2"), std::string(kOrtModelVersion), }; @@ -1745,6 +1747,7 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod // save model metadata model_metadata_.producer_name = model.ProducerName(); model_metadata_.description = model.DocString(); + model_metadata_.graph_description = model.GraphDocString(); model_metadata_.domain = model.Domain(); model_metadata_.version = model.ModelVersion(); model_metadata_.custom_metadata_map = model.MetaData(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 446c5d2753..2b98143621 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -65,6 +65,7 @@ struct ModelMetadata { std::string graph_name; std::string domain; std::string description; + std::string graph_description; int64_t version = 0; std::unordered_map custom_metadata_map; }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index db097983d8..6f613efb92 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1030,6 +1030,16 @@ ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDescription, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetGraphDescription, + _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value) { + API_IMPL_BEGIN + auto description = reinterpret_cast(model_metadata)->graph_description; + *value = StrDup(description, allocator); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value) { API_IMPL_BEGIN @@ -1884,7 +1894,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_6 = { +static constexpr OrtApi ort_api_1_to_7 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -2073,6 +2083,7 @@ static constexpr OrtApi ort_api_1_to_6 = { // End of Version 6 - DO NOT MODIFY ABOVE (see above text for more information) // Version 7 - In development, feel free to add/remove/rearrange here + &OrtApis::ModelMetadataGetGraphDescription, }; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) @@ -2080,8 +2091,8 @@ static constexpr OrtApi ort_api_1_to_6 = { static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change"); ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { - if (version >= 1 && version <= 6) - return &ort_api_1_to_6; + if (version >= 1 && version <= 7) + return &ort_api_1_to_7; return nullptr; // Unsupported version } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 4516fbddf0..a49b785c5a 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -97,6 +97,8 @@ ORT_API_STATUS_IMPL(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_m _Inout_ OrtAllocator* allocator, _Outptr_ char** value); ORT_API_STATUS_IMPL(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); +ORT_API_STATUS_IMPL(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); ORT_API_STATUS_IMPL(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 4126c9229b..37f90e0ff6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1599,6 +1599,7 @@ facilitate the comparison.)pbdoc") .def_readwrite("graph_name", &ModelMetadata::graph_name, "graph name") .def_readwrite("domain", &ModelMetadata::domain, "ONNX domain") .def_readwrite("description", &ModelMetadata::description, "description of the model") + .def_readwrite("graph_description", &ModelMetadata::graph_description, "description of the graph hosted in the model") .def_readwrite("version", &ModelMetadata::version, "version of the model") .def_readwrite("custom_metadata_map", &ModelMetadata::custom_metadata_map, "additional metadata"); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e188d31e47..ec58417e73 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -510,6 +510,7 @@ class TestInferenceSession(unittest.TestCase): self.assertEqual('squeezenet_old', modelmeta.graph_name) self.assertEqual('', modelmeta.domain) self.assertEqual('', modelmeta.description) + self.assertEqual('', modelmeta.graph_description) def testProfilerWithSessionOptions(self): so = onnxrt.SessionOptions() diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index a36d66eae1..3eb03fedcf 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -743,18 +743,18 @@ TEST(CApiTest, io_binding_cuda) { }; Ort::SessionOptions session_options; - #ifdef USE_TENSORRT - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); - #else - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); - #endif +#ifdef USE_TENSORRT + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); +#else + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); +#endif Ort::Session session(*ort_env, MODEL_URI, session_options); - #ifdef USE_TENSORRT - Ort::MemoryInfo info_cuda("Tensorrt", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); - #else - Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); - #endif +#ifdef USE_TENSORRT + Ort::MemoryInfo info_cuda("Tensorrt", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#else + Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#endif Ort::Allocator cuda_allocator(session, info_cuda); auto allocator_info = cuda_allocator.GetInfo(); @@ -1126,6 +1126,10 @@ TEST(CApiTest, model_metadata) { ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0); allocator.get()->Free(description); + char* graph_description = model_metadata.GetGraphDescription(allocator.get()); + ASSERT_TRUE(strcmp("graph description", graph_description) == 0); + allocator.get()->Free(graph_description); + int64_t version = model_metadata.GetVersion(); ASSERT_TRUE(version == 1); @@ -1165,6 +1169,11 @@ TEST(CApiTest, model_metadata) { ASSERT_TRUE(strcmp("", description) == 0); allocator.get()->Free(description); + // Graph description is empty + char* graph_description = model_metadata.GetGraphDescription(allocator.get()); + ASSERT_TRUE(strcmp("", graph_description) == 0); + allocator.get()->Free(graph_description); + // Model does not contain custom metadata map int64_t num_keys_in_custom_metadata_map; char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map); diff --git a/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx b/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx index f2a0a9bb8e72dfe7016e3e5ef019fde5e340f7f2..fd0e0706c9b4da18c6b28feb24ff49622509625e 100644 GIT binary patch delta 34 pcmcb^^qy&g9b?Tzdwby^!StfUf((U})Z*l#%z~24{Je>E_W;@04PyWR delta 14 WcmaFQbcbnz9b@K1d;5tK?*RZSr3M!O