mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Fix reading of onnx domain causing one of the automl models to break in 0.5 release. (#1694)
* Mention OrtCreateSessionFromArray in C API doc * Fix registration of Equal op causing one of the automl models to break in 0.5 release. * updates...
This commit is contained in:
parent
e54904e6a3
commit
25d02a33c8
3 changed files with 42 additions and 6 deletions
|
|
@ -108,18 +108,25 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
|
|||
const auto& domain = opSet.domain();
|
||||
const auto version = opSet.version();
|
||||
// empty domain and 'ai.onnx' are equivalent
|
||||
if ((domain.empty() || domain == "ai.onnx") && version < 7) {
|
||||
if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) {
|
||||
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
|
||||
// in CI to opset 7 or above
|
||||
LOGS_DEFAULT(WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
|
||||
"with opset version 7 or above for opset domain 'ai.onnx'. "
|
||||
"Please upgrade your model to opset 7 or higher. "
|
||||
"For now, this opset "
|
||||
<< version
|
||||
<< version
|
||||
<< " model may run depending upon legacy support "
|
||||
"of some older opset version operators.";
|
||||
}
|
||||
domain_to_version[domain] = gsl::narrow_cast<int>(version);
|
||||
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
|
||||
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
|
||||
// This effectively ignores the opset version specified by the model for the onnx domain.
|
||||
if (domain == kOnnxDomainAlias) {
|
||||
domain_to_version[kOnnxDomain] = gsl::narrow_cast<int>(version);
|
||||
} else {
|
||||
domain_to_version[domain] = gsl::narrow_cast<int>(version);
|
||||
}
|
||||
}
|
||||
|
||||
auto domain_map = schema_registry->GetLatestOpsetVersions(false);
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class FuseExecutionProvider : public IExecutionProvider {
|
|||
class InferenceSessionGetGraphWrapper : public InferenceSession {
|
||||
public:
|
||||
explicit InferenceSessionGetGraphWrapper(const SessionOptions& session_options,
|
||||
logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) {
|
||||
logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) {
|
||||
}
|
||||
|
||||
const Graph& GetGraph() {
|
||||
|
|
@ -364,7 +364,7 @@ TEST(InferenceSessionTests, TestModelSerialization) {
|
|||
InferenceSessionGetGraphWrapper session_object{so, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object.Load(test_model).IsOK());
|
||||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
|
||||
|
||||
// Assert that model has been transformed and identity Node is removed.
|
||||
const auto& graph = session_object.GetGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -383,7 +383,7 @@ TEST(InferenceSessionTests, TestModelSerialization) {
|
|||
InferenceSession session_object_opt{so_opt, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object_opt.Load(so.optimized_model_filepath).IsOK());
|
||||
ASSERT_TRUE(session_object_opt.Initialize().IsOK());
|
||||
|
||||
|
||||
// Assert that re-feed of optimized model with default transform level results
|
||||
// in same runtime model as abs-id-max.onnx with TransformLevel-1.
|
||||
std::ifstream model_fs_session1(so.optimized_model_filepath, ios::in | ios::binary);
|
||||
|
|
@ -1481,5 +1481,16 @@ TEST(InferenceSessionTests, TestParallelExecutionWithCudaProvider) {
|
|||
|
||||
#endif
|
||||
|
||||
TEST(InferenceSessionTests, ModelWithKOnnxDomainAlias) {
|
||||
SessionOptions so;
|
||||
so.session_logid = "InferenceSessionTests.NoTimeout";
|
||||
InferenceSession session_object{so, &DefaultLoggingManager()};
|
||||
std::string file_name = "testdata/test_model_with_fullonnxdomain.onnx";
|
||||
auto ret_status = session_object.Load(file_name);
|
||||
ASSERT_TRUE(ret_status.IsOK()) << ret_status.ErrorMessage();
|
||||
ret_status = session_object.Initialize();
|
||||
ASSERT_TRUE(ret_status.IsOK()) << ret_status.ErrorMessage();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
18
onnxruntime/test/testdata/test_model_with_fullonnxdomain.onnx
vendored
Normal file
18
onnxruntime/test/testdata/test_model_with_fullonnxdomain.onnx
vendored
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
onnx-example"ai.onnx:j
|
||||
|
||||
X1
|
||||
X2Y"Equal:ai.onnx
|
||||
test-modelZ
|
||||
X1
|
||||
|
||||
|
||||
Z
|
||||
X2
|
||||
|
||||
|
||||
b
|
||||
Y
|
||||
|
||||
|
||||
B
|
||||
ai.onnx
|
||||
Loading…
Reference in a new issue