2018-11-20 00:48:22 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
|
|
|
|
using System;
|
2018-11-23 04:56:43 +00:00
|
|
|
using System.Runtime.InteropServices;
|
|
|
|
|
|
2018-11-20 00:48:22 +00:00
|
|
|
|
|
|
|
|
namespace Microsoft.ML.OnnxRuntime
|
|
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
public enum ExecutionProvider
|
|
|
|
|
{
|
|
|
|
|
Cpu,
|
|
|
|
|
MklDnn
|
|
|
|
|
//TODO: add more providers gradually
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
public class SessionOptions
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
protected SafeHandle _nativeOption;
|
|
|
|
|
protected static readonly Lazy<SessionOptions> _default = new Lazy<SessionOptions>(MakeSessionOptionWithMklDnnProvider);
|
2018-11-20 00:48:22 +00:00
|
|
|
|
|
|
|
|
public SessionOptions()
|
|
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
_nativeOption = new NativeOnnxObjectHandle(NativeMethods.ONNXRuntimeCreateSessionOptions());
|
2018-11-20 00:48:22 +00:00
|
|
|
}
|
|
|
|
|
|
2018-11-23 04:56:43 +00:00
|
|
|
public static SessionOptions Default
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
|
|
|
|
get
|
|
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
return _default.Value;
|
2018-11-20 00:48:22 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-23 04:56:43 +00:00
|
|
|
public void AppendExecutionProvider(ExecutionProvider provider)
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
switch (provider)
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
case ExecutionProvider.Cpu:
|
|
|
|
|
AppendExecutionProvider(CpuExecutionProviderFactory.Default);
|
|
|
|
|
break;
|
|
|
|
|
case ExecutionProvider.MklDnn:
|
|
|
|
|
AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
2018-11-20 00:48:22 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-23 04:56:43 +00:00
|
|
|
|
|
|
|
|
private static SessionOptions MakeSessionOptionWithMklDnnProvider()
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
SessionOptions options = new SessionOptions();
|
|
|
|
|
options.AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
|
|
|
|
|
options.AppendExecutionProvider(CpuExecutionProviderFactory.Default);
|
|
|
|
|
|
|
|
|
|
return options;
|
2018-11-20 00:48:22 +00:00
|
|
|
}
|
|
|
|
|
|
2018-11-23 04:56:43 +00:00
|
|
|
|
|
|
|
|
internal IntPtr NativeHandle
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
get
|
|
|
|
|
{
|
|
|
|
|
return _nativeOption.DangerousGetHandle(); //Note: this is unsafe, and not ref counted, use with caution
|
|
|
|
|
}
|
2018-11-20 00:48:22 +00:00
|
|
|
}
|
|
|
|
|
|
2018-11-23 04:56:43 +00:00
|
|
|
private void AppendExecutionProvider(NativeOnnxObjectHandle providerFactory)
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
unsafe
|
2018-11-20 00:48:22 +00:00
|
|
|
{
|
2018-11-23 04:56:43 +00:00
|
|
|
bool success = false;
|
|
|
|
|
providerFactory.DangerousAddRef(ref success);
|
|
|
|
|
if (success)
|
|
|
|
|
{
|
|
|
|
|
NativeMethods.ONNXRuntimeSessionOptionsAppendExecutionProvider(_nativeOption.DangerousGetHandle(), providerFactory.DangerousGetHandle());
|
|
|
|
|
providerFactory.DangerousRelease();
|
|
|
|
|
}
|
2018-11-20 00:48:22 +00:00
|
|
|
|
2018-11-23 04:56:43 +00:00
|
|
|
}
|
2018-11-20 00:48:22 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|