added unit test to guard against native API changes (#337)

* added unit test to guard against native API changes

* Removed cuda and mkldnn from API checks

* Updated per some code comments
This commit is contained in:
jignparm 2019-01-16 16:53:06 -08:00 committed by GitHub
parent 790cda6ea7
commit b3f0d0b659
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 16 deletions

View file

@ -64,7 +64,7 @@ namespace Microsoft.ML.OnnxRuntime
private static void Delete(IntPtr nativePtr)
{
NativeMethods.ReleaseOrtAllocatorInfo(nativePtr);
NativeMethods.OrtReleaseAllocatorInfo(nativePtr);
}
protected override bool ReleaseHandle()

View file

@ -170,11 +170,11 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateCUDAExecutionProviderFactory(int device_id, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateNupharExecutionProviderFactory(int device_id, string target_str, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern IntPtr /*(OrtStatus*)*/ OrtCreateNupharExecutionProviderFactory(int device_id, string target_str, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
[DllImport(nativeLib, CharSet = charSet)]
public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path);
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path);
#endregion
@ -215,7 +215,7 @@ namespace Microsoft.ML.OnnxRuntime
);
[DllImport(nativeLib, CharSet = charSet)]
public static extern void ReleaseOrtAllocatorInfo(IntPtr /*(OrtAllocatorInfo*)*/ allocatorInfo);
public static extern void OrtReleaseAllocatorInfo(IntPtr /*(OrtAllocatorInfo*)*/ allocatorInfo);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/OrtCreateDefaultAllocator(out IntPtr /*(OrtAllocator**)*/ allocator);
@ -258,12 +258,6 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtGetTensorMutableData(IntPtr /*(OrtValue*)*/ value, out IntPtr /* (void**)*/ dataBufferHandle);
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern IntPtr /*(OrtStatus*)*/ OrtGetTensorShapeDimCount(IntPtr /*(OrtValue*)*/ value, out ulong dimension); //size_t TODO: make it portable for x86, arm
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern IntPtr /*(OrtStatus*)*/ OrtGetTensorShapeElementCount(IntPtr /*(OrtValue*)*/value, out ulong count);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/
OrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo);

View file

@ -5,6 +5,7 @@ using System;
using System.IO;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Numerics.Tensors;
using System.Threading.Tasks;
using Xunit;
@ -13,10 +14,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
public class InferenceTest
{
private const string module = "onnxruntime.dll";
[Fact]
public void CanCreateAndDisposeSessionWithModelPath()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory() , "squeezenet.onnx");
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
using (var session = new InferenceSession(modelPath))
{
Assert.NotNull(session);
@ -49,7 +52,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
[Fact]
private void CanRunInferenceOnAModel()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
using (var session = new InferenceSession(modelPath))
{
@ -121,7 +124,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var tensor = new DenseTensor<int>(inputDataInt, inputMeta["data_0"].Dimensions);
container.Add(NamedOnnxValue.CreateFromTensor<int>("data_0", tensor));
var ex = Assert.Throws<OnnxRuntimeException>(() => session.Run(container));
var msg = ex.ToString().Substring(0,101);
var msg = ex.ToString().Substring(0, 101);
// TODO: message is diff in LInux. Use substring match
Assert.Equal("Microsoft.ML.OnnxRuntime.OnnxRuntimeException: [ErrorCode:InvalidArgument] Unexpected input data type", msg);
session.Dispose();
@ -239,7 +242,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var innodepair = inMeta.First();
var innodename = innodepair.Key;
var innodedims = innodepair.Value.Dimensions;
for (int i=0; i < innodedims.Length; i++)
for (int i = 0; i < innodedims.Length; i++)
{
if (innodedims[i] < 0)
innodedims[i] = -1 * innodedims[i];
@ -470,6 +473,40 @@ namespace Microsoft.ML.OnnxRuntime.Tests
session.Dispose();
}
[DllImport("kernel32", SetLastError = true)]
static extern IntPtr LoadLibrary(string lpFileName);
[DllImport("kernel32", CharSet = CharSet.Ansi)]
static extern UIntPtr GetProcAddress(IntPtr hModule, string procName);
[Fact]
private void VerifyNativeMethodsExist()
{
// Check for external API changes
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return;
var entryPointNames = new[]{
"OrtInitialize","OrtReleaseEnv","OrtGetErrorCode","OrtGetErrorMessage",
"OrtReleaseStatus","OrtCreateSession","OrtRun","OrtSessionGetInputCount",
"OrtSessionGetOutputCount","OrtSessionGetInputName","OrtSessionGetOutputName","OrtSessionGetInputTypeInfo",
"OrtSessionGetOutputTypeInfo","OrtReleaseSession","OrtCreateSessionOptions","OrtCloneSessionOptions",
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider",
"OrtCreateCpuExecutionProviderFactory","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
"OrtCreateDefaultAllocator","OrtReleaseObject","OrtAllocatorFree","OrtAllocatorGetInfo",
"OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo",
"OrtCastTypeInfoToTensorInfo","OrtGetTensorShapeAndType","OrtGetTensorElementType","OrtGetNumOfDimensions",
"OrtGetDimensions","OrtGetTensorShapeElementCount","OrtReleaseValue"};
var hModule = LoadLibrary(module);
foreach (var ep in entryPointNames)
{
var x = GetProcAddress(hModule, ep);
Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}");
}
}
static float[] LoadTensorFromFile(string filename, bool skipheader = true)
{
var tensorData = new List<float>();