Check kMSDomain already exists before registering it (#10078)

* Check domain before registration
This commit is contained in:
ashari4 2021-12-17 17:55:15 -08:00 committed by GitHub
parent 12ee2e942f
commit cdbd678192
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 7 deletions

View file

@ -212,11 +212,16 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
#if !defined(ORT_MINIMAL_BUILD)
// Register Microsoft domain with min/max op_set version as 1/1.
std::call_once(schemaRegistrationOnceFlag, []() {
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSDomain, 1, 1);
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1);
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1);
auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
if (domainToVersionRangeInstance.Map().find(onnxruntime::kMSDomain) == domainToVersionRangeInstance.Map().end())
{
// External shared providers may have already added kMSDomain
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSDomain, 1, 1);
}
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1);
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1);
#ifdef USE_DML
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1);
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1);
#endif
// Register contributed schemas.
// The corresponding kernels are registered inside the appropriate execution provider.

View file

@ -3444,9 +3444,14 @@ Return true if all elements are true and false otherwise.
} // namespace training
void RegisterOrtOpSchemas() {
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSDomain, 1, 1);
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1);
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1);
auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
if (domainToVersionRangeInstance.Map().find(onnxruntime::kMSDomain) == domainToVersionRangeInstance.Map().end())
{
// External shared providers may have already added kMSDomain
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSDomain, 1, 1);
}
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1);
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1);
onnxruntime::contrib::RegisterContribSchemas();
onnxruntime::training::RegisterTrainingOpSchemas();