diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py index 2a29ccdaa5..3c88a4eb3f 100644 --- a/onnxruntime/python/backend/backend.py +++ b/onnxruntime/python/backend/backend.py @@ -7,7 +7,7 @@ Implements ONNX's backend API. """ from onnx.checker import check_model from onnx.backend.base import Backend -from onnxruntime import InferenceSession, RunOptions, SessionOptions, get_device +from onnxruntime import InferenceSession, SessionOptions, get_device from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep @@ -91,11 +91,7 @@ class OnnxRuntimeBackend(Backend): :return: predictions """ rep = cls.prepare(model, device, **kwargs) - options = RunOptions() - for k, v in kwargs.items(): - if hasattr(options, k): - setattr(options, k, v) - return rep.run(inputs, options) + return rep.run(inputs, **kwargs) @classmethod def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): diff --git a/onnxruntime/python/backend/backend_rep.py b/onnxruntime/python/backend/backend_rep.py index 63aa586ade..fdc2569709 100644 --- a/onnxruntime/python/backend/backend_rep.py +++ b/onnxruntime/python/backend/backend_rep.py @@ -6,6 +6,7 @@ Implements ONNX's backend API. """ import numpy as np +from onnxruntime import RunOptions from onnx.backend.base import BackendRep @@ -27,11 +28,17 @@ class OnnxRuntimeBackendRep(BackendRep): Computes the prediction. See :meth:`onnxruntime.InferenceSession.run`. """ + + options = RunOptions() + for k, v in kwargs.items(): + if hasattr(options, k): + setattr(options, k, v) + if isinstance(inputs, list): inps = {} for i, inp in enumerate(self._session.get_inputs()): inps[inp.name] = inputs[i] - outs = self._session.run(None, inps) + outs = self._session.run(None, inps, options) if isinstance(outs, list): return outs else: @@ -42,4 +49,4 @@ class OnnxRuntimeBackendRep(BackendRep): if len(inp) != 1: raise RuntimeError("Model expect {0} inputs".format(len(inp))) inps = {inp[0].name: inputs} - return self._session.run(None, inps) + return self._session.run(None, inps, options) diff --git a/onnxruntime/test/python/onnxruntime_test_python_backend.py b/onnxruntime/test/python/onnxruntime_test_python_backend.py index 2df6fed0cc..a2e6864f51 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_backend.py +++ b/onnxruntime/test/python/onnxruntime_test_python_backend.py @@ -9,11 +9,12 @@ import numpy as np import onnxruntime as onnxrt from onnxruntime import datasets import onnxruntime.backend as backend +from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend from onnx import load class TestBackend(unittest.TestCase): - + def get_name(self, name): if os.path.exists(name): return name @@ -43,10 +44,10 @@ class TestBackend(unittest.TestCase): output_expected = np.array([[49.752754]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - def testRunModelProto(self): + def testRunModelProto(self): name = datasets.get_example("logreg_iris.onnx") model = load(name) - + rep = backend.prepare(model) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) res = rep.run(x) @@ -57,6 +58,20 @@ class TestBackend(unittest.TestCase): {0: 0.9997311234474182, 1: 1.1918064757310276e-07, 2: 0.00026869276189245284}] self.assertEqual(output_expected, res[1]) + def testRunModelProtoApi(self): + name = datasets.get_example("logreg_iris.onnx") + model = load(name) + + inputs = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + outputs = ort_backend.run_model(model, inputs) + + output_expected = np.array([0, 0, 0], dtype=np.float32) + np.testing.assert_allclose(output_expected, outputs[0], rtol=1e-05, atol=1e-08) + output_expected = [{0: 0.950599730014801, 1: 0.027834169566631317, 2: 0.02156602405011654}, + {0: 0.9974970817565918, 1: 5.6299926654901356e-05, 2: 0.0024466661270707846}, + {0: 0.9997311234474182, 1: 1.1918064757310276e-07, 2: 0.00026869276189245284}] + self.assertEqual(output_expected, outputs[1]) + if __name__ == '__main__': unittest.main(module=__name__, buffer=True)