fix py packaging pipeline (#5038)

* add test skip logic when opset > allowed opset

* fix attribute error

* plus fix
This commit is contained in:
Ashwini Khade 2020-09-03 09:32:10 -07:00 committed by GitHub
parent 22ba266bd6
commit 9ba2cfb71b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,10 +6,13 @@
Implements ONNX's backend API.
"""
from onnx import ModelProto
from onnx import helper
from onnx.checker import check_model
from onnx.backend.base import Backend
from onnxruntime import InferenceSession, SessionOptions, get_device
from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
import unittest
import os
class OnnxRuntimeBackend(Backend):
@ -24,6 +27,8 @@ class OnnxRuntimeBackend(Backend):
Note: This is not the official Python API.
""" # noqa: E501
allowReleasedOpsetsOnly = bool(os.getenv('ALLOW_RELEASED_ONNX_OPSET_ONLY', '1') == '1')
@classmethod
def is_compatible(cls, model, device=None, **kwargs):
"""
@ -37,6 +42,37 @@ class OnnxRuntimeBackend(Backend):
device = get_device()
return cls.supports_device(device)
@classmethod
def is_opset_supported(cls, model):
"""
Return whether the opset for the model is supported by the backend.
When By default only released onnx opsets are allowed by the backend
To test new opsets env variable ALLOW_RELEASED_ONNX_OPSET_ONLY should be set to 0
:param model: Model whose opsets needed to be verified.
:return: boolean and error message if opset is not supported.
"""
if cls.allowReleasedOpsetsOnly:
for opset in model.opset_import:
domain = opset.domain if opset.domain else 'ai.onnx'
try:
key = (domain, opset.version)
if not (key in helper.OP_SET_ID_VERSION_MAP):
error_message = ("Skipping this test as only released onnx opsets are supported."
"To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
" Got Domain '{0}' version '{1}'.".format(domain, opset.version))
return False, error_message
except AttributeError:
# for some CI pipelines accessing helper.OP_SET_ID_VERSION_MAP
# is generating attribute error. TODO investigate the pipelines to
# fix this error. Falling back to a simple version check when this error is encountered
if (domain == 'ai.onnx' and opset.version > 12) or (domain == 'ai.ommx.ml' and opset.version > 2):
error_message = ("Skipping this test as only released onnx opsets are supported."
"To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
" Got Domain '{0}' version '{1}'.".format(domain, opset.version))
return False, error_message
return True, ""
@classmethod
def supports_device(cls, device):
"""
@ -78,6 +114,9 @@ class OnnxRuntimeBackend(Backend):
else:
# type: ModelProto
check_model(model)
opset_supported, error_message = cls.is_opset_supported(model)
if not opset_supported:
raise unittest.SkipTest(error_message)
bin = model.SerializeToString()
return cls.prepare(bin, device, **kwargs)