onnxruntime/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
Dmitri Smirnov a5dec8eedf
[C# ] Improve string marshalling and reduce GC pressure (#15545)
### Description

  Reduce a number of auxillary objects created to reduce GC pressure.
Eliminate GCHandle type of memory pinning in most of the places.
Improve string marshalling by allocating unmanaged memory that does not
require pinning. Change native methods from `IntPtr` to `byte[]`
(marshalling pinning is more efficient).

Allocate input/output UTF-8 names in unmanaged heap for the lifetime of
InferenceSession. So we do not keep converting them and pinning on every
Run.

Introduce a new native API that allows to allocate and convert/copy
strings directly into a native tensor.

The PR delivers around 50% latency improvements and less GC pauses.

Inspired by: https://github.com/microsoft/onnxruntime/pull/15520

### Motivation and Context
Client experience GC pressure and performance degradation when dealing
with string tensors.


Co-Authored-By: @tannergooding
2023-04-20 15:12:51 -07:00

215 lines
7.4 KiB
C#

using System;
using Xunit;
namespace Microsoft.ML.OnnxRuntime.Tests
{
/// <summary>
/// Collection of OrtEnv tests that must be ran sequentially
/// </summary>
[Collection("Ort Inference Tests")]
public class OrtEnvCollectionTest
{
[Fact(DisplayName = "EnablingAndDisablingTelemetryEventCollection")]
public void EnablingAndDisablingTelemetryEventCollection()
{
var ortEnvInstance = OrtEnv.Instance();
ortEnvInstance.DisableTelemetryEvents();
// no-op on non-Windows builds
// may be no-op on certain Windows builds based on build configuration
ortEnvInstance.EnableTelemetryEvents();
}
}
[Collection("Ort Inference Tests")]
public class OrtEnvGetVersion
{
[Fact(DisplayName = "GetVersionString")]
public void GetVersionString()
{
var ortEnvInstance = OrtEnv.Instance();
string versionString = ortEnvInstance.GetVersionString();
Assert.False(versionString.Length == 0);
}
}
[Collection("Ort Inference Tests")]
public class OrtEnvGetAvailableProviders
{
[Fact(DisplayName = "GetAvailableProviders")]
public void GetAvailableProviders()
{
var ortEnvInstance = OrtEnv.Instance();
string[] providers = ortEnvInstance.GetAvailableProviders();
Assert.True(providers.Length > 0);
Assert.Equal("CPUExecutionProvider", providers[providers.Length - 1]);
#if USE_CUDA
Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"));
#endif
#if USE_ROCM
Assert.True(Array.Exists(providers, provider => provider == "ROCMExecutionProvider"));
#endif
}
}
[Collection("Ort Inference Tests")]
public class OrtEnvWithCustomLogLevel
{
[Fact(DisplayName = "TestUpdatingEnvWithCustomLogLevel")]
public void TestUpdatingEnvWithCustomLogLevel()
{
var ortEnvInstance = OrtEnv.Instance();
Assert.True(OrtEnv.IsCreated);
ortEnvInstance.Dispose();
Assert.False(OrtEnv.IsCreated);
// Must be default level of warning
ortEnvInstance = OrtEnv.Instance();
ortEnvInstance.Dispose();
Assert.False(OrtEnv.IsCreated);
var envOptions = new EnvironmentCreationOptions
{
// Everything else is unpopulated
logLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL
};
ortEnvInstance = OrtEnv.CreateInstanceWithOptions(envOptions);
Assert.True(OrtEnv.IsCreated);
Assert.Equal(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, ortEnvInstance.EnvLogLevel);
ortEnvInstance.Dispose();
Assert.False(OrtEnv.IsCreated);
envOptions = new EnvironmentCreationOptions
{
// Everything else is unpopulated
logId = "CSharpOnnxRuntimeTestLogid"
};
ortEnvInstance = OrtEnv.CreateInstanceWithOptions(envOptions);
Assert.Equal(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, ortEnvInstance.EnvLogLevel);
// Change and see if this takes effect
ortEnvInstance.EnvLogLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
Assert.Equal(OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO, ortEnvInstance.EnvLogLevel);
}
}
[Collection("Ort Inference Tests")]
public class OrtEnvWithThreadingOptions
{
[Fact(DisplayName = "TestUpdatingEnvWithThreadingOptions")]
public void TestUpdatingEnvWithThreadingOptions()
{
OrtEnv.Instance().Dispose();
Assert.False(OrtEnv.IsCreated);
using (var opt = new OrtThreadingOptions())
{
var envOptions = new EnvironmentCreationOptions
{
threadOptions = opt
};
// Make sure we start anew
var env = OrtEnv.CreateInstanceWithOptions(envOptions);
Assert.True(OrtEnv.IsCreated);
}
}
}
public class CustomLoggingFunctionTestBase
{
// Custom logging constants
protected static readonly string TestLogId = "CSharpTestLogId";
protected static readonly IntPtr TestLogParam = (IntPtr)5;
protected static int LoggingInvokes = 0;
protected static void CustomLoggingFunction(IntPtr param,
OrtLoggingLevel severity,
string category,
string logId,
string codeLocation,
string message)
{
Assert.Equal(TestLogParam, param); // Passing test param
Assert.False(string.IsNullOrEmpty(codeLocation));
Assert.False(string.IsNullOrEmpty(message));
LoggingInvokes++;
}
}
[Collection("Ort Inference Tests")]
public class OrtEnvWithCustomLogger : CustomLoggingFunctionTestBase
{
[Fact(DisplayName = "TesEnvWithCustomLogger")]
public void TesEnvWithCustomLogger()
{
// Make sure we start anew
OrtEnv.Instance().Dispose();
Assert.False(OrtEnv.IsCreated);
var envOptions = new EnvironmentCreationOptions
{
logId = TestLogId,
logLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE,
loggingFunction = CustomLoggingFunction,
loggingParam = TestLogParam
};
LoggingInvokes = 0;
var env = OrtEnv.CreateInstanceWithOptions(envOptions);
Assert.True(OrtEnv.IsCreated);
var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
// Trigger some logging
// Empty stmt intentional
using (var session = new InferenceSession(model))
;
Assert.True(LoggingInvokes > 0);
}
}
[Collection("Ort Inference Tests")]
public class OrtEnvWithCustomLoggerAndThreadindOptions : CustomLoggingFunctionTestBase
{
[Fact(DisplayName = "TestEnvWithCustomLoggerAndThredingOptions")]
public void TestEnvWithCustomLoggerAndThredingOptions()
{
OrtEnv.Instance().Dispose();
Assert.False(OrtEnv.IsCreated);
using (var opt = new OrtThreadingOptions())
{
var envOptions = new EnvironmentCreationOptions
{
logId = TestLogId,
logLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE,
threadOptions = opt,
loggingFunction = CustomLoggingFunction,
loggingParam = TestLogParam
};
LoggingInvokes = 0;
var env = OrtEnv.CreateInstanceWithOptions(envOptions);
Assert.True(OrtEnv.IsCreated);
var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
// Trigger some logging
// Empty stmt intentional
using (var session = new InferenceSession(model))
;
Assert.True(LoggingInvokes > 0);
}
}
}
}