diff --git a/onnxruntime/core/framework/fuse_nodes_funcs.h b/onnxruntime/core/framework/fuse_nodes_funcs.h index 6b36b3fb1e..aa3dc01628 100644 --- a/onnxruntime/core/framework/fuse_nodes_funcs.h +++ b/onnxruntime/core/framework/fuse_nodes_funcs.h @@ -18,6 +18,8 @@ class FuncManager { Status GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const; + size_t NumFuncs() const { return fused_funcs_->size(); } + void SetFusedFuncs(const FuncManager& func_mgr) { fused_funcs_ = func_mgr.fused_funcs_; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 87fe26849d..ee5d9987f1 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1218,6 +1218,13 @@ common::Status InferenceSession::Initialize() { #if !defined(ORT_MINIMAL_BUILD) if (saving_model) { + if (session_state_->GetFuncMgr().NumFuncs() > 0) { + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unable to serialize model as it contains compiled nodes. " + "Please disable any execution providers which generate compiled nodes.")); + } + if (session_options_.graph_optimization_level >= TransformerLevel::Level3) { LOGS(*session_logger_, WARNING) << "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. " diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 9784f497dd..56b4252808 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -126,6 +126,24 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { ExecuteMnist(*session2, enable_custom_ep); } +TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { + const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort"); + + // make sure we can't save a model with compiled ops. input/output model format doesn't matter + SessionOptions so; + so.optimized_model_filepath = ORT_TSTR("invalid_model.ort"); + + auto session = onnxruntime::make_unique(so, GetEnvironment()); + + const std::unordered_set supported_ops{"Conv", "Add", "Relu", "MaxPool"}; + ASSERT_STATUS_OK(session->RegisterExecutionProvider( + onnxruntime::make_unique(supported_ops))); + + ASSERT_STATUS_OK(session->Load(ort_model_path)); + auto status = session->Initialize(); + ASSERT_FALSE(status.IsOK()) << "Initialize should have failed when trying to save model with compiled kernels"; + ASSERT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Unable to serialize model as it contains compiled nodes")); +} #endif // !defined(ORT_MINIMAL_BUILD) // test to validate a minimal build