Allow querying a GraphProto's doc_string as part of ModelMetadata (#6248)

This commit is contained in:
Hariharan Seshadri 2021-01-06 11:48:03 +05:30 committed by GitHub
parent eea3806db1
commit d42399e1b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 182 additions and 45 deletions

View file

@ -1058,6 +1058,7 @@ namespace Microsoft.ML.OnnxRuntime
private string _graphName;
private string _domain;
private string _description;
private string _graphDescription;
private long _version;
private Dictionary<string, string> _customMetadataMap = new Dictionary<string, string>();
@ -1107,6 +1108,14 @@ namespace Microsoft.ML.OnnxRuntime
_description = NativeOnnxValueHelper.StringFromNativeUtf8(descriptionHandle);
}
// Process graph description
IntPtr graphDescriptionHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphDescription(modelMetadataHandle, allocator.Pointer, out graphDescriptionHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, graphDescriptionHandle, 0))
{
_graphDescription = NativeOnnxValueHelper.StringFromNativeUtf8(graphDescriptionHandle);
}
// Process version
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetVersion(modelMetadataHandle, out _version));
@ -1205,6 +1214,18 @@ namespace Microsoft.ML.OnnxRuntime
}
}
/// <summary>
/// Unstructured graph description
/// </summary>
/// <value>description string</value>
public string GraphDescription
{
get
{
return _graphDescription;
}
}
/// <summary>
/// Version number
/// </summary>

View file

@ -188,6 +188,7 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr SetGlobalDenormalAsZero;
public IntPtr CreateArenaCfg;
public IntPtr ReleaseArenaCfg;
public IntPtr ModelMetadataGetGraphDescription;
}
internal static class NativeMethods
@ -325,6 +326,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtModelMetadataGetGraphName = (DOrtModelMetadataGetGraphName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphName, typeof(DOrtModelMetadataGetGraphName));
OrtModelMetadataGetDomain = (DOrtModelMetadataGetDomain)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDomain, typeof(DOrtModelMetadataGetDomain));
OrtModelMetadataGetDescription = (DOrtModelMetadataGetDescription)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDescription, typeof(DOrtModelMetadataGetDescription));
OrtModelMetadataGetGraphDescription = (DOrtModelMetadataGetGraphDescription)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphDescription, typeof(DOrtModelMetadataGetGraphDescription));
OrtModelMetadataGetVersion = (DOrtModelMetadataGetVersion)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetVersion, typeof(DOrtModelMetadataGetVersion));
OrtModelMetadataGetCustomMetadataMapKeys = (DOrtModelMetadataGetCustomMetadataMapKeys)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetCustomMetadataMapKeys, typeof(DOrtModelMetadataGetCustomMetadataMapKeys));
OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap));
@ -961,6 +963,16 @@ namespace Microsoft.ML.OnnxRuntime
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetDescription OrtModelMetadataGetDescription;
/// <summary>
/// Gets the description associated with a ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="value">(output) graph description from the ModelMetadata instance</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetGraphDescription OrtModelMetadataGetGraphDescription;
/// <summary>
/// Gets the version associated with a ModelMetadata instance
/// </summary>

View file

