Serizalize model only once to reduce backend preparation overhead (#8270)

* The serialization can be very heavy for large models
* Only use the serialized model check on compatible onnx versions
* onnx version >= 1.10.0 supports serialized model check
Signed-off-by: IceTDrinker <49040125+IceTDrinker@users.noreply.github.com>
This commit is contained in:
Arthur Meyre 2021-10-13 22:58:22 +02:00 committed by GitHub
parent e8ba5145ce
commit bccd09c688
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,6 +7,7 @@ Implements ONNX's backend API.
"""
from onnx import ModelProto
from onnx import helper
from onnx import version
from onnx.checker import check_model
from onnx.backend.base import Backend
from onnxruntime import InferenceSession, SessionOptions, get_device
@ -115,11 +116,21 @@ class OnnxRuntimeBackend(Backend):
return cls.prepare(inf, device, **kwargs)
else:
# type: ModelProto
check_model(model)
# check_model serializes the model anyways, so serialize the model once here
# and reuse it below in the cls.prepare call to avoid an additional serialization
# only works with onnx >= 1.10.0 hence the version check
onnx_version = tuple(map(int, (version.version.split(".")[:3])))
onnx_supports_serialized_model_check = onnx_version >= (1, 10, 0)
bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model
check_model(bin_or_model)
opset_supported, error_message = cls.is_opset_supported(model)
if not opset_supported:
raise unittest.SkipTest(error_message)
bin = model.SerializeToString()
# Now bin might be serialized, if it's not we need to serialize it otherwise we'll have
# an infinite recursive call
bin = bin_or_model
if not isinstance(bin, (str, bytes)):
bin = bin.SerializeToString()
return cls.prepare(bin, device, **kwargs)
@classmethod