Fix run_model api. (#1111)

This commit is contained in:
Dmitri Smirnov 2019-05-28 16:36:57 -07:00 committed by GitHub
parent e19bc2d074
commit 8c7e4eb3fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 11 deletions

View file

@ -7,7 +7,7 @@ Implements ONNX's backend API.
"""
from onnx.checker import check_model
from onnx.backend.base import Backend
from onnxruntime import InferenceSession, RunOptions, SessionOptions, get_device
from onnxruntime import InferenceSession, SessionOptions, get_device
from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
@ -91,11 +91,7 @@ class OnnxRuntimeBackend(Backend):
:return: predictions
"""
rep = cls.prepare(model, device, **kwargs)
options = RunOptions()
for k, v in kwargs.items():
if hasattr(options, k):
setattr(options, k, v)
return rep.run(inputs, options)
return rep.run(inputs, **kwargs)
@classmethod
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):

View file

@ -6,6 +6,7 @@
Implements ONNX's backend API.
"""
import numpy as np
from onnxruntime import RunOptions
from onnx.backend.base import BackendRep
@ -27,11 +28,17 @@ class OnnxRuntimeBackendRep(BackendRep):
Computes the prediction.
See :meth:`onnxruntime.InferenceSession.run`.
"""
options = RunOptions()
for k, v in kwargs.items():
if hasattr(options, k):
setattr(options, k, v)
if isinstance(inputs, list):
inps = {}
for i, inp in enumerate(self._session.get_inputs()):
inps[inp.name] = inputs[i]
outs = self._session.run(None, inps)
outs = self._session.run(None, inps, options)
if isinstance(outs, list):
return outs
else:
@ -42,4 +49,4 @@ class OnnxRuntimeBackendRep(BackendRep):
if len(inp) != 1:
raise RuntimeError("Model expect {0} inputs".format(len(inp)))
inps = {inp[0].name: inputs}
return self._session.run(None, inps)
return self._session.run(None, inps, options)

View file

@ -9,11 +9,12 @@ import numpy as np
import onnxruntime as onnxrt
from onnxruntime import datasets
import onnxruntime.backend as backend
from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend
from onnx import load
class TestBackend(unittest.TestCase):
def get_name(self, name):
if os.path.exists(name):
return name
@ -43,10 +44,10 @@ class TestBackend(unittest.TestCase):
output_expected = np.array([[49.752754]], dtype=np.float32)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)
def testRunModelProto(self):
def testRunModelProto(self):
name = datasets.get_example("logreg_iris.onnx")
model = load(name)
rep = backend.prepare(model)
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
res = rep.run(x)
@ -57,6 +58,20 @@ class TestBackend(unittest.TestCase):
{0: 0.9997311234474182, 1: 1.1918064757310276e-07, 2: 0.00026869276189245284}]
self.assertEqual(output_expected, res[1])
def testRunModelProtoApi(self):
name = datasets.get_example("logreg_iris.onnx")
model = load(name)
inputs = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
outputs = ort_backend.run_model(model, inputs)
output_expected = np.array([0, 0, 0], dtype=np.float32)
np.testing.assert_allclose(output_expected, outputs[0], rtol=1e-05, atol=1e-08)
output_expected = [{0: 0.950599730014801, 1: 0.027834169566631317, 2: 0.02156602405011654},
{0: 0.9974970817565918, 1: 5.6299926654901356e-05, 2: 0.0024466661270707846},
{0: 0.9997311234474182, 1: 1.1918064757310276e-07, 2: 0.00026869276189245284}]
self.assertEqual(output_expected, outputs[1])
if __name__ == '__main__':
unittest.main(module=__name__, buffer=True)