diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 4c213ec66d..140516a630 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -130,6 +130,9 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtDisableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options);
+ [DllImport(nativeLib, CharSet = charSet)]
+ public static extern IntPtr /*(OrtStatus*)*/ OrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, [MarshalAs(UnmanagedType.LPWStr)]string optimizedModelFilepath);
+
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix);
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
index 4ce708687e..13a334e0fd 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
@@ -37,6 +37,15 @@ namespace Microsoft.ML.OnnxRuntime
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level));
}
+ ///
+ /// Set filepath to save optimized model after graph level transformations.
+ ///
+ /// File path for saving optimized model.
+ public void SetOptimizedModelFilePath(string optimizedModelFilepath)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, optimizedModelFilepath));
+ }
+
///
/// Enable Sequential Execution. By default, it is enabled.
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index 88bf5f83d4..471d7fb267 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -638,6 +638,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
+ [Fact]
+ private void TestModelSerialization()
+ {
+ string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
+ string modelOutputPath = Path.Combine(Directory.GetCurrentDirectory(), "optimized-squeezenet.onnx");
+ // Set the optimized model file path to assert that no exception are thrown.
+ SessionOptions options = new SessionOptions();
+ options.SetOptimizedModelFilePath(modelOutputPath);
+ options.SetSessionGraphOptimizationLevel(1);
+ var session = new InferenceSession(modelPath, options);
+ Assert.NotNull(session);
+ Assert.True(File.Exists(modelOutputPath));
+ }
+
[GpuFact]
private void TestGpu()
{
@@ -678,7 +692,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSetSessionGraphOptimizationLevel",
- "OrtSessionOptionsAppendExecutionProvider_CPU","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
+ "OrtSetOptimizedModelFilePath", "OrtSessionOptionsAppendExecutionProvider_CPU","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
"OrtCreateDefaultAllocator","OrtAllocatorFree","OrtAllocatorGetInfo",
"OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo",
"OrtCastTypeInfoToTensorInfo","OrtGetTensorTypeAndShape","OrtGetTensorElementType","OrtGetDimensionsCount",
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index fad81a5359..f1cd4638c4 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -201,6 +201,9 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
*/
ORT_API_STATUS(OrtCreateSessionOptions, _Outptr_ OrtSessionOptions** options);
+// Set filepath to save optimized model after graph level transformations.
+ORT_API_STATUS(OrtSetOptimizedModelFilePath, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath);
+
// create a copy of an existing OrtSessionOptions
ORT_API_STATUS(OrtCloneSessionOptions, _In_ const OrtSessionOptions* in_options, _Outptr_ OrtSessionOptions** out_options);
ORT_API_STATUS(OrtEnableSequentialExecution, _Inout_ OrtSessionOptions* options);
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index e21e875967..f5757c7940 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -140,6 +140,8 @@ struct SessionOptions : Base {
SessionOptions& EnableCpuMemArena();
SessionOptions& DisableCpuMemArena();
+ SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
+
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
SessionOptions& DisableProfiling();
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 0fbbbde445..9694603ad1 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -143,6 +143,11 @@ inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(int graph_optim
return *this;
}
+inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
+ ORT_THROW_ON_ERROR(OrtSetOptimizedModelFilePath(p_, optimized_model_filepath));
+ return *this;
+}
+
inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
ORT_THROW_ON_ERROR(OrtEnableProfiling(p_, profile_file_prefix));
return *this;
diff --git a/onnxruntime/core/providers/cpu/symbols.txt b/onnxruntime/core/providers/cpu/symbols.txt
index 7ee1a19443..e6324e2b6d 100644
--- a/onnxruntime/core/providers/cpu/symbols.txt
+++ b/onnxruntime/core/providers/cpu/symbols.txt
@@ -80,5 +80,6 @@ OrtSetDimensions
OrtSetSessionGraphOptimizationLevel
OrtSetSessionLogId
OrtSetSessionLogVerbosityLevel
+OrtSetOptimizedModelFilePath
OrtSetSessionThreadPoolSize
OrtSetTensorElementType
diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc
index 710ab2db81..a3be9e8f59 100644
--- a/onnxruntime/core/session/abi_session_options.cc
+++ b/onnxruntime/core/session/abi_session_options.cc
@@ -44,6 +44,12 @@ ORT_API_STATUS_IMPL(OrtDisableSequentialExecution, _In_ OrtSessionOptions* optio
return nullptr;
}
+// set filepath to save optimized onnx model.
+ORT_API_STATUS_IMPL(OrtSetOptimizedModelFilePath, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath) {
+ options->value.optimized_model_filepath = optimized_model_filepath;
+ return nullptr;
+}
+
// enable profiling for this session.
ORT_API_STATUS_IMPL(OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix) {
options->value.enable_profiling = true;
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index 10ceef943a..678163f92c 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -528,6 +528,16 @@ common::Status InferenceSession::Initialize() {
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR(graph.Resolve());
+ if (!session_options_.optimized_model_filepath.empty()) {
+ if (session_options_.graph_optimization_level < TransformerLevel::Level3) {
+ // Serialize optimized ONNX model.
+ ORT_RETURN_IF_ERROR(Model::Save(*model_, session_options_.optimized_model_filepath));
+ } else {
+ LOGS(*session_logger_, WARNING) << "Serializing Optimized ONNX model with Graph Optimization"
+ " level greater than 2 is not supported.";
+ }
+ }
+
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution));
ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr));
diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h
index 761c121e95..f8bd6cf57b 100644
--- a/onnxruntime/core/session/inference_session.h
+++ b/onnxruntime/core/session/inference_session.h
@@ -56,6 +56,9 @@ struct SessionOptions {
// enable profiling for this session.
bool enable_profiling = false;
+ // non empty filepath enables serialization of the transformed optimized model to the specified filepath.
+ std::basic_string optimized_model_filepath;
+
// enable the memory pattern optimization.
// The idea is if the input shapes are the same, we could trace the internal memory allocation
// and generate a memory pattern for future request. So next time we could just do one allocation
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index d4416d7c9d..4aebcc527f 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -386,6 +386,8 @@ void addObjectMethods(py::module& m) {
Set this option to false if you don't want it. Default is True.)pbdoc")
.def_readwrite("enable_profiling", &SessionOptions::enable_profiling,
R"pbdoc(Enable profiling for this session. Default is false.)pbdoc")
+ .def_readwrite("optimized_model_filepath", &SessionOptions::optimized_model_filepath,
+ R"pbdoc(File path to serialize optimized model. By default, optimized model is not serialized if optimized_model_filepath is not provided.)pbdoc")
.def_readwrite("enable_mem_pattern", &SessionOptions::enable_mem_pattern,
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
.def_readwrite("enable_sequential_execution", &SessionOptions::enable_sequential_execution,
diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc
index 1c4cdad4f8..9d4f2e4193 100644
--- a/onnxruntime/test/framework/inference_session_test.cc
+++ b/onnxruntime/test/framework/inference_session_test.cc
@@ -118,6 +118,18 @@ class FuseExecutionProvider : public IExecutionProvider {
}
};
+// InferenceSession wrapper to expose loaded graph.
+class InferenceSessionGetGraphWrapper : public InferenceSession {
+ public:
+ explicit InferenceSessionGetGraphWrapper(const SessionOptions& session_options,
+ logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) {
+ }
+
+ const Graph& GetGraph() {
+ return model_->MainGraph();
+ }
+};
+
namespace test {
static void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims,
const std::vector& expected_values);
@@ -330,6 +342,77 @@ TEST(InferenceSessionTests, DisableCPUArena) {
RunModel(session_object, run_options);
}
+TEST(InferenceSessionTests, TestModelSerialization) {
+ // Load model with level 0 transform level
+ // and assert that the model has Identity nodes.
+ SessionOptions so;
+ const string test_model = "testdata/transform/abs-id-max.onnx";
+ so.session_logid = "InferenceSessionTests.TestModelSerialization";
+ so.graph_optimization_level = TransformerLevel::Default;
+ InferenceSessionGetGraphWrapper session_object_noopt{so, &DefaultLoggingManager()};
+ ASSERT_TRUE(session_object_noopt.Load(test_model).IsOK());
+ ASSERT_TRUE(session_object_noopt.Initialize().IsOK());
+
+ // Assert that model has Identity Nodes.
+ const auto& graph_noopt = session_object_noopt.GetGraph();
+ std::map op_to_count_noopt = CountOpsInGraph(graph_noopt);
+ ASSERT_TRUE(op_to_count_noopt["Identity"] > 0);
+
+ // Load model with level 1 transform level.
+ so.graph_optimization_level = TransformerLevel::Level1;
+ so.optimized_model_filepath = ToWideString(test_model + "-TransformLevel-" + std::to_string(static_cast(so.graph_optimization_level)));
+ InferenceSessionGetGraphWrapper session_object{so, &DefaultLoggingManager()};
+ ASSERT_TRUE(session_object.Load(test_model).IsOK());
+ ASSERT_TRUE(session_object.Initialize().IsOK());
+
+ // Assert that model has been transformed and identity Node is removed.
+ const auto& graph = session_object.GetGraph();
+ std::map op_to_count = CountOpsInGraph(graph);
+ ASSERT_TRUE(op_to_count["Identity"] == 0);
+
+ // Serialize model to the same file path again to make sure that rewrite doesn't fail.
+ InferenceSession overwrite_session_object{so, &DefaultLoggingManager()};
+ ASSERT_TRUE(overwrite_session_object.Load(test_model).IsOK());
+ ASSERT_TRUE(overwrite_session_object.Initialize().IsOK());
+
+ // Load serialized model with no transform level and serialize model.
+ SessionOptions so_opt;
+ so_opt.session_logid = "InferenceSessionTests.TestModelSerialization";
+ so_opt.graph_optimization_level = TransformerLevel::Default;
+ so_opt.optimized_model_filepath = ToWideString(so.optimized_model_filepath) + ToWideString("-TransformLevel-" + std::to_string(static_cast(so_opt.graph_optimization_level)));
+ InferenceSession session_object_opt{so_opt, &DefaultLoggingManager()};
+ ASSERT_TRUE(session_object_opt.Load(so.optimized_model_filepath).IsOK());
+ ASSERT_TRUE(session_object_opt.Initialize().IsOK());
+
+ // Assert that re-feed of optimized model with default transform level results
+ // in same runtime model as abs-id-max.onnx with TransformLevel-1.
+ std::ifstream model_fs_session1(so.optimized_model_filepath, ios::in | ios::binary);
+ ASSERT_TRUE(model_fs_session1.good());
+ std::ifstream model_fs_session2(so_opt.optimized_model_filepath, ios::in | ios::binary);
+ ASSERT_TRUE(model_fs_session2.good());
+ ASSERT_TRUE(model_fs_session1.tellg() == model_fs_session2.tellg());
+ model_fs_session1.seekg(0, std::ifstream::beg);
+ model_fs_session2.seekg(0, std::ifstream::beg);
+ ASSERT_TRUE(std::equal(std::istreambuf_iterator(model_fs_session1.rdbuf()),
+ std::istreambuf_iterator(),
+ std::istreambuf_iterator(model_fs_session2.rdbuf())));
+
+ // Assert that empty optimized model file-path doesn't fail loading.
+ so_opt.optimized_model_filepath = ToWideString("");
+ InferenceSession session_object_emptyValidation{so_opt, &DefaultLoggingManager()};
+ ASSERT_TRUE(session_object_emptyValidation.Load(test_model).IsOK());
+ ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK());
+
+ // Assert that level 3 optimization doesn't result in serialized model.
+ so_opt.optimized_model_filepath = ToWideString("ShouldNotSerialize");
+ so_opt.graph_optimization_level = TransformerLevel::Level3;
+ InferenceSession session_object_Level3Test{so_opt, &DefaultLoggingManager()};
+ ASSERT_TRUE(session_object_Level3Test.Load(test_model).IsOK());
+ ASSERT_TRUE(session_object_Level3Test.Initialize().IsOK());
+ std::ifstream model_fs_Level3(so_opt.optimized_model_filepath, ios::in | ios::binary);
+ ASSERT_TRUE(model_fs_Level3.fail());
+}
+
#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS
static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) {
if (f_arg.size() != s_arg.size()) {
diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py
index ce4da1a0a6..87102060d5 100644
--- a/onnxruntime/test/python/onnxruntime_test_python.py
+++ b/onnxruntime/test/python/onnxruntime_test_python.py
@@ -34,6 +34,14 @@ class TestInferenceSession(unittest.TestCase):
np.testing.assert_allclose(
output_expected, res[0], rtol=1e-05, atol=1e-08)
+ def testModelSerialization(self):
+ so = onnxrt.SessionOptions()
+ so.session_log_verbosity_level = 1
+ so.session_logid = "TestModelSerialization"
+ so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx"
+ onnxrt.InferenceSession(self.get_name("mul_1.onnx"), sess_options=so)
+ self.assertTrue(os.path.isfile(so.optimized_model_filepath))
+
def testRunModel(self):
sess = onnxrt.InferenceSession(self.get_name("mul_1.onnx"))
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)