From bccd09c6884088c67b68d95a85b232a13f860d91 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 13 Oct 2021 22:58:22 +0200 Subject: [PATCH] Serizalize model only once to reduce backend preparation overhead (#8270) * The serialization can be very heavy for large models * Only use the serialized model check on compatible onnx versions * onnx version >= 1.10.0 supports serialized model check Signed-off-by: IceTDrinker <49040125+IceTDrinker@users.noreply.github.com> --- onnxruntime/python/backend/backend.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py index 4436ac2c79..ddb7e79539 100644 --- a/onnxruntime/python/backend/backend.py +++ b/onnxruntime/python/backend/backend.py @@ -7,6 +7,7 @@ Implements ONNX's backend API. """ from onnx import ModelProto from onnx import helper +from onnx import version from onnx.checker import check_model from onnx.backend.base import Backend from onnxruntime import InferenceSession, SessionOptions, get_device @@ -115,11 +116,21 @@ class OnnxRuntimeBackend(Backend): return cls.prepare(inf, device, **kwargs) else: # type: ModelProto - check_model(model) + # check_model serializes the model anyways, so serialize the model once here + # and reuse it below in the cls.prepare call to avoid an additional serialization + # only works with onnx >= 1.10.0 hence the version check + onnx_version = tuple(map(int, (version.version.split(".")[:3]))) + onnx_supports_serialized_model_check = onnx_version >= (1, 10, 0) + bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model + check_model(bin_or_model) opset_supported, error_message = cls.is_opset_supported(model) if not opset_supported: raise unittest.SkipTest(error_message) - bin = model.SerializeToString() + # Now bin might be serialized, if it's not we need to serialize it otherwise we'll have + # an infinite recursive call + bin = bin_or_model + if not isinstance(bin, (str, bytes)): + bin = bin.SerializeToString() return cls.prepare(bin, device, **kwargs) @classmethod