Support accessing a model's metadata in C# (#4867)

Implement access to model's metadata in C#
This commit is contained in:
Hariharan Seshadri 2020-08-25 11:13:49 -07:00 committed by GitHub
parent 26bd8c2085
commit 6c26e52134
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 307 additions and 13 deletions

View file

@ -20,7 +20,7 @@ namespace Microsoft.ML.OnnxRuntime
protected Dictionary<string, NodeMetadata> _inputMetadata, _outputMetadata, _overridableInitializerMetadata;
private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;
private ModelMetadata _modelMetadata = null;
#region Public API
@ -641,12 +641,17 @@ namespace Microsoft.ML.OnnxRuntime
return result;
}
//TODO: kept internal until implemented
internal ModelMetadata ModelMetadata
public ModelMetadata ModelMetadata
{
get
{
return new ModelMetadata(); //TODO: implement
if (_modelMetadata != null)
{
return _modelMetadata;
}
_modelMetadata = new ModelMetadata(this);
return _modelMetadata;
}
}
@ -993,9 +998,158 @@ namespace Microsoft.ML.OnnxRuntime
}
internal class ModelMetadata
public class ModelMetadata
{
//TODO: placeholder for Model metadata. Currently C-API does not expose this.
private string _producerName;
private string _graphName;
private string _domain;
private string _description;
private long _version;
private Dictionary<string, string> _customMetadataMap = new Dictionary<string, string>();
internal ModelMetadata(InferenceSession session)
{
IntPtr modelMetadataHandle = IntPtr.Zero;
var allocator = OrtAllocator.DefaultInstance;
// Get the native ModelMetadata instance associated with the InferenceSession
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetModelMetadata(session.Handle, out modelMetadataHandle));
try
{
// Process producer name
IntPtr producerNameHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetProducerName(modelMetadataHandle, allocator.Pointer, out producerNameHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, producerNameHandle, 0))
{
_producerName = NativeOnnxValueHelper.StringFromNativeUtf8(producerNameHandle);
}
// Process graph name
IntPtr graphNameHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphName(modelMetadataHandle, allocator.Pointer, out graphNameHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, graphNameHandle, 0))
{
_graphName = NativeOnnxValueHelper.StringFromNativeUtf8(graphNameHandle);
}
// Process domain
IntPtr domainHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDomain(modelMetadataHandle, allocator.Pointer, out domainHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, domainHandle, 0))
{
_domain = NativeOnnxValueHelper.StringFromNativeUtf8(domainHandle);
}
// Process description
IntPtr descriptionHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDescription(modelMetadataHandle, allocator.Pointer, out descriptionHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, descriptionHandle, 0))
{
_description = NativeOnnxValueHelper.StringFromNativeUtf8(descriptionHandle);
}
// Process version
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetVersion(modelMetadataHandle, out _version));
// Process CustomMetadata Map
IntPtr customMetadataMapKeysHandle = IntPtr.Zero;
long numKeys;
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetCustomMetadataMapKeys(modelMetadataHandle, allocator.Pointer, out customMetadataMapKeysHandle, out numKeys));
// We have received an array of null terminated C strings which are the keys that we can use to lookup the custom metadata map
// The OrtAllocator will finally free the customMetadataMapKeysHandle
using (var ortAllocationKeysArray = new OrtMemoryAllocation(allocator, customMetadataMapKeysHandle, 0))
using (var ortAllocationKeys = new DisposableList<OrtMemoryAllocation>((int)numKeys))
{
// Put all the handles to each key in the DisposableList to be disposed off in an exception-safe manner
for (int i = 0; i < (int)numKeys; ++i)
{
ortAllocationKeys.Add(new OrtMemoryAllocation(allocator, Marshal.ReadIntPtr(customMetadataMapKeysHandle, IntPtr.Size * i), 0));
}
// Process each key via the stored key handles
foreach(var allocation in ortAllocationKeys)
{
IntPtr keyHandle = allocation.Pointer;
IntPtr valueHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataLookupCustomMetadataMap(modelMetadataHandle, allocator.Pointer, keyHandle, out valueHandle));
using (var ortAllocationValue = new OrtMemoryAllocation(allocator, valueHandle, 0))
{
var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyHandle);
var value = NativeOnnxValueHelper.StringFromNativeUtf8(valueHandle);
// Put the key/value pair into the dictionary
_customMetadataMap[key] = value;
}
}
}
}
finally
{
// Free ModelMetadata handle
NativeMethods.OrtReleaseModelMetadata(modelMetadataHandle);
}
}
public string ProducerName
{
get
{
return _producerName;
}
}
public string GraphName
{
get
{
return _graphName;
}
}
public string Domain
{
get
{
return _domain;
}
}
public string Description
{
get
{
return _description;
}
}
public long Version
{
get
{
return _version;
}
}
public Dictionary<string, string> CustomMetadataMap
{
get
{
return _customMetadataMap;
}
}
}

View file

