Fix reading of onnx domain causing one of the automl models to break in 0.5 release. (#1694)

* Mention OrtCreateSessionFromArray in C API doc

* Fix registration of Equal op causing one of the automl models to break in 0.5 release.

* updates...
This commit is contained in:
Pranav Sharma 2019-08-29 12:18:39 -07:00 committed by GitHub
parent e54904e6a3
commit 25d02a33c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 6 deletions

View file

@ -108,18 +108,25 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
const auto& domain = opSet.domain();
const auto version = opSet.version();
// empty domain and 'ai.onnx' are equivalent
if ((domain.empty() || domain == "ai.onnx") && version < 7) {
if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) {
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
// in CI to opset 7 or above
LOGS_DEFAULT(WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
"with opset version 7 or above for opset domain 'ai.onnx'. "
"Please upgrade your model to opset 7 or higher. "
"For now, this opset "
<< version
<< version
<< " model may run depending upon legacy support "
"of some older opset version operators.";
}
domain_to_version[domain] = gsl::narrow_cast<int>(version);
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
// This effectively ignores the opset version specified by the model for the onnx domain.
if (domain == kOnnxDomainAlias) {
domain_to_version[kOnnxDomain] = gsl::narrow_cast<int>(version);
} else {
domain_to_version[domain] = gsl::narrow_cast<int>(version);
}
}
auto domain_map = schema_registry->GetLatestOpsetVersions(false);

View file

@ -122,7 +122,7 @@ class FuseExecutionProvider : public IExecutionProvider {
class InferenceSessionGetGraphWrapper : public InferenceSession {
public:
explicit InferenceSessionGetGraphWrapper(const SessionOptions& session_options,
logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) {
logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) {
}
const Graph& GetGraph() {
@ -364,7 +364,7 @@ TEST(InferenceSessionTests, TestModelSerialization) {
InferenceSessionGetGraphWrapper session_object{so, &DefaultLoggingManager()};
ASSERT_TRUE(session_object.Load(test_model).IsOK());
ASSERT_TRUE(session_object.Initialize().IsOK());
// Assert that model has been transformed and identity Node is removed.
const auto& graph = session_object.GetGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -383,7 +383,7 @@ TEST(InferenceSessionTests, TestModelSerialization) {
InferenceSession session_object_opt{so_opt, &DefaultLoggingManager()};
ASSERT_TRUE(session_object_opt.Load(so.optimized_model_filepath).IsOK());
ASSERT_TRUE(session_object_opt.Initialize().IsOK());
// Assert that re-feed of optimized model with default transform level results
// in same runtime model as abs-id-max.onnx with TransformLevel-1.
std::ifstream model_fs_session1(so.optimized_model_filepath, ios::in | ios::binary);
@ -1481,5 +1481,16 @@ TEST(InferenceSessionTests, TestParallelExecutionWithCudaProvider) {
#endif
TEST(InferenceSessionTests, ModelWithKOnnxDomainAlias) {
SessionOptions so;
so.session_logid = "InferenceSessionTests.NoTimeout";
InferenceSession session_object{so, &DefaultLoggingManager()};
std::string file_name = "testdata/test_model_with_fullonnxdomain.onnx";
auto ret_status = session_object.Load(file_name);
ASSERT_TRUE(ret_status.IsOK()) << ret_status.ErrorMessage();
ret_status = session_object.Initialize();
ASSERT_TRUE(ret_status.IsOK()) << ret_status.ErrorMessage();
}
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,18 @@
 onnx-example"ai.onnx:j

X1
X2Y"Equal:ai.onnx
test-modelZ
X1


Z
X2


b
Y
 

B
ai.onnx