From de9f1ff1ff8be487b7521bea80b688e4d4e69bf2 Mon Sep 17 00:00:00 2001 From: jignparm Date: Tue, 12 Mar 2019 10:11:14 -0700 Subject: [PATCH] Add new C function OrtOnnxTypeFromTypeInfo (#585) --- .../InferenceSession.cs | 20 +++++++++++++++++-- .../NamedOnnxValue.cs | 2 +- .../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 3 +++ .../InferenceTest.cs | 9 +++++++++ .../core/session/onnxruntime_c_api.h | 5 +++++ .../core/framework/onnxruntime_typeinfo.cc | 4 ++++ onnxruntime/core/providers/cpu/symbols.txt | 1 + onnxruntime/test/shared_lib/test_io_types.cc | 3 +++ 8 files changed, 44 insertions(+), 3 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 9e2422e820..bcf680fbe4 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -299,6 +299,12 @@ namespace Microsoft.ML.OnnxRuntime internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) { + var valueType = NativeMethods.OrtOnnxTypeFromTypeInfo(typeInfo); + if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR) + { + return new NodeMetadata(valueType, new int[] { }, typeof(NamedOnnxValue)); + } + IntPtr tensorInfo = NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo); // Convert the newly introduced OrtTypeInfo* to the older OrtTypeAndShapeInfo* @@ -317,7 +323,7 @@ namespace Microsoft.ML.OnnxRuntime { intDimensions[i] = (int)dimensions[i]; } - return new NodeMetadata(intDimensions, dotnetType); + return new NodeMetadata(valueType, intDimensions, dotnetType); } #endregion @@ -360,15 +366,25 @@ namespace Microsoft.ML.OnnxRuntime /// public class NodeMetadata { + private OnnxValueType _onnxValueType; private int[] _dimensions; private Type _type; - internal NodeMetadata(int[] dimensions, Type type) + internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, Type type) { + _onnxValueType = onnxValueType; _dimensions = dimensions; _type = type; } + public OnnxValueType OnnxValueType + { + get + { + return _onnxValueType; + } + } + public int[] Dimensions { get diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs index 3c9fa43c83..07c00943fd 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs @@ -345,7 +345,7 @@ namespace Microsoft.ML.OnnxRuntime DataTypeMax = 17 } - internal enum OnnxValueType + public enum OnnxValueType { ONNX_TYPE_UNKNOWN = 0, ONNX_TYPE_TENSOR = 1, diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 10f391d297..756fe203e2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -249,6 +249,9 @@ namespace Microsoft.ML.OnnxRuntime [DllImport(nativeLib, CharSet = charSet)] public static extern OnnxValueType /*Onnxtype*/ OrtGetValueType(IntPtr /*(OrtValue*)*/ value); + [DllImport(nativeLib, CharSet = charSet)] + public static extern OnnxValueType /*Onnxtype*/ OrtOnnxTypeFromTypeInfo(IntPtr /*(OrtTypeInfo*)*/ typeinfo); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtGetValueCount(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(size_t*)*/ count); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 9ac868fd0a..0fb6cd77b7 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -536,6 +536,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_sequence_map_int_float.pb"); using (var session = new InferenceSession(modelPath)) { + + var outMeta = session.OutputMetadata; + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType); + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType); + var container = new List(); var tensorIn = new DenseTensor(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 }); var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); @@ -584,6 +589,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests using (var session = new InferenceSession(modelPath)) { + var outMeta = session.OutputMetadata; + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType); + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType); + var container = new List(); var tensorIn = new DenseTensor(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 }); var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b93451adb8..0fc5ee4e5d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -358,6 +358,11 @@ ORT_API_STATUS(OrtGetTensorMemSizeInBytesFromTensorProto, _In_ const void* input */ ORT_API(const OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ OrtTypeInfo*); +/** + * Return OnnxType from OrtTypeInfo + */ +ORT_API(enum ONNXType, OrtOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*); + /** * The retured value should be released by calling OrtReleaseTensorTypeAndShapeInfo */ diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 9179ad3d96..e2984ff180 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -21,6 +21,10 @@ OrtTypeInfo::~OrtTypeInfo() { OrtReleaseTensorTypeAndShapeInfo(data); } +ORT_API(enum ONNXType, OrtOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input) { + return input->type; +} + ORT_API(const struct OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ struct OrtTypeInfo* input) { return input->type == ONNX_TYPE_TENSOR ? input->data : nullptr; } diff --git a/onnxruntime/core/providers/cpu/symbols.txt b/onnxruntime/core/providers/cpu/symbols.txt index c8eac9cd59..127ae0b866 100644 --- a/onnxruntime/core/providers/cpu/symbols.txt +++ b/onnxruntime/core/providers/cpu/symbols.txt @@ -49,6 +49,7 @@ OrtGetValue OrtGetValueCount OrtGetValueType OrtIsTensor +OrtOnnxTypeFromTypeInfo OrtReleaseAllocator OrtReleaseAllocatorInfo OrtReleaseCustomOpDomain diff --git a/onnxruntime/test/shared_lib/test_io_types.cc b/onnxruntime/test/shared_lib/test_io_types.cc index d3a0a1eee6..f67325f34f 100644 --- a/onnxruntime/test/shared_lib/test_io_types.cc +++ b/onnxruntime/test/shared_lib/test_io_types.cc @@ -25,6 +25,9 @@ static void TestModelInfo(const OrtSession* inference_session, bool is_input, co input_type_info.reset(t); } ASSERT_NE(nullptr, input_type_info); + enum ONNXType otype = OrtOnnxTypeFromTypeInfo(input_type_info.get()); + ASSERT_EQ(ONNX_TYPE_TENSOR, otype); + const OrtTensorTypeAndShapeInfo* p = OrtCastTypeInfoToTensorInfo(input_type_info.get()); ASSERT_NE(nullptr, p);