Add new C function OrtOnnxTypeFromTypeInfo (#585)

This commit is contained in:
jignparm 2019-03-12 10:11:14 -07:00 committed by Changming Sun
parent f048fc5fb0
commit de9f1ff1ff
8 changed files with 44 additions and 3 deletions

View file

@ -299,6 +299,12 @@ namespace Microsoft.ML.OnnxRuntime
internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
{
var valueType = NativeMethods.OrtOnnxTypeFromTypeInfo(typeInfo);
if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR)
{
return new NodeMetadata(valueType, new int[] { }, typeof(NamedOnnxValue));
}
IntPtr tensorInfo = NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo);
// Convert the newly introduced OrtTypeInfo* to the older OrtTypeAndShapeInfo*
@ -317,7 +323,7 @@ namespace Microsoft.ML.OnnxRuntime
{
intDimensions[i] = (int)dimensions[i];
}
return new NodeMetadata(intDimensions, dotnetType);
return new NodeMetadata(valueType, intDimensions, dotnetType);
}
#endregion
@ -360,15 +366,25 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
public class NodeMetadata
{
private OnnxValueType _onnxValueType;
private int[] _dimensions;
private Type _type;
internal NodeMetadata(int[] dimensions, Type type)
internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, Type type)
{
_onnxValueType = onnxValueType;
_dimensions = dimensions;
_type = type;
}
public OnnxValueType OnnxValueType
{
get
{
return _onnxValueType;
}
}
public int[] Dimensions
{
get

View file

@ -345,7 +345,7 @@ namespace Microsoft.ML.OnnxRuntime
DataTypeMax = 17
}
internal enum OnnxValueType
public enum OnnxValueType
{
ONNX_TYPE_UNKNOWN = 0,
ONNX_TYPE_TENSOR = 1,

View file

@ -249,6 +249,9 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern OnnxValueType /*Onnxtype*/ OrtGetValueType(IntPtr /*(OrtValue*)*/ value);
[DllImport(nativeLib, CharSet = charSet)]
public static extern OnnxValueType /*Onnxtype*/ OrtOnnxTypeFromTypeInfo(IntPtr /*(OrtTypeInfo*)*/ typeinfo);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtGetValueCount(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(size_t*)*/ count);

View file

@ -536,6 +536,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_sequence_map_int_float.pb");
using (var session = new InferenceSession(modelPath))
{
var outMeta = session.OutputMetadata;
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType);
Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType);
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<float>(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
@ -584,6 +589,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
using (var session = new InferenceSession(modelPath))
{
var outMeta = session.OutputMetadata;
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType);
Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType);
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<float>(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);

View file

@ -358,6 +358,11 @@ ORT_API_STATUS(OrtGetTensorMemSizeInBytesFromTensorProto, _In_ const void* input
*/
ORT_API(const OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ OrtTypeInfo*);
/**
* Return OnnxType from OrtTypeInfo
*/
ORT_API(enum ONNXType, OrtOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*);
/**
* The retured value should be released by calling OrtReleaseTensorTypeAndShapeInfo
*/

View file

@ -21,6 +21,10 @@ OrtTypeInfo::~OrtTypeInfo() {
OrtReleaseTensorTypeAndShapeInfo(data);
}
ORT_API(enum ONNXType, OrtOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input) {
return input->type;
}
ORT_API(const struct OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ struct OrtTypeInfo* input) {
return input->type == ONNX_TYPE_TENSOR ? input->data : nullptr;
}

View file

@ -49,6 +49,7 @@ OrtGetValue
OrtGetValueCount
OrtGetValueType
OrtIsTensor
OrtOnnxTypeFromTypeInfo
OrtReleaseAllocator
OrtReleaseAllocatorInfo
OrtReleaseCustomOpDomain

View file

@ -25,6 +25,9 @@ static void TestModelInfo(const OrtSession* inference_session, bool is_input, co
input_type_info.reset(t);
}
ASSERT_NE(nullptr, input_type_info);
enum ONNXType otype = OrtOnnxTypeFromTypeInfo(input_type_info.get());
ASSERT_EQ(ONNX_TYPE_TENSOR, otype);
const OrtTensorTypeAndShapeInfo* p = OrtCastTypeInfoToTensorInfo(input_type_info.get());
ASSERT_NE(nullptr, p);