diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 3fa97a0c18..a7b50003ef 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.OnnxRuntime protected Dictionary _inputMetadata, _outputMetadata, _overridableInitializerMetadata; private SessionOptions _builtInSessionOptions = null; private RunOptions _builtInRunOptions = null; - + private ModelMetadata _modelMetadata = null; #region Public API @@ -641,12 +641,17 @@ namespace Microsoft.ML.OnnxRuntime return result; } - //TODO: kept internal until implemented - internal ModelMetadata ModelMetadata + public ModelMetadata ModelMetadata { get { - return new ModelMetadata(); //TODO: implement + if (_modelMetadata != null) + { + return _modelMetadata; + } + + _modelMetadata = new ModelMetadata(this); + return _modelMetadata; } } @@ -993,9 +998,158 @@ namespace Microsoft.ML.OnnxRuntime } - internal class ModelMetadata + public class ModelMetadata { - //TODO: placeholder for Model metadata. Currently C-API does not expose this. + private string _producerName; + private string _graphName; + private string _domain; + private string _description; + private long _version; + private Dictionary _customMetadataMap = new Dictionary(); + + internal ModelMetadata(InferenceSession session) + { + IntPtr modelMetadataHandle = IntPtr.Zero; + + var allocator = OrtAllocator.DefaultInstance; + + // Get the native ModelMetadata instance associated with the InferenceSession + + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetModelMetadata(session.Handle, out modelMetadataHandle)); + + try + { + + // Process producer name + IntPtr producerNameHandle = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetProducerName(modelMetadataHandle, allocator.Pointer, out producerNameHandle)); + using (var ortAllocation = new OrtMemoryAllocation(allocator, producerNameHandle, 0)) + { + _producerName = NativeOnnxValueHelper.StringFromNativeUtf8(producerNameHandle); + } + + // Process graph name + IntPtr graphNameHandle = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphName(modelMetadataHandle, allocator.Pointer, out graphNameHandle)); + using (var ortAllocation = new OrtMemoryAllocation(allocator, graphNameHandle, 0)) + { + _graphName = NativeOnnxValueHelper.StringFromNativeUtf8(graphNameHandle); + } + + + // Process domain + IntPtr domainHandle = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDomain(modelMetadataHandle, allocator.Pointer, out domainHandle)); + using (var ortAllocation = new OrtMemoryAllocation(allocator, domainHandle, 0)) + { + _domain = NativeOnnxValueHelper.StringFromNativeUtf8(domainHandle); + } + + // Process description + IntPtr descriptionHandle = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDescription(modelMetadataHandle, allocator.Pointer, out descriptionHandle)); + using (var ortAllocation = new OrtMemoryAllocation(allocator, descriptionHandle, 0)) + { + _description = NativeOnnxValueHelper.StringFromNativeUtf8(descriptionHandle); + } + + // Process version + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetVersion(modelMetadataHandle, out _version)); + + + // Process CustomMetadata Map + IntPtr customMetadataMapKeysHandle = IntPtr.Zero; + long numKeys; + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetCustomMetadataMapKeys(modelMetadataHandle, allocator.Pointer, out customMetadataMapKeysHandle, out numKeys)); + + // We have received an array of null terminated C strings which are the keys that we can use to lookup the custom metadata map + // The OrtAllocator will finally free the customMetadataMapKeysHandle + using (var ortAllocationKeysArray = new OrtMemoryAllocation(allocator, customMetadataMapKeysHandle, 0)) + using (var ortAllocationKeys = new DisposableList((int)numKeys)) + { + // Put all the handles to each key in the DisposableList to be disposed off in an exception-safe manner + for (int i = 0; i < (int)numKeys; ++i) + { + ortAllocationKeys.Add(new OrtMemoryAllocation(allocator, Marshal.ReadIntPtr(customMetadataMapKeysHandle, IntPtr.Size * i), 0)); + } + + // Process each key via the stored key handles + foreach(var allocation in ortAllocationKeys) + { + IntPtr keyHandle = allocation.Pointer; + IntPtr valueHandle = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataLookupCustomMetadataMap(modelMetadataHandle, allocator.Pointer, keyHandle, out valueHandle)); + + using (var ortAllocationValue = new OrtMemoryAllocation(allocator, valueHandle, 0)) + { + var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyHandle); + var value = NativeOnnxValueHelper.StringFromNativeUtf8(valueHandle); + + // Put the key/value pair into the dictionary + _customMetadataMap[key] = value; + + } + } + } + } + + finally + { + + // Free ModelMetadata handle + NativeMethods.OrtReleaseModelMetadata(modelMetadataHandle); + + } + + } + + public string ProducerName + { + get + { + return _producerName; + } + } + + public string GraphName + { + get + { + return _graphName; + } + } + + public string Domain + { + get + { + return _domain; + } + } + + public string Description + { + get + { + return _description; + } + } + + public long Version + { + get + { + return _version; + } + } + + public Dictionary CustomMetadataMap + { + get + { + return _customMetadataMap; + } + } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index ee79996964..4741889bbb 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -293,6 +293,18 @@ namespace Microsoft.ML.OnnxRuntime OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions)); OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount)); OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue)); + + + OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata)); + OrtModelMetadataGetProducerName = (DOrtModelMetadataGetProducerName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetProducerName, typeof(DOrtModelMetadataGetProducerName)); + OrtModelMetadataGetGraphName = (DOrtModelMetadataGetGraphName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphName, typeof(DOrtModelMetadataGetGraphName)); + OrtModelMetadataGetDomain = (DOrtModelMetadataGetDomain)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDomain, typeof(DOrtModelMetadataGetDomain)); + OrtModelMetadataGetDescription = (DOrtModelMetadataGetDescription)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetDescription, typeof(DOrtModelMetadataGetDescription)); + OrtModelMetadataGetVersion = (DOrtModelMetadataGetVersion)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetVersion, typeof(DOrtModelMetadataGetVersion)); + OrtModelMetadataGetCustomMetadataMapKeys = (DOrtModelMetadataGetCustomMetadataMapKeys)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetCustomMetadataMapKeys, typeof(DOrtModelMetadataGetCustomMetadataMapKeys)); + OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap)); + OrtReleaseModelMetadata = (DOrtReleaseModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.ReleaseModelMetadata, typeof(DOrtReleaseModelMetadata)); + } [DllImport(nativeLib, CharSet = charSet)] @@ -763,6 +775,99 @@ namespace Microsoft.ML.OnnxRuntime #endregion IoBinding API + #region ModelMetadata API + + /// + /// Gets the ModelMetadata associated with an InferenceSession + /// + /// instance of OrtSession + /// (output) instance of OrtModelMetadata + public delegate IntPtr /* (OrtStatus*) */ DOrtSessionGetModelMetadata(IntPtr /* (const OrtSession*) */ session, out IntPtr /* (OrtModelMetadata**) */ modelMetadata); + public static DOrtSessionGetModelMetadata OrtSessionGetModelMetadata; + + /// + /// Gets the producer name associated with a ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// instance of OrtAllocator + /// (output) producer name from the ModelMetadata instance + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName; + + /// + /// Gets the graph name associated with a ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// instance of OrtAllocator + /// (output) graph name from the ModelMetadata instance + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + public static DOrtModelMetadataGetGraphName OrtModelMetadataGetGraphName; + + /// + /// Gets the domain associated with a ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// instance of OrtAllocator + /// (output) domain from the ModelMetadata instance + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDomain(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + public static DOrtModelMetadataGetDomain OrtModelMetadataGetDomain; + + /// + /// Gets the description associated with a ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// instance of OrtAllocator + /// (output) description from the ModelMetadata instance + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + public static DOrtModelMetadataGetDescription OrtModelMetadataGetDescription; + + /// + /// Gets the version associated with a ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// (output) version from the ModelMetadata instance + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetVersion(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + out long /* (int64_t*) */ value); + public static DOrtModelMetadataGetVersion OrtModelMetadataGetVersion; + + /// + /// Gets all the keys in the custom metadata map in the ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// instance of OrtAllocator + /// (output) all keys in the custom metadata map + /// (output) number of keys in the custom metadata map + + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetCustomMetadataMapKeys(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char***) */ keys, out long /* (int64_t*) */ numKeys); + public static DOrtModelMetadataGetCustomMetadataMapKeys OrtModelMetadataGetCustomMetadataMapKeys; + + /// + /// Gets the value associated with the given key in custom metadata map in the ModelMetadata instance + /// + /// instance of OrtModelMetadata + /// instance of OrtAllocator + /// key in the custom metadata map + /// (output) value for the key in the custom metadata map + + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataLookupCustomMetadataMap(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + IntPtr /* (OrtAllocator*) */ allocator, IntPtr /* (const char*) */ key, out IntPtr /* (char**) */ value); + public static DOrtModelMetadataLookupCustomMetadataMap OrtModelMetadataLookupCustomMetadataMap; + + + /// + /// Frees ModelMetadata instance + /// + /// instance of OrtModelMetadata + public delegate void DOrtReleaseModelMetadata(IntPtr /*(OrtModelMetadata*)*/ modelMetadata); + public static DOrtReleaseModelMetadata OrtReleaseModelMetadata; + + #endregion ModelMetadata API + #region Tensor/OnnxValue API public delegate IntPtr /*(OrtStatus*)*/ DOrtGetValue(IntPtr /*(OrtValue*)*/ value, diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index b00af13299..6fbb6055e9 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -1563,6 +1563,36 @@ namespace Microsoft.ML.OnnxRuntime.Tests } } + [Fact] + private void TestModelMetadata() + { + + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "model_with_valid_ort_config_json.onnx"); + + using (var session = new InferenceSession(modelPath)) + { + var modelMetadata = session.ModelMetadata; + + Assert.Equal(1, modelMetadata.Version); + + Assert.Equal("Hari", modelMetadata.ProducerName); + + Assert.Equal("matmul test", modelMetadata.GraphName); + + Assert.Equal("", modelMetadata.Domain); + + Assert.Equal("This is a test model with a valid ORT config Json", modelMetadata.Description); + + Assert.Equal(2, modelMetadata.CustomMetadataMap.Keys.Count); + Assert.Equal("dummy_value", modelMetadata.CustomMetadataMap["dummy_key"]); + Assert.Equal("{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}", + modelMetadata.CustomMetadataMap["ort_config"]); + + + + } + } + [Fact] private void TestModelSerialization() { diff --git a/csharp/testdata/model_with_valid_ort_config_json.onnx b/csharp/testdata/model_with_valid_ort_config_json.onnx new file mode 100644 index 0000000000..f2a0a9bb8e Binary files /dev/null and b/csharp/testdata/model_with_valid_ort_config_json.onnx differ diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 77f81f8154..0bb5266859 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -773,19 +773,24 @@ TEST(CApiTest, model_metadata) { int64_t num_keys_in_custom_metadata_map; char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map); - ASSERT_TRUE(num_keys_in_custom_metadata_map == 1); - ASSERT_TRUE(strcmp(custom_metadata_map_keys[0], "ort_config") == 0); + ASSERT_TRUE(num_keys_in_custom_metadata_map == 2); + allocator.get()->Free(custom_metadata_map_keys[0]); + allocator.get()->Free(custom_metadata_map_keys[1]); allocator.get()->Free(custom_metadata_map_keys); - char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get()); - ASSERT_TRUE(strcmp(lookup_value, + char* lookup_value_1 = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get()); + ASSERT_TRUE(strcmp(lookup_value_1, "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0); - allocator.get()->Free(lookup_value); + allocator.get()->Free(lookup_value_1); + + char* lookup_value_2 = model_metadata.LookupCustomMetadataMap("dummy_key", allocator.get()); + ASSERT_TRUE(strcmp(lookup_value_2, "dummy_value") == 0); + allocator.get()->Free(lookup_value_2); // key doesn't exist in custom metadata map - lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get()); - ASSERT_TRUE(lookup_value == nullptr); + char* lookup_value_3 = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get()); + ASSERT_TRUE(lookup_value_3 == nullptr); } // The following section tests a model with some missing metadata info diff --git a/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx b/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx index a57b83f71a..f2a0a9bb8e 100644 Binary files a/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx and b/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx differ