mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add new C function OrtOnnxTypeFromTypeInfo (#585)
This commit is contained in:
parent
f048fc5fb0
commit
de9f1ff1ff
8 changed files with 44 additions and 3 deletions
|
|
@ -299,6 +299,12 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
|
||||
{
|
||||
var valueType = NativeMethods.OrtOnnxTypeFromTypeInfo(typeInfo);
|
||||
if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR)
|
||||
{
|
||||
return new NodeMetadata(valueType, new int[] { }, typeof(NamedOnnxValue));
|
||||
}
|
||||
|
||||
IntPtr tensorInfo = NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo);
|
||||
// Convert the newly introduced OrtTypeInfo* to the older OrtTypeAndShapeInfo*
|
||||
|
||||
|
|
@ -317,7 +323,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
intDimensions[i] = (int)dimensions[i];
|
||||
}
|
||||
return new NodeMetadata(intDimensions, dotnetType);
|
||||
return new NodeMetadata(valueType, intDimensions, dotnetType);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
|
@ -360,15 +366,25 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// </summary>
|
||||
public class NodeMetadata
|
||||
{
|
||||
private OnnxValueType _onnxValueType;
|
||||
private int[] _dimensions;
|
||||
private Type _type;
|
||||
|
||||
internal NodeMetadata(int[] dimensions, Type type)
|
||||
internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, Type type)
|
||||
{
|
||||
_onnxValueType = onnxValueType;
|
||||
_dimensions = dimensions;
|
||||
_type = type;
|
||||
}
|
||||
|
||||
public OnnxValueType OnnxValueType
|
||||
{
|
||||
get
|
||||
{
|
||||
return _onnxValueType;
|
||||
}
|
||||
}
|
||||
|
||||
public int[] Dimensions
|
||||
{
|
||||
get
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
DataTypeMax = 17
|
||||
}
|
||||
|
||||
internal enum OnnxValueType
|
||||
public enum OnnxValueType
|
||||
{
|
||||
ONNX_TYPE_UNKNOWN = 0,
|
||||
ONNX_TYPE_TENSOR = 1,
|
||||
|
|
|
|||
|
|
@ -249,6 +249,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern OnnxValueType /*Onnxtype*/ OrtGetValueType(IntPtr /*(OrtValue*)*/ value);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern OnnxValueType /*Onnxtype*/ OrtOnnxTypeFromTypeInfo(IntPtr /*(OrtTypeInfo*)*/ typeinfo);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtGetValueCount(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(size_t*)*/ count);
|
||||
|
||||
|
|
|
|||
|
|
@ -536,6 +536,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_sequence_map_int_float.pb");
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
|
||||
var outMeta = session.OutputMetadata;
|
||||
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType);
|
||||
Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType);
|
||||
|
||||
var container = new List<NamedOnnxValue>();
|
||||
var tensorIn = new DenseTensor<float>(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 });
|
||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||
|
|
@ -584,6 +589,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
var outMeta = session.OutputMetadata;
|
||||
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType);
|
||||
Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType);
|
||||
|
||||
var container = new List<NamedOnnxValue>();
|
||||
var tensorIn = new DenseTensor<float>(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 });
|
||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||
|
|
|
|||
|
|
@ -358,6 +358,11 @@ ORT_API_STATUS(OrtGetTensorMemSizeInBytesFromTensorProto, _In_ const void* input
|
|||
*/
|
||||
ORT_API(const OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ OrtTypeInfo*);
|
||||
|
||||
/**
|
||||
* Return OnnxType from OrtTypeInfo
|
||||
*/
|
||||
ORT_API(enum ONNXType, OrtOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*);
|
||||
|
||||
/**
|
||||
* The retured value should be released by calling OrtReleaseTensorTypeAndShapeInfo
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ OrtTypeInfo::~OrtTypeInfo() {
|
|||
OrtReleaseTensorTypeAndShapeInfo(data);
|
||||
}
|
||||
|
||||
ORT_API(enum ONNXType, OrtOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input) {
|
||||
return input->type;
|
||||
}
|
||||
|
||||
ORT_API(const struct OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ struct OrtTypeInfo* input) {
|
||||
return input->type == ONNX_TYPE_TENSOR ? input->data : nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ OrtGetValue
|
|||
OrtGetValueCount
|
||||
OrtGetValueType
|
||||
OrtIsTensor
|
||||
OrtOnnxTypeFromTypeInfo
|
||||
OrtReleaseAllocator
|
||||
OrtReleaseAllocatorInfo
|
||||
OrtReleaseCustomOpDomain
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ static void TestModelInfo(const OrtSession* inference_session, bool is_input, co
|
|||
input_type_info.reset(t);
|
||||
}
|
||||
ASSERT_NE(nullptr, input_type_info);
|
||||
enum ONNXType otype = OrtOnnxTypeFromTypeInfo(input_type_info.get());
|
||||
ASSERT_EQ(ONNX_TYPE_TENSOR, otype);
|
||||
|
||||
const OrtTensorTypeAndShapeInfo* p = OrtCastTypeInfoToTensorInfo(input_type_info.get());
|
||||
ASSERT_NE(nullptr, p);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue