mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Fix run_model api. (#1111)
This commit is contained in:
parent
e19bc2d074
commit
8c7e4eb3fb
3 changed files with 29 additions and 11 deletions
|
|
@ -7,7 +7,7 @@ Implements ONNX's backend API.
|
|||
"""
|
||||
from onnx.checker import check_model
|
||||
from onnx.backend.base import Backend
|
||||
from onnxruntime import InferenceSession, RunOptions, SessionOptions, get_device
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_device
|
||||
from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
|
||||
|
||||
|
||||
|
|
@ -91,11 +91,7 @@ class OnnxRuntimeBackend(Backend):
|
|||
:return: predictions
|
||||
"""
|
||||
rep = cls.prepare(model, device, **kwargs)
|
||||
options = RunOptions()
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(options, k):
|
||||
setattr(options, k, v)
|
||||
return rep.run(inputs, options)
|
||||
return rep.run(inputs, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
Implements ONNX's backend API.
|
||||
"""
|
||||
import numpy as np
|
||||
from onnxruntime import RunOptions
|
||||
from onnx.backend.base import BackendRep
|
||||
|
||||
|
||||
|
|
@ -27,11 +28,17 @@ class OnnxRuntimeBackendRep(BackendRep):
|
|||
Computes the prediction.
|
||||
See :meth:`onnxruntime.InferenceSession.run`.
|
||||
"""
|
||||
|
||||
options = RunOptions()
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(options, k):
|
||||
setattr(options, k, v)
|
||||
|
||||
if isinstance(inputs, list):
|
||||
inps = {}
|
||||
for i, inp in enumerate(self._session.get_inputs()):
|
||||
inps[inp.name] = inputs[i]
|
||||
outs = self._session.run(None, inps)
|
||||
outs = self._session.run(None, inps, options)
|
||||
if isinstance(outs, list):
|
||||
return outs
|
||||
else:
|
||||
|
|
@ -42,4 +49,4 @@ class OnnxRuntimeBackendRep(BackendRep):
|
|||
if len(inp) != 1:
|
||||
raise RuntimeError("Model expect {0} inputs".format(len(inp)))
|
||||
inps = {inp[0].name: inputs}
|
||||
return self._session.run(None, inps)
|
||||
return self._session.run(None, inps, options)
|
||||
|
|
|
|||
|
|
@ -9,11 +9,12 @@ import numpy as np
|
|||
import onnxruntime as onnxrt
|
||||
from onnxruntime import datasets
|
||||
import onnxruntime.backend as backend
|
||||
from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend
|
||||
from onnx import load
|
||||
|
||||
|
||||
class TestBackend(unittest.TestCase):
|
||||
|
||||
|
||||
def get_name(self, name):
|
||||
if os.path.exists(name):
|
||||
return name
|
||||
|
|
@ -43,10 +44,10 @@ class TestBackend(unittest.TestCase):
|
|||
output_expected = np.array([[49.752754]], dtype=np.float32)
|
||||
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)
|
||||
|
||||
def testRunModelProto(self):
|
||||
def testRunModelProto(self):
|
||||
name = datasets.get_example("logreg_iris.onnx")
|
||||
model = load(name)
|
||||
|
||||
|
||||
rep = backend.prepare(model)
|
||||
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
|
||||
res = rep.run(x)
|
||||
|
|
@ -57,6 +58,20 @@ class TestBackend(unittest.TestCase):
|
|||
{0: 0.9997311234474182, 1: 1.1918064757310276e-07, 2: 0.00026869276189245284}]
|
||||
self.assertEqual(output_expected, res[1])
|
||||
|
||||
def testRunModelProtoApi(self):
|
||||
name = datasets.get_example("logreg_iris.onnx")
|
||||
model = load(name)
|
||||
|
||||
inputs = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
|
||||
outputs = ort_backend.run_model(model, inputs)
|
||||
|
||||
output_expected = np.array([0, 0, 0], dtype=np.float32)
|
||||
np.testing.assert_allclose(output_expected, outputs[0], rtol=1e-05, atol=1e-08)
|
||||
output_expected = [{0: 0.950599730014801, 1: 0.027834169566631317, 2: 0.02156602405011654},
|
||||
{0: 0.9974970817565918, 1: 5.6299926654901356e-05, 2: 0.0024466661270707846},
|
||||
{0: 0.9997311234474182, 1: 1.1918064757310276e-07, 2: 0.00026869276189245284}]
|
||||
self.assertEqual(output_expected, outputs[1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(module=__name__, buffer=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue