mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Prevent saving a model containing fused nodes as we don't have any way to save the compiled kernels so the saved model will be invalid. (#5840)
This commit is contained in:
parent
e8c0f5d0ff
commit
b3a6ed14d4
3 changed files with 27 additions and 0 deletions
|
|
@ -18,6 +18,8 @@ class FuncManager {
|
|||
|
||||
Status GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const;
|
||||
|
||||
size_t NumFuncs() const { return fused_funcs_->size(); }
|
||||
|
||||
void SetFusedFuncs(const FuncManager& func_mgr) {
|
||||
fused_funcs_ = func_mgr.fused_funcs_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1218,6 +1218,13 @@ common::Status InferenceSession::Initialize() {
|
|||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
if (saving_model) {
|
||||
if (session_state_->GetFuncMgr().NumFuncs() > 0) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Unable to serialize model as it contains compiled nodes. "
|
||||
"Please disable any execution providers which generate compiled nodes."));
|
||||
}
|
||||
|
||||
if (session_options_.graph_optimization_level >= TransformerLevel::Level3) {
|
||||
LOGS(*session_logger_, WARNING)
|
||||
<< "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. "
|
||||
|
|
|
|||
|
|
@ -126,6 +126,24 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
|
|||
ExecuteMnist(*session2, enable_custom_ep);
|
||||
}
|
||||
|
||||
TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");
|
||||
|
||||
// make sure we can't save a model with compiled ops. input/output model format doesn't matter
|
||||
SessionOptions so;
|
||||
so.optimized_model_filepath = ORT_TSTR("invalid_model.ort");
|
||||
|
||||
auto session = onnxruntime::make_unique<InferenceSessionWrapper>(so, GetEnvironment());
|
||||
|
||||
const std::unordered_set<std::string> supported_ops{"Conv", "Add", "Relu", "MaxPool"};
|
||||
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
|
||||
onnxruntime::make_unique<InternalTestingExecutionProvider>(supported_ops)));
|
||||
|
||||
ASSERT_STATUS_OK(session->Load(ort_model_path));
|
||||
auto status = session->Initialize();
|
||||
ASSERT_FALSE(status.IsOK()) << "Initialize should have failed when trying to save model with compiled kernels";
|
||||
ASSERT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Unable to serialize model as it contains compiled nodes"));
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
// test to validate a minimal build
|
||||
|
|
|
|||
Loading…
Reference in a new issue