diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs b/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs
index f794ed4bda..865acb9123 100644
--- a/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs
+++ b/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs
@@ -26,7 +26,7 @@ namespace CSharpUsage
// Optional : Create session options and set the graph optimization level for the session
SessionOptions options = new SessionOptions();
- options.GraphOptimizationLevel = 2;
+ options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;
using (var session = new InferenceSession(modelPath, options))
{
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 5b9ed79559..e1b57b4acf 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -161,7 +161,7 @@ namespace Microsoft.ML.OnnxRuntime
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize);
[DllImport(nativeLib, CharSet = charSet)]
- public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel);
+ public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, GraphOptimizationLevel graphOptimizationLevel);
///**
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
index 4ac90200f4..ad6b7a0a22 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
@@ -8,6 +8,17 @@ using System.IO;
namespace Microsoft.ML.OnnxRuntime
{
+ ///
+ /// TODO Add documentation about which optimizations are enabled for each value.
+ ///
+ public enum GraphOptimizationLevel
+ {
+ ORT_DISABLE_ALL = 0,
+ ORT_ENABLE_BASIC = 1,
+ ORT_ENABLE_EXTENDED = 2,
+ ORT_ENABLE_ALL = 99
+ }
+
///
/// Holds the options for creating an InferenceSession
///
@@ -117,7 +128,7 @@ namespace Microsoft.ML.OnnxRuntime
}
private bool _enableMemoryPattern = true;
-
+
///
/// Path prefix to use for output of profiling data
///
@@ -158,7 +169,7 @@ namespace Microsoft.ML.OnnxRuntime
///
public string OptimizedModelFilePath
{
- get
+ get
{
return _optimizedModelFilePath;
}
@@ -174,7 +185,7 @@ namespace Microsoft.ML.OnnxRuntime
private string _optimizedModelFilePath = "";
-
+
///
/// Enables Arena allocator for the CPU memory allocations. Default is true.
///
@@ -190,7 +201,7 @@ namespace Microsoft.ML.OnnxRuntime
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableCpuMemArena(_nativePtr));
_enableCpuMemArena = true;
- }
+ }
else if (_enableCpuMemArena && !value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableCpuMemArena(_nativePtr));
@@ -259,13 +270,9 @@ namespace Microsoft.ML.OnnxRuntime
///
- /// Sets the graph optimization level for the session. Default is set to 1.
+ /// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_BASIC.
///
- /// Available options are : 0, 1, 2
- /// 0 -> Disable all optimizations
- /// 1 -> Enable basic optimizations
- /// 2 -> Enable all optimizations
- public uint GraphOptimizationLevel
+ public GraphOptimizationLevel GraphOptimizationLevel
{
get
{
@@ -277,7 +284,7 @@ namespace Microsoft.ML.OnnxRuntime
_graphOptimizationLevel = value;
}
}
- private uint _graphOptimizationLevel = 1;
+ private GraphOptimizationLevel _graphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_BASIC;
#endregion
@@ -298,16 +305,16 @@ namespace Microsoft.ML.OnnxRuntime
{
IntPtr handle = LoadLibrary(dll);
if (handle != IntPtr.Zero)
- continue;
+ continue;
var sysdir = new StringBuilder(String.Empty, 2048);
GetSystemDirectory(sysdir, (uint)sysdir.Capacity);
throw new OnnxRuntimeException(
- ErrorCode.NoSuchFile,
+ ErrorCode.NoSuchFile,
$"kernel32.LoadLibrary():'{dll}' not found. CUDA is required for GPU execution. " +
$". Verify it is available in the system directory={sysdir}. Else copy it to the output folder."
- );
+ );
}
- }
+ }
return true;
}
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp
index 11dae1ab52..974a636389 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp
@@ -34,11 +34,7 @@ int main(int argc, char* argv[]) {
OrtSetSessionThreadPoolSize(session_options, 1);
// Sets graph optimization level
- // Available levels are
- // 0 -> To disable all optimizations
- // 1 -> To enable basic optimizations (Such as redundant node removals)
- // 2 -> To enable all optimizations (Includes level 1 + more complex optimizations like node fusions)
- OrtSetSessionGraphOptimizationLevel(session_options, 1);
+ OrtSetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC);
// Optionally add more execution providers via session_options
// E.g. for CUDA include cuda_provider_factory.h and uncomment the following line:
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index 96ccddc975..e1b9a4f568 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -34,7 +34,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.Equal("", opt.LogId);
Assert.Equal(LogLevel.Verbose, opt.LogVerbosityLevel);
Assert.Equal(0, opt.ThreadPoolSize);
- Assert.Equal(1u, opt.GraphOptimizationLevel);
+ Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_BASIC, opt.GraphOptimizationLevel);
// try setting options
opt.EnableSequentialExecution = false;
@@ -62,12 +62,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
opt.ThreadPoolSize = 4;
Assert.Equal(4, opt.ThreadPoolSize);
- opt.GraphOptimizationLevel = 3;
- Assert.Equal(3u, opt.GraphOptimizationLevel);
+ opt.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;
+ Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, opt.GraphOptimizationLevel);
Assert.Throws(() => { opt.ThreadPoolSize = -2; });
- Assert.Throws(() => { opt.GraphOptimizationLevel = 10; });
-
+ Assert.Throws(() => { opt.GraphOptimizationLevel = (GraphOptimizationLevel)10; });
+
}
}
@@ -129,11 +129,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
[Theory]
- [InlineData(0, true)]
- [InlineData(0, false)]
- [InlineData(2, true)]
- [InlineData(2, false)]
- private void CanRunInferenceOnAModel(uint graphOptimizationLevel, bool disableSequentialExecution)
+ [InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, true)]
+ [InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, false)]
+ [InlineData(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, true)]
+ [InlineData(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, false)]
+ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLevel, bool disableSequentialExecution)
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
@@ -743,7 +743,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
// Set the optimized model file path to assert that no exception are thrown.
SessionOptions options = new SessionOptions();
options.OptimizedModelFilePath = modelOutputPath;
- options.GraphOptimizationLevel = 1;
+ options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_BASIC;
var session = new InferenceSession(modelPath, options);
Assert.NotNull(session);
Assert.True(File.Exists(modelOutputPath));
@@ -792,7 +792,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSetSessionGraphOptimizationLevel",
- "OrtSetOptimizedModelFilePath", "OrtSessionOptionsAppendExecutionProvider_CPU",
+ "OrtSetOptimizedModelFilePath", "OrtSessionOptionsAppendExecutionProvider_CPU",
"OrtCreateRunOptions", "OrtReleaseRunOptions", "OrtRunOptionsSetRunLogVerbosityLevel", "OrtRunOptionsSetRunTag",
"OrtRunOptionsGetRunLogVerbosityLevel", "OrtRunOptionsGetRunTag","OrtRunOptionsEnableTerminate", "OrtRunOptionsDisableTerminate",
"OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs
index 1764f79d19..9444967705 100644
--- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs
+++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs
@@ -33,7 +33,7 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
public bool ParallelExecution { get; set; } = false;
[Option('o', "optimization_level", Required = false, HelpText = "Optimization Level. Default is 1, partial optimization.")]
- public uint OptimizationLevel { get; set; } = 1;
+ public GraphOptimizationLevel OptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_BASIC;
}
class Program
@@ -42,7 +42,8 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
{
var cmdOptions = Parser.Default.ParseArguments(args);
cmdOptions.WithParsed(
- options => {
+ options =>
+ {
Run(options);
});
}
@@ -52,7 +53,7 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
string inputPath = options.InputFile;
int iteration = options.IterationCount;
bool parallelExecution = options.ParallelExecution;
- uint optLevel = options.OptimizationLevel;
+ GraphOptimizationLevel optLevel = options.OptimizationLevel;
Console.WriteLine("Running model {0} in OnnxRuntime:", modelPath);
Console.WriteLine("input:{0}", inputPath);
Console.WriteLine("iteration count:{0}", iteration);
@@ -84,11 +85,11 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
return tensorData.ToArray();
}
- static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteration, DateTime[] timestamps, bool parallelExecution, uint optLevel)
+ static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteration, DateTime[] timestamps, bool parallelExecution, GraphOptimizationLevel optLevel)
{
if (timestamps.Length != (int)TimingPoint.TotalCount)
{
- throw new ArgumentException("Timestamps array must have "+(int)TimingPoint.TotalCount+" size");
+ throw new ArgumentException("Timestamps array must have " + (int)TimingPoint.TotalCount + " size");
}
timestamps[(int)TimingPoint.Start] = DateTime.Now;
@@ -108,12 +109,12 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
container.Add(NamedOnnxValue.CreateFromTensor(name, tensor));
}
-
+
timestamps[(int)TimingPoint.InputLoaded] = DateTime.Now;
// Run the inference
- for (int i=0; i < iteration; i++)
+ for (int i = 0; i < iteration; i++)
{
var results = session.Run(container); // results is an IReadOnlyList container
Debug.Assert(results != null);
@@ -132,7 +133,7 @@ namespace Microsoft.ML.OnnxRuntime.PerfTool
static void PrintUsage()
{
Console.WriteLine("Usage:\n"
- +"dotnet Microsoft.ML.OnnxRuntime.PerfTool "
+ + "dotnet Microsoft.ML.OnnxRuntime.PerfTool "
);
}
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index f1cd4638c4..2ff53c9b50 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -234,11 +234,15 @@ ORT_API_STATUS(OrtSetSessionLogId, _Inout_ OrtSessionOptions* options, const cha
ORT_API_STATUS(OrtSetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level);
// Set Graph optimization level.
-// Available options are : 0, 1, 2.
-// 0 -> Disable all optimizations
-// 1 -> Enable basic optimizations
-// 2 -> Enable all optimizations
-ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, int graph_optimization_level);
+// TODO Add documentation about which optimizations are enabled for each value.
+typedef enum GraphOptimizationLevel {
+ ORT_DISABLE_ALL = 0,
+ ORT_ENABLE_BASIC = 1,
+ ORT_ENABLE_EXTENDED = 2,
+ ORT_ENABLE_ALL = 99
+} GraphOptimizationLevel;
+ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options,
+ GraphOptimizationLevel graph_optimization_level);
// How many threads in the session thread pool.
ORT_API_STATUS(OrtSetSessionThreadPoolSize, _Inout_ OrtSessionOptions* options, int session_thread_pool_size);
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index f5757c7940..e1397105c3 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -135,7 +135,7 @@ struct SessionOptions : Base {
SessionOptions Clone() const;
SessionOptions& SetThreadPoolSize(int session_thread_pool_size);
- SessionOptions& SetGraphOptimizationLevel(int graph_optimization_level);
+ SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
SessionOptions& EnableCpuMemArena();
SessionOptions& DisableCpuMemArena();
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 9694603ad1..5e7a499a24 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -138,7 +138,7 @@ inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool
return *this;
}
-inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(int graph_optimization_level) {
+inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
ORT_THROW_ON_ERROR(OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level));
return *this;
}
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index 29e8f5fb33..ed7908896d 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -18,4 +18,4 @@ __author__ = "Microsoft"
from onnxruntime.capi import onnxruntime_validation
onnxruntime_validation.check_distro_info()
from onnxruntime.capi.session import InferenceSession
-from onnxruntime.capi._pybind_state import RunOptions, SessionOptions, set_default_logger_severity, get_device, NodeArg, ModelMetadata
+from onnxruntime.capi._pybind_state import RunOptions, SessionOptions, set_default_logger_severity, get_device, NodeArg, ModelMetadata, GraphOptimizationLevel
diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc
index a3be9e8f59..647e5c3722 100644
--- a/onnxruntime/core/session/abi_session_options.cc
+++ b/onnxruntime/core/session/abi_session_options.cc
@@ -101,14 +101,29 @@ ORT_API_STATUS_IMPL(OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* opti
}
// Set Graph optimization level.
-// Available options are : 0, 1, 2.
-ORT_API_STATUS_IMPL(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, int graph_optimization_level) {
+ORT_API_STATUS_IMPL(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options,
+ GraphOptimizationLevel graph_optimization_level) {
if (graph_optimization_level < 0) {
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid");
}
- if (graph_optimization_level >= static_cast(onnxruntime::TransformerLevel::MaxTransformerLevel))
- return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid");
- options->value.graph_optimization_level = static_cast(graph_optimization_level);
+
+ switch (graph_optimization_level) {
+ case ORT_DISABLE_ALL:
+ options->value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
+ break;
+ case ORT_ENABLE_BASIC:
+ options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
+ break;
+ case ORT_ENABLE_EXTENDED:
+ options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
+ break;
+ case ORT_ENABLE_ALL:
+ options->value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
+ break;
+ default:
+ return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid");
+ }
+
return nullptr;
}
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 4aebcc527f..b22b274518 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -8,6 +8,8 @@
#include
#include "core/graph/graph_viewer.h"
+#include "core/common/logging/logging.h"
+#include "core/common/logging/severity.h"
#if USE_CUDA
#define BACKEND_PROC "GPU"
@@ -379,7 +381,14 @@ void addOpSchemaSubmodule(py::module& m) {
#endif //onnxruntime_PYBIND_EXPORT_OPSCHEMA
void addObjectMethods(py::module& m) {
- py::class_(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc")
+ py::enum_(m, "GraphOptimizationLevel")
+ .value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL)
+ .value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC)
+ .value("ORT_ENABLE_EXTENDED", GraphOptimizationLevel::ORT_ENABLE_EXTENDED)
+ .value("ORT_ENABLE_ALL", GraphOptimizationLevel::ORT_ENABLE_ALL);
+
+ py::class_ sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
+ sess
.def(py::init())
.def_readwrite("enable_cpu_mem_arena", &SessionOptions::enable_cpu_mem_arena,
R"pbdoc(Enables the memory arena on CPU. Arena may pre-allocate memory for future usage.
@@ -407,17 +416,48 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc")
This parameter is unused unless *enable_sequential_execution* is false.)pbdoc")
.def_property_readonly(
"graph_optimization_level",
- [](const SessionOptions* options) -> uint32_t {
- return static_cast(options->graph_optimization_level);
+ [](const SessionOptions* options) -> GraphOptimizationLevel {
+ GraphOptimizationLevel retval = ORT_ENABLE_BASIC;
+ switch (options->graph_optimization_level) {
+ case onnxruntime::TransformerLevel::Default:
+ retval = ORT_DISABLE_ALL;
+ break;
+ case onnxruntime::TransformerLevel::Level1:
+ retval = ORT_ENABLE_BASIC;
+ break;
+ case onnxruntime::TransformerLevel::Level2:
+ retval = ORT_ENABLE_EXTENDED;
+ break;
+ case onnxruntime::TransformerLevel::Level3:
+ retval = ORT_ENABLE_ALL;
+ break;
+ default:
+ retval = ORT_ENABLE_BASIC;
+ LOGS_DEFAULT(WARNING) << "Got invalid graph optimization level; defaulting to ORT_ENABLE_BASIC";
+ break;
+ }
+ return retval;
},
R"pbdoc(Graph optimization level for this session.)pbdoc")
.def(
"set_graph_optimization_level",
- [](SessionOptions* options, uint32_t level) -> void {
- options->graph_optimization_level = static_cast(level);
+ [](SessionOptions* options, GraphOptimizationLevel level) -> void {
+ switch (level) {
+ case ORT_DISABLE_ALL:
+ options->graph_optimization_level = onnxruntime::TransformerLevel::Default;
+ break;
+ case ORT_ENABLE_BASIC:
+ options->graph_optimization_level = onnxruntime::TransformerLevel::Level1;
+ break;
+ case ORT_ENABLE_EXTENDED:
+ options->graph_optimization_level = onnxruntime::TransformerLevel::Level2;
+ break;
+ case ORT_ENABLE_ALL:
+ options->graph_optimization_level = onnxruntime::TransformerLevel::Level3;
+ break;
+ }
},
- R"pbdoc(Graph optimization level for this session. 0 disables all optimizations.
-Whereas 1 enables basic optimizations and 2 enables all optimizations.)pbdoc");
+ R"pbdoc(Graph optimization level for this session.)pbdoc");
py::class_(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc")
.def(py::init())
@@ -450,57 +490,55 @@ including arg name, arg type (contains both type and shape).)pbdoc")
return *(na.Type());
},
"node type")
- .def(
- "__str__", [](const onnxruntime::NodeArg& na) -> std::string {
- std::ostringstream res;
- res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
- auto shape = na.Shape();
- std::vector arr;
- if (shape == nullptr || shape->dim_size() == 0) {
- res << "[]";
+ .def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
+ std::ostringstream res;
+ res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
+ auto shape = na.Shape();
+ std::vector arr;
+ if (shape == nullptr || shape->dim_size() == 0) {
+ res << "[]";
+ } else {
+ res << "[";
+ for (int i = 0; i < shape->dim_size(); ++i) {
+ if (shape->dim(i).has_dim_value()) {
+ res << shape->dim(i).dim_value();
+ } else if (shape->dim(i).has_dim_param()) {
+ res << "'" << shape->dim(i).dim_param() << "'";
} else {
- res << "[";
- for (int i = 0; i < shape->dim_size(); ++i) {
- if (shape->dim(i).has_dim_value()) {
- res << shape->dim(i).dim_value();
- } else if (shape->dim(i).has_dim_param()) {
- res << "'" << shape->dim(i).dim_param() << "'";
- } else {
- res << "None";
- }
-
- if (i < shape->dim_size() - 1) {
- res << ", ";
- }
- }
- res << "]";
- }
- res << ")";
-
- return std::string(res.str());
- },
- "converts the node into a readable string")
- .def_property_readonly(
- "shape", [](const onnxruntime::NodeArg& na) -> std::vector {
- auto shape = na.Shape();
- std::vector arr;
- if (shape == nullptr || shape->dim_size() == 0) {
- return arr;
+ res << "None";
}
- arr.resize(shape->dim_size());
- for (int i = 0; i < shape->dim_size(); ++i) {
- if (shape->dim(i).has_dim_value()) {
- arr[i] = py::cast(shape->dim(i).dim_value());
- } else if (shape->dim(i).has_dim_param()) {
- arr[i] = py::cast(shape->dim(i).dim_param());
- } else {
- arr[i] = py::none();
- }
+ if (i < shape->dim_size() - 1) {
+ res << ", ";
}
- return arr;
- },
- "node shape (assuming the node holds a tensor)");
+ }
+ res << "]";
+ }
+ res << ")";
+
+ return std::string(res.str());
+ },
+ "converts the node into a readable string")
+ .def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector {
+ auto shape = na.Shape();
+ std::vector arr;
+ if (shape == nullptr || shape->dim_size() == 0) {
+ return arr;
+ }
+
+ arr.resize(shape->dim_size());
+ for (int i = 0; i < shape->dim_size(); ++i) {
+ if (shape->dim(i).has_dim_value()) {
+ arr[i] = py::cast(shape->dim(i).dim_value());
+ } else if (shape->dim(i).has_dim_param()) {
+ arr[i] = py::cast(shape->dim(i).dim_param());
+ } else {
+ arr[i] = py::none();
+ }
+ }
+ return arr;
+ },
+ "node shape (assuming the node holds a tensor)");
py::class_(m, "SessionObjectInitializer");
py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
@@ -515,16 +553,15 @@ including arg name, arg type (contains both type and shape).)pbdoc")
InitializeSession(sess);
},
R"pbdoc(Load a model saved in ONNX format.)pbdoc")
- .def(
- "read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel) {
- std::istringstream buffer(serializedModel);
- auto status = sess->Load(buffer);
- if (!status.IsOK()) {
- throw std::runtime_error(status.ToString().c_str());
- }
- InitializeSession(sess);
- },
- R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
+ .def("read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel) {
+ std::istringstream buffer(serializedModel);
+ auto status = sess->Load(buffer);
+ if (!status.IsOK()) {
+ throw std::runtime_error(status.ToString().c_str());
+ }
+ InitializeSession(sess);
+ },
+ R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
.def("run", [](InferenceSession* sess, std::vector output_names, std::map pyfeeds, RunOptions* run_options = nullptr) -> std::vector {
NameMLValMap feeds;
for (auto _ : pyfeeds) {
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index 45950799d8..2f9595df6c 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -98,7 +98,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
bool enable_mem_pattern = true;
bool enable_openvino = false;
bool enable_nnapi = false;
- uint32_t graph_optimization_level{};
+ GraphOptimizationLevel graph_optimization_level = ORT_DISABLE_ALL;
bool user_graph_optimization_level_set = false;
OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING;
@@ -166,15 +166,29 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
case 'x':
enable_sequential_execution = false;
break;
- case 'o':
- graph_optimization_level = static_cast(OrtStrtol(optarg, nullptr));
- if (graph_optimization_level >= static_cast(TransformerLevel::MaxTransformerLevel)) {
- fprintf(stderr, "See usage for valid values of graph optimization level\n");
- usage();
- return -1;
+ case 'o': {
+ int tmp = static_cast(OrtStrtol(optarg, nullptr));
+ switch (tmp) {
+ case ORT_DISABLE_ALL:
+ graph_optimization_level = ORT_DISABLE_ALL;
+ break;
+ case ORT_ENABLE_BASIC:
+ graph_optimization_level = ORT_ENABLE_BASIC;
+ break;
+ case ORT_ENABLE_EXTENDED:
+ graph_optimization_level = ORT_ENABLE_EXTENDED;
+ break;
+ case ORT_ENABLE_ALL:
+ graph_optimization_level = ORT_ENABLE_ALL;
+ break;
+ default:
+ fprintf(stderr, "See usage for valid values of graph optimization level\n");
+ usage();
+ return -1;
}
user_graph_optimization_level_set = true;
break;
+ }
case '?':
case 'h':
default:
diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc
index 849e4f1723..de492fa87a 100644
--- a/onnxruntime/test/perftest/command_args_parser.cc
+++ b/onnxruntime/test/perftest/command_args_parser.cc
@@ -87,7 +87,7 @@ namespace perftest {
} else if (!CompareCString(optarg, ORT_TSTR("openvino"))) {
test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) {
- test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider;
+ test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider;
} else {
return false;
}
@@ -128,12 +128,26 @@ namespace perftest {
return false;
}
break;
- case 'o':
- test_config.run_config.optimization_level = static_cast(OrtStrtol(optarg, nullptr));
- if (test_config.run_config.optimization_level >= static_cast(TransformerLevel::MaxTransformerLevel)) {
- return false;
+ case 'o': {
+ int tmp = static_cast(OrtStrtol(optarg, nullptr));
+ switch (tmp) {
+ case ORT_DISABLE_ALL:
+ test_config.run_config.optimization_level = ORT_DISABLE_ALL;
+ break;
+ case ORT_ENABLE_BASIC:
+ test_config.run_config.optimization_level = ORT_ENABLE_BASIC;
+ break;
+ case ORT_ENABLE_EXTENDED:
+ test_config.run_config.optimization_level = ORT_ENABLE_EXTENDED;
+ break;
+ case ORT_ENABLE_ALL:
+ test_config.run_config.optimization_level = ORT_ENABLE_ALL;
+ break;
+ default:
+ return false;
}
break;
+ }
case '?':
case 'h':
default:
diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h
index 9c37937425..95f793bc94 100644
--- a/onnxruntime/test/perftest/test_configuration.h
+++ b/onnxruntime/test/perftest/test_configuration.h
@@ -45,7 +45,7 @@ struct RunConfig {
bool enable_cpu_mem_arena{true};
bool enable_sequential_execution{true};
int session_thread_pool_size{0};
- uint32_t optimization_level{2};
+ GraphOptimizationLevel optimization_level{ORT_ENABLE_EXTENDED};
};
struct PerformanceTestConfig {
diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py
index 87102060d5..ca2715aba9 100644
--- a/onnxruntime/test/python/onnxruntime_test_python.py
+++ b/onnxruntime/test/python/onnxruntime_test_python.py
@@ -8,7 +8,6 @@ import numpy as np
import onnxruntime as onnxrt
import threading
-
class TestInferenceSession(unittest.TestCase):
def get_name(self, name):
@@ -494,6 +493,12 @@ class TestInferenceSession(unittest.TestCase):
total = mat.sum()
self.assertEqual(total, 0)
+ def testGraphOptimizationLevel(self):
+ sess = onnxrt.InferenceSession(self.get_name("logicaland.onnx"))
+ sess.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL
+ self.assertEqual(sess.graph_optimization_level,
+ onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL)
+
if __name__ == '__main__':
unittest.main()
diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc
index 60fbc6b819..258bd62dc3 100644
--- a/onnxruntime/test/shared_lib/test_session_options.cc
+++ b/onnxruntime/test/shared_lib/test_session_options.cc
@@ -9,15 +9,6 @@ using namespace onnxruntime;
TEST_F(CApiTest, session_options_graph_optimization_level) {
// Test set optimization level succeeds when valid level is provided.
- uint32_t valid_optimization_level = static_cast(TransformerLevel::Level2);
Ort::SessionOptions options;
- options.SetGraphOptimizationLevel(valid_optimization_level);
-
- // Test set optimization level fails when invalid level is provided.
- try {
- uint32_t invalid_level = static_cast(TransformerLevel::MaxTransformerLevel);
- options.SetGraphOptimizationLevel(invalid_level);
- } catch (const Ort::Exception& e) {
- ASSERT_EQ(e.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
- }
+ options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
}