mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Support accessing a model's metadata in C# (#4867)
Implement access to model's metadata in C#
This commit is contained in:
parent
26bd8c2085
commit
6c26e52134
6 changed files with 307 additions and 13 deletions
|
|
@ -20,7 +20,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
protected Dictionary<string, NodeMetadata> _inputMetadata, _outputMetadata, _overridableInitializerMetadata;
|
||||
private SessionOptions _builtInSessionOptions = null;
|
||||
private RunOptions _builtInRunOptions = null;
|
||||
|
||||
private ModelMetadata _modelMetadata = null;
|
||||
|
||||
#region Public API
|
||||
|
||||
|
|
@ -641,12 +641,17 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
return result;
|
||||
}
|
||||
|
||||
//TODO: kept internal until implemented
|
||||
internal ModelMetadata ModelMetadata
|
||||
public ModelMetadata ModelMetadata
|
||||
{
|
||||
get
|
||||
{
|
||||
return new ModelMetadata(); //TODO: implement
|
||||
if (_modelMetadata != null)
|
||||
{
|
||||
return _modelMetadata;
|
||||
}
|
||||
|
||||
_modelMetadata = new ModelMetadata(this);
|
||||
return _modelMetadata;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -993,9 +998,158 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
|
||||
|
||||
internal class ModelMetadata
|
||||
public class ModelMetadata
|
||||
{
|
||||
//TODO: placeholder for Model metadata. Currently C-API does not expose this.
|
||||
private string _producerName;
|
||||
private string _graphName;
|
||||
private string _domain;
|
||||
private string _description;
|
||||
private long _version;
|
||||
private Dictionary<string, string> _customMetadataMap = new Dictionary<string, string>();
|
||||
|
||||
internal ModelMetadata(InferenceSession session)
|
||||
{
|
||||
IntPtr modelMetadataHandle = IntPtr.Zero;
|
||||
|
||||
var allocator = OrtAllocator.DefaultInstance;
|
||||
|
||||
// Get the native ModelMetadata instance associated with the InferenceSession
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetModelMetadata(session.Handle, out modelMetadataHandle));
|
||||
|
||||
try
|
||||
{
|
||||
|
||||
// Process producer name
|
||||
IntPtr producerNameHandle = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetProducerName(modelMetadataHandle, allocator.Pointer, out producerNameHandle));
|
||||
using (var ortAllocation = new OrtMemoryAllocation(allocator, producerNameHandle, 0))
|
||||
{
|
||||
_producerName = NativeOnnxValueHelper.StringFromNativeUtf8(producerNameHandle);
|
||||
}
|
||||
|
||||
// Process graph name
|
||||
IntPtr graphNameHandle = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphName(modelMetadataHandle, allocator.Pointer, out graphNameHandle));
|
||||
using (var ortAllocation = new OrtMemoryAllocation(allocator, graphNameHandle, 0))
|
||||
{
|
||||
_graphName = NativeOnnxValueHelper.StringFromNativeUtf8(graphNameHandle);
|
||||
}
|
||||
|
||||
|
||||
// Process domain
|
||||
IntPtr domainHandle = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDomain(modelMetadataHandle, allocator.Pointer, out domainHandle));
|
||||
using (var ortAllocation = new OrtMemoryAllocation(allocator, domainHandle, 0))
|
||||
{
|
||||
_domain = NativeOnnxValueHelper.StringFromNativeUtf8(domainHandle);
|
||||
}
|
||||
|
||||
// Process description
|
||||
IntPtr descriptionHandle = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDescription(modelMetadataHandle, allocator.Pointer, out descriptionHandle));
|
||||
using (var ortAllocation = new OrtMemoryAllocation(allocator, descriptionHandle, 0))
|
||||
{
|
||||
_description = NativeOnnxValueHelper.StringFromNativeUtf8(descriptionHandle);
|
||||
}
|
||||
|
||||
// Process version
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetVersion(modelMetadataHandle, out _version));
|
||||
|
||||
|
||||
// Process CustomMetadata Map
|
||||
IntPtr customMetadataMapKeysHandle = IntPtr.Zero;
|
||||
long numKeys;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetCustomMetadataMapKeys(modelMetadataHandle, allocator.Pointer, out customMetadataMapKeysHandle, out numKeys));
|
||||
|
||||
// We have received an array of null terminated C strings which are the keys that we can use to lookup the custom metadata map
|
||||
// The OrtAllocator will finally free the customMetadataMapKeysHandle
|
||||
using (var ortAllocationKeysArray = new OrtMemoryAllocation(allocator, customMetadataMapKeysHandle, 0))
|
||||
using (var ortAllocationKeys = new DisposableList<OrtMemoryAllocation>((int)numKeys))
|
||||
{
|
||||
// Put all the handles to each key in the DisposableList to be disposed off in an exception-safe manner
|
||||
for (int i = 0; i < (int)numKeys; ++i)
|
||||
{
|
||||
ortAllocationKeys.Add(new OrtMemoryAllocation(allocator, Marshal.ReadIntPtr(customMetadataMapKeysHandle, IntPtr.Size * i), 0));
|
||||
}
|
||||
|
||||
// Process each key via the stored key handles
|
||||
foreach(var allocation in ortAllocationKeys)
|
||||
{
|
||||
IntPtr keyHandle = allocation.Pointer;
|
||||
IntPtr valueHandle = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataLookupCustomMetadataMap(modelMetadataHandle, allocator.Pointer, keyHandle, out valueHandle));
|
||||
|
||||
using (var ortAllocationValue = new OrtMemoryAllocation(allocator, valueHandle, 0))
|
||||
{
|
||||
var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyHandle);
|
||||
var value = NativeOnnxValueHelper.StringFromNativeUtf8(valueHandle);
|
||||
|
||||
// Put the key/value pair into the dictionary
|
||||
_customMetadataMap[key] = value;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finally
|
||||
{
|
||||
|
||||
// Free ModelMetadata handle
|
||||
NativeMethods.OrtReleaseModelMetadata(modelMetadataHandle);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public string ProducerName
|
||||
{
|
||||
get
|
||||
{
|
||||
return _producerName;
|
||||
}
|
||||
}
|
||||
|
||||
public string GraphName
|
||||
{
|
||||
get
|
||||
{
|
||||
return _graphName;
|
||||
}
|
||||
}
|
||||
|
||||
public string Domain
|
||||
{
|
||||
get
|
||||
{
|
||||
return _domain;
|
||||
}
|
||||
}
|
||||
|
||||
public string Description
|
||||
{
|
||||
get
|
||||
{
|
||||
return _description;
|
||||
}
|
||||
}
|
||||
|
||||
public long Version
|
||||
{
|
||||
get
|
||||
{
|
||||
return _version;
|
||||
}
|
||||
}
|
||||
|
||||
public Dictionary<string, string> CustomMetadataMap
|
||||
{
|
||||
get
|
||||
{
|
||||
return _customMetadataMap;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -293,6 +293,18 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions));
|
||||
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
|
||||
OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue));
|
||||
|
||||
|
||||
OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata));
|
||||
OrtModelMetadataGetProducerName = (DOrtModelMetadataGetProducerName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetProducerName, typeof(DOrtModelMetadataGetProducerName));
|
||||
OrtModelMetadataGetGraphName = (DOrtModelMetadataGetGraphName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphName, typeof(DOrtModelMetadataGetGraphName));
|
||||
OrtModelMetadataGetDomain = (DOrtModelMetadataGetDomain)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDomain, typeof(DOrtModelMetadataGetDomain));
|
||||
OrtModelMetadataGetDescription = (DOrtModelMetadataGetDescription)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDescription, typeof(DOrtModelMetadataGetDescription));
|
||||
OrtModelMetadataGetVersion = (DOrtModelMetadataGetVersion)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetVersion, typeof(DOrtModelMetadataGetVersion));
|
||||
OrtModelMetadataGetCustomMetadataMapKeys = (DOrtModelMetadataGetCustomMetadataMapKeys)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetCustomMetadataMapKeys, typeof(DOrtModelMetadataGetCustomMetadataMapKeys));
|
||||
OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap));
|
||||
OrtReleaseModelMetadata = (DOrtReleaseModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.ReleaseModelMetadata, typeof(DOrtReleaseModelMetadata));
|
||||
|
||||
}
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
|
|
@ -763,6 +775,99 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
#endregion IoBinding API
|
||||
|
||||
#region ModelMetadata API
|
||||
|
||||
/// <summary>
|
||||
/// Gets the ModelMetadata associated with an InferenceSession
|
||||
/// </summary>
|
||||
/// <param name="session">instance of OrtSession</param>
|
||||
/// <param name="modelMetadata">(output) instance of OrtModelMetadata</param>
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtSessionGetModelMetadata(IntPtr /* (const OrtSession*) */ session, out IntPtr /* (OrtModelMetadata**) */ modelMetadata);
|
||||
public static DOrtSessionGetModelMetadata OrtSessionGetModelMetadata;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the producer name associated with a ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="allocator">instance of OrtAllocator</param>
|
||||
/// <param name="value">(output) producer name from the ModelMetadata instance</param>
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
|
||||
public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the graph name associated with a ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="allocator">instance of OrtAllocator</param>
|
||||
/// <param name="value">(output) graph name from the ModelMetadata instance</param>
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
|
||||
public static DOrtModelMetadataGetGraphName OrtModelMetadataGetGraphName;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the domain associated with a ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="allocator">instance of OrtAllocator</param>
|
||||
/// <param name="value">(output) domain from the ModelMetadata instance</param>
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDomain(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
|
||||
public static DOrtModelMetadataGetDomain OrtModelMetadataGetDomain;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the description associated with a ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="allocator">instance of OrtAllocator</param>
|
||||
/// <param name="value">(output) description from the ModelMetadata instance</param>
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
|
||||
public static DOrtModelMetadataGetDescription OrtModelMetadataGetDescription;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the version associated with a ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="value">(output) version from the ModelMetadata instance</param>
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetVersion(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
out long /* (int64_t*) */ value);
|
||||
public static DOrtModelMetadataGetVersion OrtModelMetadataGetVersion;
|
||||
|
||||
/// <summary>
|
||||
/// Gets all the keys in the custom metadata map in the ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="allocator">instance of OrtAllocator</param>
|
||||
/// <param name="keys">(output) all keys in the custom metadata map</param>
|
||||
/// <param name="numKeys">(output) number of keys in the custom metadata map</param>
|
||||
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetCustomMetadataMapKeys(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char***) */ keys, out long /* (int64_t*) */ numKeys);
|
||||
public static DOrtModelMetadataGetCustomMetadataMapKeys OrtModelMetadataGetCustomMetadataMapKeys;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the value associated with the given key in custom metadata map in the ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="allocator">instance of OrtAllocator</param>
|
||||
/// <param name="key">key in the custom metadata map</param>
|
||||
/// <param name="value">(output) value for the key in the custom metadata map</param>
|
||||
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataLookupCustomMetadataMap(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
IntPtr /* (OrtAllocator*) */ allocator, IntPtr /* (const char*) */ key, out IntPtr /* (char**) */ value);
|
||||
public static DOrtModelMetadataLookupCustomMetadataMap OrtModelMetadataLookupCustomMetadataMap;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Frees ModelMetadata instance
|
||||
/// </summary>
|
||||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
public delegate void DOrtReleaseModelMetadata(IntPtr /*(OrtModelMetadata*)*/ modelMetadata);
|
||||
public static DOrtReleaseModelMetadata OrtReleaseModelMetadata;
|
||||
|
||||
#endregion ModelMetadata API
|
||||
|
||||
#region Tensor/OnnxValue API
|
||||
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetValue(IntPtr /*(OrtValue*)*/ value,
|
||||
|
|
|
|||
|
|
@ -1563,6 +1563,36 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
private void TestModelMetadata()
|
||||
{
|
||||
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "model_with_valid_ort_config_json.onnx");
|
||||
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
var modelMetadata = session.ModelMetadata;
|
||||
|
||||
Assert.Equal(1, modelMetadata.Version);
|
||||
|
||||
Assert.Equal("Hari", modelMetadata.ProducerName);
|
||||
|
||||
Assert.Equal("matmul test", modelMetadata.GraphName);
|
||||
|
||||
Assert.Equal("", modelMetadata.Domain);
|
||||
|
||||
Assert.Equal("This is a test model with a valid ORT config Json", modelMetadata.Description);
|
||||
|
||||
Assert.Equal(2, modelMetadata.CustomMetadataMap.Keys.Count);
|
||||
Assert.Equal("dummy_value", modelMetadata.CustomMetadataMap["dummy_key"]);
|
||||
Assert.Equal("{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",
|
||||
modelMetadata.CustomMetadataMap["ort_config"]);
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
private void TestModelSerialization()
|
||||
{
|
||||
|
|
|
|||
BIN
csharp/testdata/model_with_valid_ort_config_json.onnx
vendored
Normal file
BIN
csharp/testdata/model_with_valid_ort_config_json.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -773,19 +773,24 @@ TEST(CApiTest, model_metadata) {
|
|||
|
||||
int64_t num_keys_in_custom_metadata_map;
|
||||
char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map);
|
||||
ASSERT_TRUE(num_keys_in_custom_metadata_map == 1);
|
||||
ASSERT_TRUE(strcmp(custom_metadata_map_keys[0], "ort_config") == 0);
|
||||
ASSERT_TRUE(num_keys_in_custom_metadata_map == 2);
|
||||
|
||||
allocator.get()->Free(custom_metadata_map_keys[0]);
|
||||
allocator.get()->Free(custom_metadata_map_keys[1]);
|
||||
allocator.get()->Free(custom_metadata_map_keys);
|
||||
|
||||
char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
|
||||
ASSERT_TRUE(strcmp(lookup_value,
|
||||
char* lookup_value_1 = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
|
||||
ASSERT_TRUE(strcmp(lookup_value_1,
|
||||
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
|
||||
allocator.get()->Free(lookup_value);
|
||||
allocator.get()->Free(lookup_value_1);
|
||||
|
||||
char* lookup_value_2 = model_metadata.LookupCustomMetadataMap("dummy_key", allocator.get());
|
||||
ASSERT_TRUE(strcmp(lookup_value_2, "dummy_value") == 0);
|
||||
allocator.get()->Free(lookup_value_2);
|
||||
|
||||
// key doesn't exist in custom metadata map
|
||||
lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get());
|
||||
ASSERT_TRUE(lookup_value == nullptr);
|
||||
char* lookup_value_3 = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get());
|
||||
ASSERT_TRUE(lookup_value_3 == nullptr);
|
||||
}
|
||||
|
||||
// The following section tests a model with some missing metadata info
|
||||
|
|
|
|||
Binary file not shown.
Loading…
Reference in a new issue