Prevent saving a model containing fused nodes as we don't have any way to save the compiled kernels so the saved model will be invalid. (#5840)

This commit is contained in:
Scott McKay 2020-11-18 16:17:07 +10:00 committed by GitHub
parent e8c0f5d0ff
commit b3a6ed14d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 0 deletions

View file

@ -18,6 +18,8 @@ class FuncManager {
Status GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const;
size_t NumFuncs() const { return fused_funcs_->size(); }
void SetFusedFuncs(const FuncManager& func_mgr) {
fused_funcs_ = func_mgr.fused_funcs_;
}

View file

@ -1218,6 +1218,13 @@ common::Status InferenceSession::Initialize() {
#if !defined(ORT_MINIMAL_BUILD)
if (saving_model) {
if (session_state_->GetFuncMgr().NumFuncs() > 0) {
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Unable to serialize model as it contains compiled nodes. "
"Please disable any execution providers which generate compiled nodes."));
}
if (session_options_.graph_optimization_level >= TransformerLevel::Level3) {
LOGS(*session_logger_, WARNING)
<< "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. "

View file

@ -126,6 +126,24 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
ExecuteMnist(*session2, enable_custom_ep);
}
TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");
// make sure we can't save a model with compiled ops. input/output model format doesn't matter
SessionOptions so;
so.optimized_model_filepath = ORT_TSTR("invalid_model.ort");
auto session = onnxruntime::make_unique<InferenceSessionWrapper>(so, GetEnvironment());
const std::unordered_set<std::string> supported_ops{"Conv", "Add", "Relu", "MaxPool"};
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
onnxruntime::make_unique<InternalTestingExecutionProvider>(supported_ops)));
ASSERT_STATUS_OK(session->Load(ort_model_path));
auto status = session->Initialize();
ASSERT_FALSE(status.IsOK()) << "Initialize should have failed when trying to save model with compiled kernels";
ASSERT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Unable to serialize model as it contains compiled nodes"));
}
#endif // !defined(ORT_MINIMAL_BUILD)
// test to validate a minimal build