@ -1823,6 +1823,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.Equal("This is a test model with a valid ORT config Json", modelMetadata.Description);
Assert.Equal("graph description", modelMetadata.GraphDescription);
Assert.Equal(2, modelMetadata.CustomMetadataMap.Keys.Count);
Assert.Equal("dummy_value", modelMetadata.CustomMetadataMap["dummy_key"]);
Assert.Equal("{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",

View file

@ -7,7 +7,7 @@
#include <string.h>
// This value is used in structures passed to ORT so that a newer version of ORT will still work with them
#define ORT_API_VERSION 6
#define ORT_API_VERSION 7
#ifdef __cplusplus
extern "C" {
@ -1134,6 +1134,18 @@ struct OrtApi {
int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out);
ORT_CLASS_RELEASE(ArenaCfg);
/**
* Use this API to obtain the description of the graph present in the model
* (doc_string field of the GraphProto message within the ModelProto message).
* If it doesn't exist, an empty string will be returned.
* \param model_metadata - an instance of OrtModelMetadata
* \param allocator - allocator used to allocate the string that will be returned back
* \param value - is set to a null terminated string allocated using 'allocator'.
The caller is responsible for freeing it.
*/
ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
};
/*

View file

@ -336,6 +336,7 @@ struct ModelMetadata : Base<OrtModelMetadata> {
char* GetGraphName(OrtAllocator* allocator) const;
char* GetDomain(OrtAllocator* allocator) const;
char* GetDescription(OrtAllocator* allocator) const;
char* GetGraphDescription(OrtAllocator* allocator) const;
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const;
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
int64_t GetVersion() const;

View file

@ -602,6 +602,12 @@ inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const {
return out;
}
inline char* ModelMetadata::GetGraphDescription(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));

View file

@ -24,4 +24,7 @@ This should result in ort.fbs.h being updated.
Initial support for FlatBuffers that includes Model support. Graph support including Attributes, Tensors, Tensor Sequences, Maps and Sequences. Constant initializers are also supported. Constant nodes are converted to constant initializers in the ORT format.
## Version 2.
Support for sparse initialiers. Sparse intializers are stored within ORT FlatBuffers format, which includes sparse initializers converted from Constant node attribute.
Support for sparse initialiers. Sparse intializers are stored within ORT FlatBuffers format, which includes sparse initializers converted from Constant node attribute.
## Version 3.
Support for storing `graph_doc_string` field in Model (ORT FlatBuffers format).

View file

@ -202,8 +202,10 @@ table Model {
domain:string;
model_version:int64;
doc_string:string;
graph:Graph;
graph_doc_string:string;
}
table KernelCreateInfos {

View file

@ -1815,7 +1815,8 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_DOMAIN = 12,
VT_MODEL_VERSION = 14,
VT_DOC_STRING = 16,
VT_GRAPH = 18
VT_GRAPH = 18,
VT_GRAPH_DOC_STRING = 20
};
int64_t ir_version() const {
return GetField<int64_t>(VT_IR_VERSION, 0);
@ -1841,6 +1842,9 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const onnxruntime::experimental::fbs::Graph *graph() const {
return GetPointer<const onnxruntime::experimental::fbs::Graph *>(VT_GRAPH);
}
const flatbuffers::String *graph_doc_string() const {
return GetPointer<const flatbuffers::String *>(VT_GRAPH_DOC_STRING);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int64_t>(verifier, VT_IR_VERSION) &&
@ -1858,6 +1862,8 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.VerifyString(doc_string()) &&
VerifyOffset(verifier, VT_GRAPH) &&
verifier.VerifyTable(graph()) &&
VerifyOffset(verifier, VT_GRAPH_DOC_STRING) &&
verifier.VerifyString(graph_doc_string()) &&
verifier.EndTable();
}
};
@ -1890,6 +1896,9 @@ struct ModelBuilder {
void add_graph(flatbuffers::Offset<onnxruntime::experimental::fbs::Graph> graph) {
fbb_.AddOffset(Model::VT_GRAPH, graph);
}
void add_graph_doc_string(flatbuffers::Offset<flatbuffers::String> graph_doc_string) {
fbb_.AddOffset(Model::VT_GRAPH_DOC_STRING, graph_doc_string);
}
explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@ -1911,10 +1920,12 @@ inline flatbuffers::Offset<Model> CreateModel(
flatbuffers::Offset<flatbuffers::String> domain = 0,
int64_t model_version = 0,
flatbuffers::Offset<flatbuffers::String> doc_string = 0,
flatbuffers::Offset<onnxruntime::experimental::fbs::Graph> graph = 0) {
flatbuffers::Offset<onnxruntime::experimental::fbs::Graph> graph = 0,
flatbuffers::Offset<flatbuffers::String> graph_doc_string = 0) {
ModelBuilder builder_(_fbb);
builder_.add_model_version(model_version);
builder_.add_ir_version(ir_version);
builder_.add_graph_doc_string(graph_doc_string);
builder_.add_graph(graph);
builder_.add_doc_string(doc_string);
builder_.add_domain(domain);
@ -1933,12 +1944,14 @@ inline flatbuffers::Offset<Model> CreateModelDirect(
const char *domain = nullptr,
int64_t model_version = 0,
const char *doc_string = nullptr,
flatbuffers::Offset<onnxruntime::experimental::fbs::Graph> graph = 0) {
flatbuffers::Offset<onnxruntime::experimental::fbs::Graph> graph = 0,
const char *graph_doc_string = nullptr) {
auto opset_import__ = opset_import ? _fbb.CreateVector<flatbuffers::Offset<onnxruntime::experimental::fbs::OperatorSetId>>(*opset_import) : 0;
auto producer_name__ = producer_name ? _fbb.CreateString(producer_name) : 0;
auto producer_version__ = producer_version ? _fbb.CreateString(producer_version) : 0;
auto domain__ = domain ? _fbb.CreateString(domain) : 0;
auto doc_string__ = doc_string ? _fbb.CreateString(doc_string) : 0;
auto graph_doc_string__ = graph_doc_string ? _fbb.CreateString(graph_doc_string) : 0;
return onnxruntime::experimental::fbs::CreateModel(
_fbb,
ir_version,
@ -1948,7 +1961,8 @@ inline flatbuffers::Offset<Model> CreateModelDirect(
domain__,
model_version,
doc_string__,
graph);
graph,
graph_doc_string__);
}
struct KernelCreateInfos FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {

View file

@ -181,24 +181,33 @@ Version Model::IrVersion() const {
return kNoVersion;
}
const std::string& Model::ProducerName() const {
return model_proto_.producer_name();
const std::string Model::ProducerName() const {
if (model_proto_.has_producer_name()) {
return model_proto_.producer_name();
}
return std::string();
}
void Model::SetProducerName(const std::string& producer_name) {
model_proto_.set_producer_name(producer_name);
}
const std::string& Model::ProducerVersion() const {
return model_proto_.producer_version();
const std::string Model::ProducerVersion() const {
if (model_proto_.has_producer_version()) {
return model_proto_.producer_version();
}
return std::string();
}
void Model::SetProducerVersion(const std::string& producer_version) {
model_proto_.set_producer_version(producer_version);
}
const std::string& Model::Domain() const {
return model_proto_.domain();
const std::string Model::Domain() const {
if (model_proto_.has_domain()) {
return model_proto_.domain();
}
return std::string();
}
void Model::SetDomain(const std::string& domain) {
@ -216,14 +225,24 @@ void Model::SetModelVersion(onnxruntime::Version version) {
model_proto_.set_model_version(version);
}
const std::string& Model::DocString() const {
return model_proto_.doc_string();
const std::string Model::DocString() const {
if (model_proto_.has_doc_string()) {
return model_proto_.doc_string();
}
return std::string();
}
void Model::SetDocString(const std::string& doc_string) {
model_proto_.set_doc_string(doc_string);
}
const std::string Model::GraphDocString() const {
if (model_proto_.has_graph() && model_proto_.graph().has_doc_string()) {
return model_proto_.graph().doc_string();
}
return std::string();
}
#endif // !defined(ORT_MINIMAL_BUILD)
const ModelMetaData& Model::MetaData() const noexcept {
@ -558,6 +577,8 @@ common::Status Model::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
auto domain = builder.CreateSharedString(model_proto_.domain());
auto doc_string = experimental::utils::SaveStringToOrtFormat(
builder, model_proto_.has_doc_string(), model_proto_.doc_string());
auto graph_doc_string = experimental::utils::SaveStringToOrtFormat(
builder, model_proto_.has_graph() && model_proto_.graph().has_doc_string(), model_proto_.graph().doc_string());
std::vector<flatbuffers::Offset<fbs::OperatorSetId>> op_set_ids_vec;
op_set_ids_vec.reserve(model_proto_.opset_import().size());
@ -581,6 +602,7 @@ common::Status Model::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
mb.add_domain(domain);
mb.add_model_version(model_proto_.model_version());
mb.add_doc_string(doc_string);
mb.add_graph_doc_string(graph_doc_string);
mb.add_graph(fbs_graph);
// add graph
@ -605,6 +627,9 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, producer_version, fbs_model.producer_version());
LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, domain, fbs_model.domain());
LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, doc_string, fbs_model.doc_string());
if (fbs_model.graph_doc_string()) {
model->model_proto_.mutable_graph()->set_doc_string(fbs_model.graph_doc_string()->c_str());
}
model->model_proto_.set_model_version(fbs_model.model_version());
model->model_proto_.set_ir_version(fbs_model.ir_version());
#else
@ -612,6 +637,7 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
experimental::utils::LoadStringFromOrtFormat(model->producer_version_, fbs_model.producer_version());
experimental::utils::LoadStringFromOrtFormat(model->domain_, fbs_model.domain());
experimental::utils::LoadStringFromOrtFormat(model->doc_string_, fbs_model.doc_string());
experimental::utils::LoadStringFromOrtFormat(model->graph_doc_string_, fbs_model.graph_doc_string());
model->model_version_ = fbs_model.model_version();
model->ir_version_ = fbs_model.ir_version();
#endif

View file

@ -89,20 +89,20 @@ class Model {
Version IrVersion() const;
// Get model's producer name.
// Return null pointer if not specified.
const std::string& ProducerName() const;
// Returns empty string if not specified.
const std::string ProducerName() const;
// Set model's producer name.
void SetProducerName(const std::string& producer_name);
// Get model's producer version.
// Return null pointer if not specified.
const std::string& ProducerVersion() const;
// Returns empty string if not specified.
const std::string ProducerVersion() const;
// Set model's producer version.
void SetProducerVersion(const std::string& producer_version);
// Get model's domain.
// Return null pointer if not specified.
const std::string& Domain() const;
// Returns empty string if not specified.
const std::string Domain() const;
// Set models' domain.
void SetDomain(const std::string& domain);
@ -113,34 +113,43 @@ class Model {
void SetModelVersion(onnxruntime::Version model_version);
// Get model's doc string.
// Return null pointer if not specified.
const std::string& DocString() const;
// Returns empty string if not specified.
const std::string DocString() const;
// Set models' doc string.
void SetDocString(const std::string& doc_string);
// Get graph's doc string.
// Returns empty string if not specified.
const std::string GraphDocString() const;
#else
// Get model's IR version.
// Return <kNoVersion> if not specified.
Version IrVersion() const { return ir_version_; }
// Get model's producer name.
// Return null pointer if not specified.
const std::string& ProducerName() const { return producer_name_; }
// Returns empty string if not specified.
const std::string ProducerName() const { return producer_name_; }
// Get model's producer version.
// Return null pointer if not specified.
const std::string& ProducerVersion() const { return producer_version_; }
// Returns empty string if not specified.
const std::string ProducerVersion() const { return producer_version_; }
// Get model's domain.
const std::string& Domain() const { return domain_; }
// Returns empty string if not specified.
const std::string Domain() const { return domain_; }
// Get model's version.
// Return null pointer if not specified.
Version ModelVersion() const { return model_version_; }
// Get model's doc string.
// Return null pointer if not specified.
const std::string& DocString() const { return doc_string_; }
// Returns empty string if not specified.
const std::string DocString() const { return doc_string_; }
// Get graph's doc string.
// Returns empty string if not specified.
const std::string GraphDocString() const { return graph_doc_string_; }
#endif
const ModelMetaData& MetaData() const noexcept;
@ -248,6 +257,7 @@ class Model {
int64_t ir_version_ = kNoVersion;
std::string domain_;
std::string doc_string_;
std::string graph_doc_string_;
#endif
// This is a duplication of <model_proto_.metadata_props()>.

View file

@ -105,8 +105,9 @@ std::atomic<uint32_t> InferenceSession::global_session_id_{1};
// below will also need to be updated.
// See onnxruntime/core/session/flatbuffers/schema/README.md for more details on versioning.
// Version 1 - history begins
// Version 2 - add serailization/deserialization of sparse_initializer
static constexpr const char* kOrtModelVersion = "2";
// Version 2 - add serialization/deserialization of sparse_initializer
// Version 3 - add `graph_doc_string` to Model
static constexpr const char* kOrtModelVersion = "3";
#if defined(ENABLE_ORT_FORMAT_LOAD)
// Check if the given ort model version is supported in this build
@ -116,6 +117,7 @@ static bool IsOrtModelVersionSupported(const std::string& ort_model_version) {
static const std::unordered_set<std::string> kSupportedOrtModelVersions{
std::string("1.4.0"), // This is a special model version for existing converted model
std::string("1"),
std::string("2"),
std::string(kOrtModelVersion),
};
@ -1745,6 +1747,7 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod
// save model metadata
model_metadata_.producer_name = model.ProducerName();
model_metadata_.description = model.DocString();
model_metadata_.graph_description = model.GraphDocString();
model_metadata_.domain = model.Domain();
model_metadata_.version = model.ModelVersion();
model_metadata_.custom_metadata_map = model.MetaData();

View file

@ -65,6 +65,7 @@ struct ModelMetadata {
std::string graph_name;
std::string domain;
std::string description;
std::string graph_description;
int64_t version = 0;
std::unordered_map<std::string, std::string> custom_metadata_map;
};

View file

@ -1030,6 +1030,16 @@ ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDescription,
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetGraphDescription,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto description = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->graph_description;
*value = StrDup(description, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value) {
API_IMPL_BEGIN
@ -1884,7 +1894,7 @@ Second example, if we wanted to add and remove some members, we'd do this:
In GetApi we now make it return ort_api_3 for version 3.
*/
static constexpr OrtApi ort_api_1_to_6 = {
static constexpr OrtApi ort_api_1_to_7 = {
// NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering.
// Shipped as version 1 - DO NOT MODIFY (see above text for more information)
@ -2073,6 +2083,7 @@ static constexpr OrtApi ort_api_1_to_6 = {
// End of Version 6 - DO NOT MODIFY ABOVE (see above text for more information)
// Version 7 - In development, feel free to add/remove/rearrange here
&OrtApis::ModelMetadataGetGraphDescription,
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
@ -2080,8 +2091,8 @@ static constexpr OrtApi ort_api_1_to_6 = {
static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change");
ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
if (version >= 1 && version <= 6)
return &ort_api_1_to_6;
if (version >= 1 && version <= 7)
return &ort_api_1_to_7;
return nullptr; // Unsupported version
}

View file

@ -97,6 +97,8 @@ ORT_API_STATUS_IMPL(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_m
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value);

View file

@ -1599,6 +1599,7 @@ facilitate the comparison.)pbdoc")
.def_readwrite("graph_name", &ModelMetadata::graph_name, "graph name")
.def_readwrite("domain", &ModelMetadata::domain, "ONNX domain")
.def_readwrite("description", &ModelMetadata::description, "description of the model")
.def_readwrite("graph_description", &ModelMetadata::graph_description, "description of the graph hosted in the model")
.def_readwrite("version", &ModelMetadata::version, "version of the model")
.def_readwrite("custom_metadata_map", &ModelMetadata::custom_metadata_map, "additional metadata");

View file

@ -510,6 +510,7 @@ class TestInferenceSession(unittest.TestCase):
self.assertEqual('squeezenet_old', modelmeta.graph_name)
self.assertEqual('', modelmeta.domain)
self.assertEqual('', modelmeta.description)
self.assertEqual('', modelmeta.graph_description)
def testProfilerWithSessionOptions(self):
so = onnxrt.SessionOptions()

View file

@ -743,18 +743,18 @@ TEST(CApiTest, io_binding_cuda) {
};
Ort::SessionOptions session_options;
#ifdef USE_TENSORRT
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0));
#else
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
#endif
#ifdef USE_TENSORRT
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0));
#else
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
#endif
Ort::Session session(*ort_env, MODEL_URI, session_options);
#ifdef USE_TENSORRT
Ort::MemoryInfo info_cuda("Tensorrt", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#else
Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#endif
#ifdef USE_TENSORRT
Ort::MemoryInfo info_cuda("Tensorrt", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#else
Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#endif
Ort::Allocator cuda_allocator(session, info_cuda);
auto allocator_info = cuda_allocator.GetInfo();
@ -1126,6 +1126,10 @@ TEST(CApiTest, model_metadata) {
ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0);
allocator.get()->Free(description);
char* graph_description = model_metadata.GetGraphDescription(allocator.get());
ASSERT_TRUE(strcmp("graph description", graph_description) == 0);
allocator.get()->Free(graph_description);
int64_t version = model_metadata.GetVersion();
ASSERT_TRUE(version == 1);
@ -1165,6 +1169,11 @@ TEST(CApiTest, model_metadata) {
ASSERT_TRUE(strcmp("", description) == 0);
allocator.get()->Free(description);
// Graph description is empty
char* graph_description = model_metadata.GetGraphDescription(allocator.get());
ASSERT_TRUE(strcmp("", graph_description) == 0);
allocator.get()->Free(graph_description);
// Model does not contain custom metadata map
int64_t num_keys_in_custom_metadata_map;
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map);