diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py index 4436ac2c79..ddb7e79539 100644 --- a/onnxruntime/python/backend/backend.py +++ b/onnxruntime/python/backend/backend.py @@ -7,6 +7,7 @@ Implements ONNX's backend API. """ from onnx import ModelProto from onnx import helper +from onnx import version from onnx.checker import check_model from onnx.backend.base import Backend from onnxruntime import InferenceSession, SessionOptions, get_device @@ -115,11 +116,21 @@ class OnnxRuntimeBackend(Backend): return cls.prepare(inf, device, **kwargs) else: # type: ModelProto - check_model(model) + # check_model serializes the model anyways, so serialize the model once here + # and reuse it below in the cls.prepare call to avoid an additional serialization + # only works with onnx >= 1.10.0 hence the version check + onnx_version = tuple(map(int, (version.version.split(".")[:3]))) + onnx_supports_serialized_model_check = onnx_version >= (1, 10, 0) + bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model + check_model(bin_or_model) opset_supported, error_message = cls.is_opset_supported(model) if not opset_supported: raise unittest.SkipTest(error_message) - bin = model.SerializeToString() + # Now bin might be serialized, if it's not we need to serialize it otherwise we'll have + # an infinite recursive call + bin = bin_or_model + if not isinstance(bin, (str, bytes)): + bin = bin.SerializeToString() return cls.prepare(bin, device, **kwargs) @classmethod