mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
Use a friendly enum for graph optimization level. (#1586)
* Mention OrtCreateSessionFromArray in C API doc * review changes * use enum for graph optimization level * Use explicit values for enums * updates... * Add friendly enum for graph optimization levels in C, C# and Python APIs. * Fix linux build * Fix build breakage due to master merge * PR comments
This commit is contained in:
parent
24d17f4353
commit
8d12ce45cf
17 changed files with 227 additions and 143 deletions
|
|
@ -26,7 +26,7 @@ namespace CSharpUsage
|
|||
|
||||
// Optional : Create session options and set the graph optimization level for the session
|
||||
SessionOptions options = new SessionOptions();
|
||||
options.GraphOptimizationLevel = 2;
|
||||
options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;
|
||||
|
||||
using (var session = new InferenceSession(modelPath, options))
|
||||
{
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, GraphOptimizationLevel graphOptimizationLevel);
|
||||
|
||||
|
||||
///**
|
||||
|
|
|
|||
|
|
@ -8,6 +8,17 @@ using System.IO;
|
|||
|
||||
namespace Microsoft.ML.OnnxRuntime
|
||||
{
|
||||
/// <summary>
|
||||
/// TODO Add documentation about which optimizations are enabled for each value.
|
||||
/// </summary>
|
||||
public enum GraphOptimizationLevel
|
||||
{
|
||||
ORT_DISABLE_ALL = 0,
|
||||
ORT_ENABLE_BASIC = 1,
|
||||
ORT_ENABLE_EXTENDED = 2,
|
||||
ORT_ENABLE_ALL = 99
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Holds the options for creating an InferenceSession
|
||||
/// </summary>
|
||||
|
|
@ -117,7 +128,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
private bool _enableMemoryPattern = true;
|
||||
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Path prefix to use for output of profiling data
|
||||
/// </summary>
|
||||
|
|
@ -158,7 +169,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// </summary>
|
||||
public string OptimizedModelFilePath
|
||||
{
|
||||
get
|
||||
get
|
||||
{
|
||||
return _optimizedModelFilePath;
|
||||
}
|
||||
|
|
@ -174,7 +185,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
private string _optimizedModelFilePath = "";
|
||||
|
||||
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Enables Arena allocator for the CPU memory allocations. Default is true.
|
||||
/// </summary>
|
||||
|
|
@ -190,7 +201,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableCpuMemArena(_nativePtr));
|
||||
_enableCpuMemArena = true;
|
||||
}
|
||||
}
|
||||
else if (_enableCpuMemArena && !value)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableCpuMemArena(_nativePtr));
|
||||
|
|
@ -259,13 +270,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
|
||||
/// <summary>
|
||||
/// Sets the graph optimization level for the session. Default is set to 1.
|
||||
/// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_BASIC.
|
||||
/// </summary>
|
||||
/// Available options are : 0, 1, 2
|
||||
/// 0 -> Disable all optimizations
|
||||
/// 1 -> Enable basic optimizations
|
||||
/// 2 -> Enable all optimizations
|
||||
public uint GraphOptimizationLevel
|
||||
public GraphOptimizationLevel GraphOptimizationLevel
|
||||
{
|
||||
get
|
||||
{
|
||||
|
|
@ -277,7 +284,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
_graphOptimizationLevel = value;
|
||||
}
|
||||
}
|
||||
private uint _graphOptimizationLevel = 1;
|
||||
private GraphOptimizationLevel _graphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_BASIC;
|
||||
|
||||
#endregion
|
||||
|
||||
|
|
@ -298,16 +305,16 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
IntPtr handle = LoadLibrary(dll);
|
||||
if (handle != IntPtr.Zero)
|
||||
continue;
|
||||
continue;
|
||||
var sysdir = new StringBuilder(String.Empty, 2048);
|
||||
GetSystemDirectory(sysdir, (uint)sysdir.Capacity);
|
||||
throw new OnnxRuntimeException(
|
||||
ErrorCode.NoSuchFile,
|
||||
ErrorCode.NoSuchFile,
|
||||
$"kernel32.LoadLibrary():'{dll}' not found. CUDA is required for GPU execution. " +
|
||||
$". Verify it is available in the system directory={sysdir}. Else copy it to the output folder."
|
||||
);
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,11 +34,7 @@ int main(int argc, char* argv[]) {
|
|||
OrtSetSessionThreadPoolSize(session_options, 1);
|
||||
|
||||
// Sets graph optimization level
|
||||
// Available levels are
|
||||
// 0 -> To disable all optimizations
|
||||
// 1 -> To enable basic optimizations (Such as redundant node removals)
|
||||
// 2 -> To enable all optimizations (Includes level 1 + more complex optimizations like node fusions)
|
||||
OrtSetSessionGraphOptimizationLevel(session_options, 1);
|
||||
OrtSetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC);
|
||||
|
||||
// Optionally add more execution providers via session_options
|
||||
// E.g. for CUDA include cuda_provider_factory.h and uncomment the following line:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
Assert.Equal("", opt.LogId);
|
||||
Assert.Equal(LogLevel.Verbose, opt.LogVerbosityLevel);
|
||||
Assert.Equal(0, opt.ThreadPoolSize);
|
||||
Assert.Equal(1u, opt.GraphOptimizationLevel);
|
||||
Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_BASIC, opt.GraphOptimizationLevel);
|
||||
|
||||
// try setting options
|
||||
opt.EnableSequentialExecution = false;
|
||||
|
|
@ -62,12 +62,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
opt.ThreadPoolSize = 4;
|
||||
Assert.Equal(4, opt.ThreadPoolSize);
|
||||
|
||||
opt.GraphOptimizationLevel = 3;
|
||||
Assert.Equal(3u, opt.GraphOptimizationLevel);
|
||||
opt.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;
|
||||
Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, opt.GraphOptimizationLevel);
|
||||
|
||||
Assert.Throws<OnnxRuntimeException>(() => { opt.ThreadPoolSize = -2; });
|
||||
Assert.Throws<OnnxRuntimeException>(() => { opt.GraphOptimizationLevel = 10; });
|
||||
|
||||
Assert.Throws<OnnxRuntimeException>(() => { opt.GraphOptimizationLevel = (GraphOptimizationLevel)10; });
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -129,11 +129,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(0, true)]
|
||||
[InlineData(0, false)]
|
||||
[InlineData(2, true)]
|
||||
[InlineData(2, false)]
|
||||
private void CanRunInferenceOnAModel(uint graphOptimizationLevel, bool disableSequentialExecution)
|
||||
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, true)]
|
||||
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, false)]
|
||||
[InlineData(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, true)]
|
||||
[InlineData(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, false)]
|
||||
private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLevel, bool disableSequentialExecution)
|
||||
{
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
|
||||
|
||||
|
|
@ -743,7 +743,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
// Set the optimized model file path to assert that no exception are thrown.
|
||||
SessionOptions options = new SessionOptions();
|
||||
options.OptimizedModelFilePath = modelOutputPath;
|
||||
options.GraphOptimizationLevel = 1;
|
||||
options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_BASIC;
|
||||
var session = new InferenceSession(modelPath, options);
|
||||
Assert.NotNull(session);
|
||||
Assert.True(File.Exists(modelOutputPath));
|
||||
|
|
@ -792,7 +792,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
|
||||
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
|
||||
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSetSessionGraphOptimizationLevel",
|
||||
"OrtSetOptimizedModelFilePath", "OrtSessionOptionsAppendExecutionProvider_CPU",
|
||||
"OrtSetOptimizedModelFilePath", "OrtSessionOptionsAppendExecutionProvider_CPU",
|
||||
"OrtCreateRunOptions", "OrtReleaseRunOptions", "OrtRunOptionsSetRunLogVerbosityLevel", "OrtRunOptionsSetRunTag",
|
||||
"OrtRunOptionsGetRunLogVerbosityLevel", "OrtRunOptionsGetRunTag","OrtRunOptionsEnableTerminate", "OrtRunOptionsDisableTerminate",
|
||||
"OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
|
|||
public bool ParallelExecution { get; set; } = false;
|
||||
|
||||
[Option('o', "optimization_level", Required = false, HelpText = "Optimization Level. Default is 1, partial optimization.")]
|
||||
public uint OptimizationLevel { get; set; } = 1;
|
||||
public GraphOptimizationLevel OptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_BASIC;
|
||||
}
|
||||
|
||||
class Program
|
||||
|
|
@ -42,7 +42,8 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
|
|||
{
|
||||
var cmdOptions = Parser.Default.ParseArguments<CommandOptions>(args);
|
||||
cmdOptions.WithParsed(
|
||||
options => {
|
||||
options =>
|
||||
{
|
||||
Run(options);
|
||||
});
|
||||
}
|
||||
|
|
@ -52,7 +53,7 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
|
|||
string inputPath = options.InputFile;
|
||||
int iteration = options.IterationCount;
|
||||
bool parallelExecution = options.ParallelExecution;
|
||||
uint optLevel = options.OptimizationLevel;
|
||||
GraphOptimizationLevel optLevel = options.OptimizationLevel;
|
||||
Console.WriteLine("Running model {0} in OnnxRuntime:", modelPath);
|
||||
Console.WriteLine("input:{0}", inputPath);
|
||||
Console.WriteLine("iteration count:{0}", iteration);
|
||||
|
|
@ -84,11 +85,11 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
|
|||
return tensorData.ToArray();
|
||||
}
|
||||
|
||||
static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteration, DateTime[] timestamps, bool parallelExecution, uint optLevel)
|
||||
static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteration, DateTime[] timestamps, bool parallelExecution, GraphOptimizationLevel optLevel)
|
||||
{
|
||||
if (timestamps.Length != (int)TimingPoint.TotalCount)
|
||||
{
|
||||
throw new ArgumentException("Timestamps array must have "+(int)TimingPoint.TotalCount+" size");
|
||||
throw new ArgumentException("Timestamps array must have " + (int)TimingPoint.TotalCount + " size");
|
||||
}
|
||||
|
||||
timestamps[(int)TimingPoint.Start] = DateTime.Now;
|
||||
|
|
@ -108,12 +109,12 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
|
|||
container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
timestamps[(int)TimingPoint.InputLoaded] = DateTime.Now;
|
||||
|
||||
// Run the inference
|
||||
for (int i=0; i < iteration; i++)
|
||||
for (int i = 0; i < iteration; i++)
|
||||
{
|
||||
var results = session.Run(container); // results is an IReadOnlyList<NamedOnnxValue> container
|
||||
Debug.Assert(results != null);
|
||||
|
|
@ -132,7 +133,7 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
|
|||
static void PrintUsage()
|
||||
{
|
||||
Console.WriteLine("Usage:\n"
|
||||
+"dotnet Microsoft.ML.OnnxRuntime.PerfTool <onnx-model-path> <input-file-path> <iteration-count>"
|
||||
+ "dotnet Microsoft.ML.OnnxRuntime.PerfTool <onnx-model-path> <input-file-path> <iteration-count>"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -234,11 +234,15 @@ ORT_API_STATUS(OrtSetSessionLogId, _Inout_ OrtSessionOptions* options, const cha
|
|||
ORT_API_STATUS(OrtSetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level);
|
||||
|
||||
// Set Graph optimization level.
|
||||
// Available options are : 0, 1, 2.
|
||||
// 0 -> Disable all optimizations
|
||||
// 1 -> Enable basic optimizations
|
||||
// 2 -> Enable all optimizations
|
||||
ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, int graph_optimization_level);
|
||||
// TODO Add documentation about which optimizations are enabled for each value.
|
||||
typedef enum GraphOptimizationLevel {
|
||||
ORT_DISABLE_ALL = 0,
|
||||
ORT_ENABLE_BASIC = 1,
|
||||
ORT_ENABLE_EXTENDED = 2,
|
||||
ORT_ENABLE_ALL = 99
|
||||
} GraphOptimizationLevel;
|
||||
ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options,
|
||||
GraphOptimizationLevel graph_optimization_level);
|
||||
|
||||
// How many threads in the session thread pool.
|
||||
ORT_API_STATUS(OrtSetSessionThreadPoolSize, _Inout_ OrtSessionOptions* options, int session_thread_pool_size);
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
|
|||
SessionOptions Clone() const;
|
||||
|
||||
SessionOptions& SetThreadPoolSize(int session_thread_pool_size);
|
||||
SessionOptions& SetGraphOptimizationLevel(int graph_optimization_level);
|
||||
SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
|
||||
|
||||
SessionOptions& EnableCpuMemArena();
|
||||
SessionOptions& DisableCpuMemArena();
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool
|
|||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(int graph_optimization_level) {
|
||||
inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
|
||||
ORT_THROW_ON_ERROR(OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level));
|
||||
return *this;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,4 +18,4 @@ __author__ = "Microsoft"
|
|||
from onnxruntime.capi import onnxruntime_validation
|
||||
onnxruntime_validation.check_distro_info()
|
||||
from onnxruntime.capi.session import InferenceSession
|
||||
from onnxruntime.capi._pybind_state import RunOptions, SessionOptions, set_default_logger_severity, get_device, NodeArg, ModelMetadata
|
||||
from onnxruntime.capi._pybind_state import RunOptions, SessionOptions, set_default_logger_severity, get_device, NodeArg, ModelMetadata, GraphOptimizationLevel
|
||||
|
|
|
|||
|
|
@ -101,14 +101,29 @@ ORT_API_STATUS_IMPL(OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* opti
|
|||
}
|
||||
|
||||
// Set Graph optimization level.
|
||||
// Available options are : 0, 1, 2.
|
||||
ORT_API_STATUS_IMPL(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, int graph_optimization_level) {
|
||||
ORT_API_STATUS_IMPL(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options,
|
||||
GraphOptimizationLevel graph_optimization_level) {
|
||||
if (graph_optimization_level < 0) {
|
||||
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid");
|
||||
}
|
||||
if (graph_optimization_level >= static_cast<int>(onnxruntime::TransformerLevel::MaxTransformerLevel))
|
||||
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid");
|
||||
options->value.graph_optimization_level = static_cast<onnxruntime::TransformerLevel>(graph_optimization_level);
|
||||
|
||||
switch (graph_optimization_level) {
|
||||
case ORT_DISABLE_ALL:
|
||||
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
|
||||
break;
|
||||
case ORT_ENABLE_BASIC:
|
||||
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
|
||||
break;
|
||||
case ORT_ENABLE_EXTENDED:
|
||||
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
|
||||
break;
|
||||
case ORT_ENABLE_ALL:
|
||||
options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
|
||||
break;
|
||||
default:
|
||||
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid");
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
#include <numpy/arrayobject.h>
|
||||
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/logging/severity.h"
|
||||
|
||||
#if USE_CUDA
|
||||
#define BACKEND_PROC "GPU"
|
||||
|
|
@ -379,7 +381,14 @@ void addOpSchemaSubmodule(py::module& m) {
|
|||
#endif //onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
||||
|
||||
void addObjectMethods(py::module& m) {
|
||||
py::class_<SessionOptions>(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc")
|
||||
py::enum_<GraphOptimizationLevel>(m, "GraphOptimizationLevel")
|
||||
.value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL)
|
||||
.value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC)
|
||||
.value("ORT_ENABLE_EXTENDED", GraphOptimizationLevel::ORT_ENABLE_EXTENDED)
|
||||
.value("ORT_ENABLE_ALL", GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
|
||||
py::class_<SessionOptions> sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
|
||||
sess
|
||||
.def(py::init())
|
||||
.def_readwrite("enable_cpu_mem_arena", &SessionOptions::enable_cpu_mem_arena,
|
||||
R"pbdoc(Enables the memory arena on CPU. Arena may pre-allocate memory for future usage.
|
||||
|
|
@ -407,17 +416,48 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc")
|
|||
This parameter is unused unless *enable_sequential_execution* is false.)pbdoc")
|
||||
.def_property_readonly(
|
||||
"graph_optimization_level",
|
||||
[](const SessionOptions* options) -> uint32_t {
|
||||
return static_cast<uint32_t>(options->graph_optimization_level);
|
||||
[](const SessionOptions* options) -> GraphOptimizationLevel {
|
||||
GraphOptimizationLevel retval = ORT_ENABLE_BASIC;
|
||||
switch (options->graph_optimization_level) {
|
||||
case onnxruntime::TransformerLevel::Default:
|
||||
retval = ORT_DISABLE_ALL;
|
||||
break;
|
||||
case onnxruntime::TransformerLevel::Level1:
|
||||
retval = ORT_ENABLE_BASIC;
|
||||
break;
|
||||
case onnxruntime::TransformerLevel::Level2:
|
||||
retval = ORT_ENABLE_EXTENDED;
|
||||
break;
|
||||
case onnxruntime::TransformerLevel::Level3:
|
||||
retval = ORT_ENABLE_ALL;
|
||||
break;
|
||||
default:
|
||||
retval = ORT_ENABLE_BASIC;
|
||||
LOGS_DEFAULT(WARNING) << "Got invalid graph optimization level; defaulting to ORT_ENABLE_BASIC";
|
||||
break;
|
||||
}
|
||||
return retval;
|
||||
},
|
||||
R"pbdoc(Graph optimization level for this session.)pbdoc")
|
||||
.def(
|
||||
"set_graph_optimization_level",
|
||||
[](SessionOptions* options, uint32_t level) -> void {
|
||||
options->graph_optimization_level = static_cast<TransformerLevel>(level);
|
||||
[](SessionOptions* options, GraphOptimizationLevel level) -> void {
|
||||
switch (level) {
|
||||
case ORT_DISABLE_ALL:
|
||||
options->graph_optimization_level = onnxruntime::TransformerLevel::Default;
|
||||
break;
|
||||
case ORT_ENABLE_BASIC:
|
||||
options->graph_optimization_level = onnxruntime::TransformerLevel::Level1;
|
||||
break;
|
||||
case ORT_ENABLE_EXTENDED:
|
||||
options->graph_optimization_level = onnxruntime::TransformerLevel::Level2;
|
||||
break;
|
||||
case ORT_ENABLE_ALL:
|
||||
options->graph_optimization_level = onnxruntime::TransformerLevel::Level3;
|
||||
break;
|
||||
}
|
||||
},
|
||||
R"pbdoc(Graph optimization level for this session. 0 disables all optimizations.
|
||||
Whereas 1 enables basic optimizations and 2 enables all optimizations.)pbdoc");
|
||||
R"pbdoc(Graph optimization level for this session.)pbdoc");
|
||||
|
||||
py::class_<RunOptions>(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc")
|
||||
.def(py::init())
|
||||
|
|
@ -450,57 +490,55 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
return *(na.Type());
|
||||
},
|
||||
"node type")
|
||||
.def(
|
||||
"__str__", [](const onnxruntime::NodeArg& na) -> std::string {
|
||||
std::ostringstream res;
|
||||
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
res << "[]";
|
||||
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
|
||||
std::ostringstream res;
|
||||
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
res << "[]";
|
||||
} else {
|
||||
res << "[";
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (shape->dim(i).has_dim_value()) {
|
||||
res << shape->dim(i).dim_value();
|
||||
} else if (shape->dim(i).has_dim_param()) {
|
||||
res << "'" << shape->dim(i).dim_param() << "'";
|
||||
} else {
|
||||
res << "[";
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (shape->dim(i).has_dim_value()) {
|
||||
res << shape->dim(i).dim_value();
|
||||
} else if (shape->dim(i).has_dim_param()) {
|
||||
res << "'" << shape->dim(i).dim_param() << "'";
|
||||
} else {
|
||||
res << "None";
|
||||
}
|
||||
|
||||
if (i < shape->dim_size() - 1) {
|
||||
res << ", ";
|
||||
}
|
||||
}
|
||||
res << "]";
|
||||
}
|
||||
res << ")";
|
||||
|
||||
return std::string(res.str());
|
||||
},
|
||||
"converts the node into a readable string")
|
||||
.def_property_readonly(
|
||||
"shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
return arr;
|
||||
res << "None";
|
||||
}
|
||||
|
||||
arr.resize(shape->dim_size());
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (shape->dim(i).has_dim_value()) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_value());
|
||||
} else if (shape->dim(i).has_dim_param()) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_param());
|
||||
} else {
|
||||
arr[i] = py::none();
|
||||
}
|
||||
if (i < shape->dim_size() - 1) {
|
||||
res << ", ";
|
||||
}
|
||||
return arr;
|
||||
},
|
||||
"node shape (assuming the node holds a tensor)");
|
||||
}
|
||||
res << "]";
|
||||
}
|
||||
res << ")";
|
||||
|
||||
return std::string(res.str());
|
||||
},
|
||||
"converts the node into a readable string")
|
||||
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
return arr;
|
||||
}
|
||||
|
||||
arr.resize(shape->dim_size());
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (shape->dim(i).has_dim_value()) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_value());
|
||||
} else if (shape->dim(i).has_dim_param()) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_param());
|
||||
} else {
|
||||
arr[i] = py::none();
|
||||
}
|
||||
}
|
||||
return arr;
|
||||
},
|
||||
"node shape (assuming the node holds a tensor)");
|
||||
|
||||
py::class_<SessionObjectInitializer>(m, "SessionObjectInitializer");
|
||||
py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
|
||||
|
|
@ -515,16 +553,15 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
InitializeSession(sess);
|
||||
},
|
||||
R"pbdoc(Load a model saved in ONNX format.)pbdoc")
|
||||
.def(
|
||||
"read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel) {
|
||||
std::istringstream buffer(serializedModel);
|
||||
auto status = sess->Load(buffer);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ToString().c_str());
|
||||
}
|
||||
InitializeSession(sess);
|
||||
},
|
||||
R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
|
||||
.def("read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel) {
|
||||
std::istringstream buffer(serializedModel);
|
||||
auto status = sess->Load(buffer);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ToString().c_str());
|
||||
}
|
||||
InitializeSession(sess);
|
||||
},
|
||||
R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
|
||||
.def("run", [](InferenceSession* sess, std::vector<std::string> output_names, std::map<std::string, py::object> pyfeeds, RunOptions* run_options = nullptr) -> std::vector<py::object> {
|
||||
NameMLValMap feeds;
|
||||
for (auto _ : pyfeeds) {
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
bool enable_mem_pattern = true;
|
||||
bool enable_openvino = false;
|
||||
bool enable_nnapi = false;
|
||||
uint32_t graph_optimization_level{};
|
||||
GraphOptimizationLevel graph_optimization_level = ORT_DISABLE_ALL;
|
||||
bool user_graph_optimization_level_set = false;
|
||||
|
||||
OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING;
|
||||
|
|
@ -166,15 +166,29 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
case 'x':
|
||||
enable_sequential_execution = false;
|
||||
break;
|
||||
case 'o':
|
||||
graph_optimization_level = static_cast<uint32_t>(OrtStrtol<PATH_CHAR_TYPE>(optarg, nullptr));
|
||||
if (graph_optimization_level >= static_cast<uint32_t>(TransformerLevel::MaxTransformerLevel)) {
|
||||
fprintf(stderr, "See usage for valid values of graph optimization level\n");
|
||||
usage();
|
||||
return -1;
|
||||
case 'o': {
|
||||
int tmp = static_cast<int>(OrtStrtol<PATH_CHAR_TYPE>(optarg, nullptr));
|
||||
switch (tmp) {
|
||||
case ORT_DISABLE_ALL:
|
||||
graph_optimization_level = ORT_DISABLE_ALL;
|
||||
break;
|
||||
case ORT_ENABLE_BASIC:
|
||||
graph_optimization_level = ORT_ENABLE_BASIC;
|
||||
break;
|
||||
case ORT_ENABLE_EXTENDED:
|
||||
graph_optimization_level = ORT_ENABLE_EXTENDED;
|
||||
break;
|
||||
case ORT_ENABLE_ALL:
|
||||
graph_optimization_level = ORT_ENABLE_ALL;
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "See usage for valid values of graph optimization level\n");
|
||||
usage();
|
||||
return -1;
|
||||
}
|
||||
user_graph_optimization_level_set = true;
|
||||
break;
|
||||
}
|
||||
case '?':
|
||||
case 'h':
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ namespace perftest {
|
|||
} else if (!CompareCString(optarg, ORT_TSTR("openvino"))) {
|
||||
test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) {
|
||||
test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider;
|
||||
test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -128,12 +128,26 @@ namespace perftest {
|
|||
return false;
|
||||
}
|
||||
break;
|
||||
case 'o':
|
||||
test_config.run_config.optimization_level = static_cast<uint32_t>(OrtStrtol<PATH_CHAR_TYPE>(optarg, nullptr));
|
||||
if (test_config.run_config.optimization_level >= static_cast<uint32_t>(TransformerLevel::MaxTransformerLevel)) {
|
||||
return false;
|
||||
case 'o': {
|
||||
int tmp = static_cast<int>(OrtStrtol<PATH_CHAR_TYPE>(optarg, nullptr));
|
||||
switch (tmp) {
|
||||
case ORT_DISABLE_ALL:
|
||||
test_config.run_config.optimization_level = ORT_DISABLE_ALL;
|
||||
break;
|
||||
case ORT_ENABLE_BASIC:
|
||||
test_config.run_config.optimization_level = ORT_ENABLE_BASIC;
|
||||
break;
|
||||
case ORT_ENABLE_EXTENDED:
|
||||
test_config.run_config.optimization_level = ORT_ENABLE_EXTENDED;
|
||||
break;
|
||||
case ORT_ENABLE_ALL:
|
||||
test_config.run_config.optimization_level = ORT_ENABLE_ALL;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case '?':
|
||||
case 'h':
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ struct RunConfig {
|
|||
bool enable_cpu_mem_arena{true};
|
||||
bool enable_sequential_execution{true};
|
||||
int session_thread_pool_size{0};
|
||||
uint32_t optimization_level{2};
|
||||
GraphOptimizationLevel optimization_level{ORT_ENABLE_EXTENDED};
|
||||
};
|
||||
|
||||
struct PerformanceTestConfig {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import numpy as np
|
|||
import onnxruntime as onnxrt
|
||||
import threading
|
||||
|
||||
|
||||
class TestInferenceSession(unittest.TestCase):
|
||||
|
||||
def get_name(self, name):
|
||||
|
|
@ -494,6 +493,12 @@ class TestInferenceSession(unittest.TestCase):
|
|||
total = mat.sum()
|
||||
self.assertEqual(total, 0)
|
||||
|
||||
def testGraphOptimizationLevel(self):
|
||||
sess = onnxrt.InferenceSession(self.get_name("logicaland.onnx"))
|
||||
sess.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
self.assertEqual(sess.graph_optimization_level,
|
||||
onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -9,15 +9,6 @@ using namespace onnxruntime;
|
|||
|
||||
TEST_F(CApiTest, session_options_graph_optimization_level) {
|
||||
// Test set optimization level succeeds when valid level is provided.
|
||||
uint32_t valid_optimization_level = static_cast<uint32_t>(TransformerLevel::Level2);
|
||||
Ort::SessionOptions options;
|
||||
options.SetGraphOptimizationLevel(valid_optimization_level);
|
||||
|
||||
// Test set optimization level fails when invalid level is provided.
|
||||
try {
|
||||
uint32_t invalid_level = static_cast<uint32_t>(TransformerLevel::MaxTransformerLevel);
|
||||
options.SetGraphOptimizationLevel(invalid_level);
|
||||
} catch (const Ort::Exception& e) {
|
||||
ASSERT_EQ(e.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue