mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Serialize optimized onnx model (#1470)
* Model serialization * Removed duplicate symbol * Minor update * Review comments * add tests * Model serialization * Removed duplicate symbol * Minor update * Merged PR 1106437: Model Serialization in onnxruntime * Review comments * Merged PR 1107226: Review comments Review comments * add tests * Fixed merge conflict * Correct python tests * InferenceSesssion Refeed Test * Replace use of widechar const literal-L * Fixed failing tests * Updated comment * Removed unnecessary session options * Spell check on comments * Do not serialize when level 3 optimization specified * Updated error logs * Changed log severity to WARN
This commit is contained in:
parent
8a559d75ae
commit
a50a63aa9e
13 changed files with 150 additions and 1 deletions
|
|
@ -130,6 +130,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtDisableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, [MarshalAs(UnmanagedType.LPWStr)]string optimizedModelFilepath);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix);
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,15 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set filepath to save optimized model after graph level transformations.
|
||||
/// </summary>
|
||||
/// <param name="optimizedModelFilepath">File path for saving optimized model.</param>
|
||||
public void SetOptimizedModelFilePath(string optimizedModelFilepath)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, optimizedModelFilepath));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Enable Sequential Execution. By default, it is enabled.
|
||||
/// </summary>
|
||||
|
|
|
|||
|
|
@ -638,6 +638,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
private void TestModelSerialization()
|
||||
{
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
|
||||
string modelOutputPath = Path.Combine(Directory.GetCurrentDirectory(), "optimized-squeezenet.onnx");
|
||||
// Set the optimized model file path to assert that no exception are thrown.
|
||||
SessionOptions options = new SessionOptions();
|
||||
options.SetOptimizedModelFilePath(modelOutputPath);
|
||||
options.SetSessionGraphOptimizationLevel(1);
|
||||
var session = new InferenceSession(modelPath, options);
|
||||
Assert.NotNull(session);
|
||||
Assert.True(File.Exists(modelOutputPath));
|
||||
}
|
||||
|
||||
[GpuFact]
|
||||
private void TestGpu()
|
||||
{
|
||||
|
|
@ -678,7 +692,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
|
||||
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
|
||||
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSetSessionGraphOptimizationLevel",
|
||||
"OrtSessionOptionsAppendExecutionProvider_CPU","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
|
||||
"OrtSetOptimizedModelFilePath", "OrtSessionOptionsAppendExecutionProvider_CPU","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
|
||||
"OrtCreateDefaultAllocator","OrtAllocatorFree","OrtAllocatorGetInfo",
|
||||
"OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo",
|
||||
"OrtCastTypeInfoToTensorInfo","OrtGetTensorTypeAndShape","OrtGetTensorElementType","OrtGetDimensionsCount",
|
||||
|
|
|
|||
|
|
@ -201,6 +201,9 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
|
|||
*/
|
||||
ORT_API_STATUS(OrtCreateSessionOptions, _Outptr_ OrtSessionOptions** options);
|
||||
|
||||
// Set filepath to save optimized model after graph level transformations.
|
||||
ORT_API_STATUS(OrtSetOptimizedModelFilePath, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath);
|
||||
|
||||
// create a copy of an existing OrtSessionOptions
|
||||
ORT_API_STATUS(OrtCloneSessionOptions, _In_ const OrtSessionOptions* in_options, _Outptr_ OrtSessionOptions** out_options);
|
||||
ORT_API_STATUS(OrtEnableSequentialExecution, _Inout_ OrtSessionOptions* options);
|
||||
|
|
|
|||
|
|
@ -140,6 +140,8 @@ struct SessionOptions : Base<OrtSessionOptions> {
|
|||
SessionOptions& EnableCpuMemArena();
|
||||
SessionOptions& DisableCpuMemArena();
|
||||
|
||||
SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
|
||||
|
||||
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
|
||||
SessionOptions& DisableProfiling();
|
||||
|
||||
|
|
|
|||
|
|
@ -143,6 +143,11 @@ inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(int graph_optim
|
|||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
|
||||
ORT_THROW_ON_ERROR(OrtSetOptimizedModelFilePath(p_, optimized_model_filepath));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
|
||||
ORT_THROW_ON_ERROR(OrtEnableProfiling(p_, profile_file_prefix));
|
||||
return *this;
|
||||
|
|
|
|||
|
|
@ -80,5 +80,6 @@ OrtSetDimensions
|
|||
OrtSetSessionGraphOptimizationLevel
|
||||
OrtSetSessionLogId
|
||||
OrtSetSessionLogVerbosityLevel
|
||||
OrtSetOptimizedModelFilePath
|
||||
OrtSetSessionThreadPoolSize
|
||||
OrtSetTensorElementType
|
||||
|
|
|
|||
|
|
@ -44,6 +44,12 @@ ORT_API_STATUS_IMPL(OrtDisableSequentialExecution, _In_ OrtSessionOptions* optio
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// set filepath to save optimized onnx model.
|
||||
ORT_API_STATUS_IMPL(OrtSetOptimizedModelFilePath, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath) {
|
||||
options->value.optimized_model_filepath = optimized_model_filepath;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// enable profiling for this session.
|
||||
ORT_API_STATUS_IMPL(OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix) {
|
||||
options->value.enable_profiling = true;
|
||||
|
|
|
|||
|
|
@ -528,6 +528,16 @@ common::Status InferenceSession::Initialize() {
|
|||
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
|
||||
if (!session_options_.optimized_model_filepath.empty()) {
|
||||
if (session_options_.graph_optimization_level < TransformerLevel::Level3) {
|
||||
// Serialize optimized ONNX model.
|
||||
ORT_RETURN_IF_ERROR(Model::Save(*model_, session_options_.optimized_model_filepath));
|
||||
} else {
|
||||
LOGS(*session_logger_, WARNING) << "Serializing Optimized ONNX model with Graph Optimization"
|
||||
" level greater than 2 is not supported.";
|
||||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution));
|
||||
ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr));
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ struct SessionOptions {
|
|||
// enable profiling for this session.
|
||||
bool enable_profiling = false;
|
||||
|
||||
// non empty filepath enables serialization of the transformed optimized model to the specified filepath.
|
||||
std::basic_string<ORTCHAR_T> optimized_model_filepath;
|
||||
|
||||
// enable the memory pattern optimization.
|
||||
// The idea is if the input shapes are the same, we could trace the internal memory allocation
|
||||
// and generate a memory pattern for future request. So next time we could just do one allocation
|
||||
|
|
|
|||
|
|
@ -386,6 +386,8 @@ void addObjectMethods(py::module& m) {
|
|||
Set this option to false if you don't want it. Default is True.)pbdoc")
|
||||
.def_readwrite("enable_profiling", &SessionOptions::enable_profiling,
|
||||
R"pbdoc(Enable profiling for this session. Default is false.)pbdoc")
|
||||
.def_readwrite("optimized_model_filepath", &SessionOptions::optimized_model_filepath,
|
||||
R"pbdoc(File path to serialize optimized model. By default, optimized model is not serialized if optimized_model_filepath is not provided.)pbdoc")
|
||||
.def_readwrite("enable_mem_pattern", &SessionOptions::enable_mem_pattern,
|
||||
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
|
||||
.def_readwrite("enable_sequential_execution", &SessionOptions::enable_sequential_execution,
|
||||
|
|
|
|||
|
|
@ -118,6 +118,18 @@ class FuseExecutionProvider : public IExecutionProvider {
|
|||
}
|
||||
};
|
||||
|
||||
// InferenceSession wrapper to expose loaded graph.
|
||||
class InferenceSessionGetGraphWrapper : public InferenceSession {
|
||||
public:
|
||||
explicit InferenceSessionGetGraphWrapper(const SessionOptions& session_options,
|
||||
logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) {
|
||||
}
|
||||
|
||||
const Graph& GetGraph() {
|
||||
return model_->MainGraph();
|
||||
}
|
||||
};
|
||||
|
||||
namespace test {
|
||||
static void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
|
||||
const std::vector<float>& expected_values);
|
||||
|
|
@ -330,6 +342,77 @@ TEST(InferenceSessionTests, DisableCPUArena) {
|
|||
RunModel(session_object, run_options);
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, TestModelSerialization) {
|
||||
// Load model with level 0 transform level
|
||||
// and assert that the model has Identity nodes.
|
||||
SessionOptions so;
|
||||
const string test_model = "testdata/transform/abs-id-max.onnx";
|
||||
so.session_logid = "InferenceSessionTests.TestModelSerialization";
|
||||
so.graph_optimization_level = TransformerLevel::Default;
|
||||
InferenceSessionGetGraphWrapper session_object_noopt{so, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object_noopt.Load(test_model).IsOK());
|
||||
ASSERT_TRUE(session_object_noopt.Initialize().IsOK());
|
||||
|
||||
// Assert that model has Identity Nodes.
|
||||
const auto& graph_noopt = session_object_noopt.GetGraph();
|
||||
std::map<std::string, int> op_to_count_noopt = CountOpsInGraph(graph_noopt);
|
||||
ASSERT_TRUE(op_to_count_noopt["Identity"] > 0);
|
||||
|
||||
// Load model with level 1 transform level.
|
||||
so.graph_optimization_level = TransformerLevel::Level1;
|
||||
so.optimized_model_filepath = ToWideString(test_model + "-TransformLevel-" + std::to_string(static_cast<uint32_t>(so.graph_optimization_level)));
|
||||
InferenceSessionGetGraphWrapper session_object{so, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object.Load(test_model).IsOK());
|
||||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
|
||||
// Assert that model has been transformed and identity Node is removed.
|
||||
const auto& graph = session_object.GetGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 0);
|
||||
|
||||
// Serialize model to the same file path again to make sure that rewrite doesn't fail.
|
||||
InferenceSession overwrite_session_object{so, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(overwrite_session_object.Load(test_model).IsOK());
|
||||
ASSERT_TRUE(overwrite_session_object.Initialize().IsOK());
|
||||
|
||||
// Load serialized model with no transform level and serialize model.
|
||||
SessionOptions so_opt;
|
||||
so_opt.session_logid = "InferenceSessionTests.TestModelSerialization";
|
||||
so_opt.graph_optimization_level = TransformerLevel::Default;
|
||||
so_opt.optimized_model_filepath = ToWideString(so.optimized_model_filepath) + ToWideString("-TransformLevel-" + std::to_string(static_cast<uint32_t>(so_opt.graph_optimization_level)));
|
||||
InferenceSession session_object_opt{so_opt, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object_opt.Load(so.optimized_model_filepath).IsOK());
|
||||
ASSERT_TRUE(session_object_opt.Initialize().IsOK());
|
||||
|
||||
// Assert that re-feed of optimized model with default transform level results
|
||||
// in same runtime model as abs-id-max.onnx with TransformLevel-1.
|
||||
std::ifstream model_fs_session1(so.optimized_model_filepath, ios::in | ios::binary);
|
||||
ASSERT_TRUE(model_fs_session1.good());
|
||||
std::ifstream model_fs_session2(so_opt.optimized_model_filepath, ios::in | ios::binary);
|
||||
ASSERT_TRUE(model_fs_session2.good());
|
||||
ASSERT_TRUE(model_fs_session1.tellg() == model_fs_session2.tellg());
|
||||
model_fs_session1.seekg(0, std::ifstream::beg);
|
||||
model_fs_session2.seekg(0, std::ifstream::beg);
|
||||
ASSERT_TRUE(std::equal(std::istreambuf_iterator<char>(model_fs_session1.rdbuf()),
|
||||
std::istreambuf_iterator<char>(),
|
||||
std::istreambuf_iterator<char>(model_fs_session2.rdbuf())));
|
||||
|
||||
// Assert that empty optimized model file-path doesn't fail loading.
|
||||
so_opt.optimized_model_filepath = ToWideString("");
|
||||
InferenceSession session_object_emptyValidation{so_opt, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object_emptyValidation.Load(test_model).IsOK());
|
||||
ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK());
|
||||
|
||||
// Assert that level 3 optimization doesn't result in serialized model.
|
||||
so_opt.optimized_model_filepath = ToWideString("ShouldNotSerialize");
|
||||
so_opt.graph_optimization_level = TransformerLevel::Level3;
|
||||
InferenceSession session_object_Level3Test{so_opt, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object_Level3Test.Load(test_model).IsOK());
|
||||
ASSERT_TRUE(session_object_Level3Test.Initialize().IsOK());
|
||||
std::ifstream model_fs_Level3(so_opt.optimized_model_filepath, ios::in | ios::binary);
|
||||
ASSERT_TRUE(model_fs_Level3.fail());
|
||||
}
|
||||
|
||||
#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS
|
||||
static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) {
|
||||
if (f_arg.size() != s_arg.size()) {
|
||||
|
|
|
|||
|
|
@ -34,6 +34,14 @@ class TestInferenceSession(unittest.TestCase):
|
|||
np.testing.assert_allclose(
|
||||
output_expected, res[0], rtol=1e-05, atol=1e-08)
|
||||
|
||||
def testModelSerialization(self):
|
||||
so = onnxrt.SessionOptions()
|
||||
so.session_log_verbosity_level = 1
|
||||
so.session_logid = "TestModelSerialization"
|
||||
so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx"
|
||||
onnxrt.InferenceSession(self.get_name("mul_1.onnx"), sess_options=so)
|
||||
self.assertTrue(os.path.isfile(so.optimized_model_filepath))
|
||||
|
||||
def testRunModel(self):
|
||||
sess = onnxrt.InferenceSession(self.get_name("mul_1.onnx"))
|
||||
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
|
||||
|
|
|
|||
Loading…
Reference in a new issue