mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix python examples in documentation (#3379)
This commit is contained in:
parent
accffded5d
commit
edec8043d4
2 changed files with 20 additions and 4 deletions
|
|
@ -25,7 +25,8 @@ if not os.path.exists('dense121.onnx'):
|
|||
model = DenseNet121(include_top=True, weights='imagenet')
|
||||
|
||||
from keras2onnx import convert_keras
|
||||
onx = convert_keras(model, 'dense121.onnx')
|
||||
onx = convert_keras(model, 'dense121.onnx')
|
||||
onx.ir_version = 6
|
||||
with open("dense121.onnx", "wb") as f:
|
||||
f.write(onx.SerializeToString())
|
||||
|
||||
|
|
|
|||
|
|
@ -11,16 +11,31 @@ Profile the execution of a simple model
|
|||
*ONNX Runtime* can profile the execution of the model.
|
||||
This example shows how to interpret the results.
|
||||
"""
|
||||
|
||||
import onnx
|
||||
import onnxruntime as rt
|
||||
import numpy
|
||||
from onnxruntime.datasets import get_example
|
||||
|
||||
|
||||
def change_ir_version(filename, ir_version=6):
|
||||
"onnxruntime==1.2.0 does not support opset <= 7 and ir_version > 6"
|
||||
with open(filename, "rb") as f:
|
||||
model = onnx.load(f)
|
||||
model.ir_version = 6
|
||||
if model.opset_import[0].version <= 7:
|
||||
model.opset_import[0].version = 11
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
#########################
|
||||
# Let's load a very simple model and compute some prediction.
|
||||
|
||||
example1 = get_example("mul_1.onnx")
|
||||
sess = rt.InferenceSession(example1)
|
||||
onnx_model = change_ir_version(example1)
|
||||
onnx_model_str = onnx_model.SerializeToString()
|
||||
sess = rt.InferenceSession(onnx_model_str)
|
||||
input_name = sess.get_inputs()[0].name
|
||||
|
||||
x = numpy.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=numpy.float32)
|
||||
|
|
@ -33,7 +48,7 @@ print(res)
|
|||
|
||||
options = rt.SessionOptions()
|
||||
options.enable_profiling = True
|
||||
sess_profile = rt.InferenceSession(example1, options)
|
||||
sess_profile = rt.InferenceSession(onnx_model_str, options)
|
||||
input_name = sess.get_inputs()[0].name
|
||||
|
||||
x = numpy.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=numpy.float32)
|
||||
|
|
|
|||
Loading…
Reference in a new issue