csharp api for graph transformers (#741)

* add graph optimization level to csharp api

* format documentation

* changes per review comments
This commit is contained in:
Ashwini Khade 2019-04-02 17:23:14 -07:00 committed by GitHub
parent 06888437dd
commit 2dbce4ebcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 7 deletions

View file

@ -24,8 +24,11 @@ namespace CSharpUsage
{
string modelPath = Directory.GetCurrentDirectory() + @"\squeezenet.onnx";
// Optional : Create session options and set the graph optimization level for the session
SessionOptions options = new SessionOptions();
options.SetSessionGraphOptimizationLevel(2);
using (var session = new InferenceSession(modelPath))
using (var session = new InferenceSession(modelPath, options))
{
var inputMeta = session.InputMetadata;
var container = new List<NamedOnnxValue>();

View file

@ -157,6 +157,9 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern int OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize);
[DllImport(nativeLib, CharSet = charSet)]
public static extern int OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel);
///**
// * The order of invocation indicates the preference order as well. In other words call this method

View file

@ -35,6 +35,21 @@ namespace Microsoft.ML.OnnxRuntime
_nativePtr = NativeMethods.OrtCreateSessionOptions();
}
/// <summary>
/// Sets the graph optimization level for the session. Default is set to 1.
/// </summary>
/// <param name="optimization_level">optimization level for the session</param>
/// Available options are : 0, 1, 2
/// 0 -> Disable all optimizations
/// 1 -> Enable basic optimizations
/// 2 -> Enable all optimizations
/// <returns>True on success and false otherwise</returns>
public bool SetSessionGraphOptimizationLevel(uint optimization_level)
{
var result = NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level);
return result == 0;
}
/// <summary>
/// Default instance
/// </summary>

View file

@ -50,12 +50,18 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
[Fact]
private void CanRunInferenceOnAModel()
[Theory]
[InlineData(0)]
[InlineData(2)]
private void CanRunInferenceOnAModel(uint graphOptimizationLevel)
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
using (var session = new InferenceSession(modelPath))
// Set the graph optimization level for this session.
SessionOptions options = new SessionOptions();
Assert.True(options.SetSessionGraphOptimizationLevel(graphOptimizationLevel));
using (var session = new InferenceSession(modelPath, options))
{
var inputMeta = session.InputMetadata;
var container = new List<NamedOnnxValue>();
@ -664,8 +670,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests
"OrtSessionGetOutputTypeInfo","OrtReleaseSession","OrtCreateSessionOptions","OrtCloneSessionOptions",
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider_CPU",
"OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSetSessionGraphOptimizationLevel",
"OrtSessionOptionsAppendExecutionProvider_CPU","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
"OrtCreateDefaultAllocator","OrtAllocatorFree","OrtAllocatorGetInfo",
"OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo",
"OrtCastTypeInfoToTensorInfo","OrtGetTensorShapeAndType","OrtGetTensorElementType","OrtGetNumOfDimensions",
@ -677,7 +683,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var x = GetProcAddress(hModule, ep);
Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}");
}
}
}
static string GetTestModelsDir()
{

View file

@ -105,6 +105,12 @@ Constructs a SessionOptions will all options at default/unset values.
Accessor to the default static option object
#### Methods
SetSessionGraphOptimizationLevel(uint optimization_level);
Sets the graph optimization level for the session. Default is set to 1. Available options are : {0, 1, 2}.
* 0 -> Disable all optimizations
* 1 -> Enable basic optimizations such as redundant node removals and constant folding
* 2 -> Enable all optimizations (includes Level1 and more complex optimizations such as node fusions)
AppendExecutionProvider(ExecutionProvider provider);
Appends execution provider to the session. For any operator in the graph the first execution provider that implements the operator will be user. ExecutionProvider is defined as the following enum.