onnxruntime/onnxruntime/python/backend/backend.py
Prabhat dd43623da2
Remove ONNX from requirements.txt (#4073)
* Avoid installing ONNX package on aarch64

* Removed onnx from requirements

* Add note in backend.py
2020-05-29 21:44:20 +05:30

113 lines
4.4 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Implements ONNX's backend API.
"""
from onnx import ModelProto
from onnx.checker import check_model
from onnx.backend.base import Backend
from onnxruntime import InferenceSession, SessionOptions, get_device
from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
class OnnxRuntimeBackend(Backend):
"""
Implements
`ONNX's backend API <https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md>`_
with *ONNX Runtime*.
The backend is mostly used when you need to switch between
multiple runtimes with the same API.
`Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
shows how to use *caffe2* as a backend for a converted model.
Note: This is not the official Python API.
""" # noqa: E501
@classmethod
def is_compatible(cls, model, device=None, **kwargs):
"""
Return whether the model is compatible with the backend.
:param model: unused
:param device: None to use the default device or a string (ex: `'CPU'`)
:return: boolean
"""
if device is None:
device = get_device()
return cls.supports_device(device)
@classmethod
def supports_device(cls, device):
"""
Check whether the backend is compiled with particular device support.
In particular it's used in the testing suite.
"""
return device in get_device()
@classmethod
def prepare(cls, model, device=None, **kwargs):
"""
Load the model and creates a :class:`onnxruntime.InferenceSession`
ready to be used as a backend.
:param model: ModelProto (returned by `onnx.load`),
string for a filename or bytes for a serialized model
:param device: requested device for the computation,
None means the default one which depends on
the compilation settings
:param kwargs: see :class:`onnxruntime.SessionOptions`
:return: :class:`onnxruntime.InferenceSession`
"""
if isinstance(model, OnnxRuntimeBackendRep):
return model
elif isinstance(model, InferenceSession):
return OnnxRuntimeBackendRep(model)
elif isinstance(model, (str, bytes)):
options = SessionOptions()
for k, v in kwargs.items():
if hasattr(options, k):
setattr(options, k, v)
inf = InferenceSession(model, options)
# backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
# which may hide test failures.
inf.disable_fallback()
if device is not None and not cls.supports_device(device):
raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device()))
return cls.prepare(inf, device, **kwargs)
else:
# type: ModelProto
check_model(model)
bin = model.SerializeToString()
return cls.prepare(bin, device, **kwargs)
@classmethod
def run_model(cls, model, inputs, device=None, **kwargs):
"""
Compute the prediction.
:param model: :class:`onnxruntime.InferenceSession` returned
by function *prepare*
:param inputs: inputs
:param device: requested device for the computation,
None means the default one which depends on
the compilation settings
:param kwargs: see :class:`onnxruntime.RunOptions`
:return: predictions
"""
rep = cls.prepare(model, device, **kwargs)
return rep.run(inputs, **kwargs)
@classmethod
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
'''
This method is not implemented as it is much more efficient
to run a whole model than every node independently.
'''
raise NotImplementedError("It is much more efficient to run a whole model than every node independently.")
is_compatible = OnnxRuntimeBackend.is_compatible
prepare = OnnxRuntimeBackend.prepare
run = OnnxRuntimeBackend.run_model
supports_device = OnnxRuntimeBackend.supports_device