@ -293,6 +293,18 @@ namespace Microsoft.ML.OnnxRuntime
OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions));
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue));
OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata));
OrtModelMetadataGetProducerName = (DOrtModelMetadataGetProducerName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetProducerName, typeof(DOrtModelMetadataGetProducerName));
OrtModelMetadataGetGraphName = (DOrtModelMetadataGetGraphName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphName, typeof(DOrtModelMetadataGetGraphName));
OrtModelMetadataGetDomain = (DOrtModelMetadataGetDomain)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDomain, typeof(DOrtModelMetadataGetDomain));
OrtModelMetadataGetDescription = (DOrtModelMetadataGetDescription)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDescription, typeof(DOrtModelMetadataGetDescription));
OrtModelMetadataGetVersion = (DOrtModelMetadataGetVersion)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetVersion, typeof(DOrtModelMetadataGetVersion));
OrtModelMetadataGetCustomMetadataMapKeys = (DOrtModelMetadataGetCustomMetadataMapKeys)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetCustomMetadataMapKeys, typeof(DOrtModelMetadataGetCustomMetadataMapKeys));
OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap));
OrtReleaseModelMetadata = (DOrtReleaseModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.ReleaseModelMetadata, typeof(DOrtReleaseModelMetadata));
}
[DllImport(nativeLib, CharSet = charSet)]
@ -763,6 +775,99 @@ namespace Microsoft.ML.OnnxRuntime
#endregion IoBinding API
#region ModelMetadata API
/// <summary>
/// Gets the ModelMetadata associated with an InferenceSession
/// </summary>
/// <param name="session">instance of OrtSession</param>
/// <param name="modelMetadata">(output) instance of OrtModelMetadata</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtSessionGetModelMetadata(IntPtr /* (const OrtSession*) */ session, out IntPtr /* (OrtModelMetadata**) */ modelMetadata);
public static DOrtSessionGetModelMetadata OrtSessionGetModelMetadata;
/// <summary>
/// Gets the producer name associated with a ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="value">(output) producer name from the ModelMetadata instance</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName;
/// <summary>
/// Gets the graph name associated with a ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="value">(output) graph name from the ModelMetadata instance</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetGraphName OrtModelMetadataGetGraphName;
/// <summary>
/// Gets the domain associated with a ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="value">(output) domain from the ModelMetadata instance</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDomain(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetDomain OrtModelMetadataGetDomain;
/// <summary>
/// Gets the description associated with a ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="value">(output) description from the ModelMetadata instance</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetDescription OrtModelMetadataGetDescription;
/// <summary>
/// Gets the version associated with a ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="value">(output) version from the ModelMetadata instance</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetVersion(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
out long /* (int64_t*) */ value);
public static DOrtModelMetadataGetVersion OrtModelMetadataGetVersion;
/// <summary>
/// Gets all the keys in the custom metadata map in the ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="keys">(output) all keys in the custom metadata map</param>
/// <param name="numKeys">(output) number of keys in the custom metadata map</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetCustomMetadataMapKeys(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char***) */ keys, out long /* (int64_t*) */ numKeys);
public static DOrtModelMetadataGetCustomMetadataMapKeys OrtModelMetadataGetCustomMetadataMapKeys;
/// <summary>
/// Gets the value associated with the given key in custom metadata map in the ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="key">key in the custom metadata map</param>
/// <param name="value">(output) value for the key in the custom metadata map</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataLookupCustomMetadataMap(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, IntPtr /* (const char*) */ key, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataLookupCustomMetadataMap OrtModelMetadataLookupCustomMetadataMap;
/// <summary>
/// Frees ModelMetadata instance
/// </summary>
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
public delegate void DOrtReleaseModelMetadata(IntPtr /*(OrtModelMetadata*)*/ modelMetadata);
public static DOrtReleaseModelMetadata OrtReleaseModelMetadata;
#endregion ModelMetadata API
#region Tensor/OnnxValue API
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetValue(IntPtr /*(OrtValue*)*/ value,

View file

@ -1563,6 +1563,36 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
[Fact]
private void TestModelMetadata()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "model_with_valid_ort_config_json.onnx");
using (var session = new InferenceSession(modelPath))
{
var modelMetadata = session.ModelMetadata;
Assert.Equal(1, modelMetadata.Version);
Assert.Equal("Hari", modelMetadata.ProducerName);
Assert.Equal("matmul test", modelMetadata.GraphName);
Assert.Equal("", modelMetadata.Domain);
Assert.Equal("This is a test model with a valid ORT config Json", modelMetadata.Description);
Assert.Equal(2, modelMetadata.CustomMetadataMap.Keys.Count);
Assert.Equal("dummy_value", modelMetadata.CustomMetadataMap["dummy_key"]);
Assert.Equal("{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",
modelMetadata.CustomMetadataMap["ort_config"]);
}
}
[Fact]
private void TestModelSerialization()
{

Binary file not shown.

View file

@ -773,19 +773,24 @@ TEST(CApiTest, model_metadata) {
int64_t num_keys_in_custom_metadata_map;
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map);
ASSERT_TRUE(num_keys_in_custom_metadata_map == 1);
ASSERT_TRUE(strcmp(custom_metadata_map_keys[0], "ort_config") == 0);
ASSERT_TRUE(num_keys_in_custom_metadata_map == 2);
allocator.get()->Free(custom_metadata_map_keys[0]);
allocator.get()->Free(custom_metadata_map_keys[1]);
allocator.get()->Free(custom_metadata_map_keys);
char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
ASSERT_TRUE(strcmp(lookup_value,
char* lookup_value_1 = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
ASSERT_TRUE(strcmp(lookup_value_1,
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
allocator.get()->Free(lookup_value);
allocator.get()->Free(lookup_value_1);
char* lookup_value_2 = model_metadata.LookupCustomMetadataMap("dummy_key", allocator.get());
ASSERT_TRUE(strcmp(lookup_value_2, "dummy_value") == 0);
allocator.get()->Free(lookup_value_2);
// key doesn't exist in custom metadata map
lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get());
ASSERT_TRUE(lookup_value == nullptr);
char* lookup_value_3 = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get());
ASSERT_TRUE(lookup_value_3 == nullptr);
}
// The following section tests a model with some missing metadata info