Fix model path marshalling in csharp, and re-enable the pretrained model tests (#2236)

This commit is contained in:
shahasad 2019-10-24 20:39:16 -07:00 committed by GitHub
parent 8be48f47dd
commit 6a0ee7eff6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 10 deletions

View file

@ -215,11 +215,7 @@ namespace Microsoft.ML.OnnxRuntime
{
var envHandle = OnnxRuntime.Handle;
var session = IntPtr.Zero;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out session));
else
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out session));
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeMethods.GetPlatformSerializedString(modelPath), options.Handle, out session));
InitWithSessionHandle(session, options);
}

View file

@ -364,10 +364,10 @@ namespace Microsoft.ML.OnnxRuntime
ExecutionMode execution_mode);
public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode;
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, [MarshalAs(UnmanagedType.LPWStr)]string optimizedModelFilepath);
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath);
public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath;
public delegate IntPtr /*(OrtStatus*)*/ DOrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix);
public delegate IntPtr /*(OrtStatus*)*/ DOrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, byte[] profilePathPrefix);
public static DOrtEnableProfiling OrtEnableProfiling;
public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableProfiling(IntPtr /* OrtSessionOptions* */ options);
@ -659,5 +659,13 @@ namespace Microsoft.ML.OnnxRuntime
public static DOrtReleaseValue OrtReleaseValue;
#endregion
public static byte[] GetPlatformSerializedString(string str)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return System.Text.Encoding.Unicode.GetBytes(str + Char.MinValue);
else
return System.Text.Encoding.UTF8.GetBytes(str + Char.MinValue);
}
} //class NativeMethods
} //namespace

View file

@ -201,7 +201,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (!_enableProfiling && value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(_nativePtr, ProfileOutputPathPrefix));
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(_nativePtr, NativeMethods.GetPlatformSerializedString(ProfileOutputPathPrefix)));
_enableProfiling = true;
}
else if (_enableProfiling && !value)
@ -226,7 +226,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (value != _optimizedModelFilePath)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, value));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, NativeMethods.GetPlatformSerializedString(value)));
_optimizedModelFilePath = value;
}
}

View file

@ -344,6 +344,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests
if (System.Environment.Is64BitProcess == false)
{
skipModels["test_vgg19"] = "Get preallocated buffer for initializer conv4_4_b_0 failed";
skipModels["tf_pnasnet_large"] = "Get preallocated buffer for initializer ConvBnFusion_BN_B_cell_5/comb_iter_1/left/bn_sep_7x7_1/beta:0_203 failed";
skipModels["tf_nasnet_large"] = "Get preallocated buffer for initializer ConvBnFusion_BN_B_cell_11/beginning_bn/beta:0_331 failed";
}
return skipModels;
@ -390,7 +392,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
[Theory(Skip = "TestPreTrainedModels is flaky and is blocking CI build progress. Enable it once this is fixed.")]
[Theory]
[MemberData(nameof(GetModelsForTest))]
[MemberData(nameof(GetSkippedModelForTest), Skip = "Skipped due to Error, please fix the error and enable the test")]
private void TestPreTrainedModels(string opset, string modelName)