Expose SessionOtions.DisablePerSessionThreads (#19730)

### Description

### Motivation and Context
ML.NET needs to run mltiple sessions on a single threadpool.
This commit is contained in:
Dmitri Smirnov 2024-03-04 13:46:51 -08:00 committed by GitHub
parent 27b1dc91ab
commit 0cdf36faeb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 18 additions and 1 deletions

View file

@ -362,6 +362,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern));
OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena));
OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena));
OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads));
OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId));
OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel));
OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel));
@ -992,6 +993,10 @@ namespace Microsoft.ML.OnnxRuntime
public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
public static DOrtDisableCpuMemArena OrtDisableCpuMemArena;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options);
public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId);
public static DOrtSetSessionLogId OrtSetSessionLogId;

View file

@ -696,6 +696,15 @@ namespace Microsoft.ML.OnnxRuntime
}
private bool _enableCpuMemArena = true;
/// <summary>
/// Disables the per session threads. Default is true.
/// This makes all sessions in the process use a global TP.
/// </summary>
public void DisablePerSessionThreads()
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle));
}
/// <summary>
/// Log Id to be used for the session. Default is empty string.
/// </summary>

View file

@ -55,6 +55,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.Equal(0, opt.InterOpNumThreads);
Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel);
// No get, so no verify
opt.DisablePerSessionThreads();
// try setting options
opt.ExecutionMode = ExecutionMode.ORT_PARALLEL;
Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode);
@ -98,7 +101,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message);
// SessionOptions.RegisterOrtExtensions can be manually tested by referencing the
// Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.
// Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.RegisterOrtExtensions(); });
Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message);