Fix python examples in documentation (#3379)

This commit is contained in:
Xavier Dupré 2020-04-01 22:48:32 +02:00 committed by GitHub
parent accffded5d
commit edec8043d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 4 deletions

View file

@ -25,7 +25,8 @@ if not os.path.exists('dense121.onnx'):
model = DenseNet121(include_top=True, weights='imagenet')
from keras2onnx import convert_keras
onx = convert_keras(model, 'dense121.onnx')
onx = convert_keras(model, 'dense121.onnx')
onx.ir_version = 6
with open("dense121.onnx", "wb") as f:
f.write(onx.SerializeToString())

View file

@ -11,16 +11,31 @@ Profile the execution of a simple model
*ONNX Runtime* can profile the execution of the model.
This example shows how to interpret the results.
"""
import onnx
import onnxruntime as rt
import numpy
from onnxruntime.datasets import get_example
def change_ir_version(filename, ir_version=6):
"onnxruntime==1.2.0 does not support opset <= 7 and ir_version > 6"
with open(filename, "rb") as f:
model = onnx.load(f)
model.ir_version = 6
if model.opset_import[0].version <= 7:
model.opset_import[0].version = 11
return model
#########################
# Let's load a very simple model and compute some prediction.
example1 = get_example("mul_1.onnx")
sess = rt.InferenceSession(example1)
onnx_model = change_ir_version(example1)
onnx_model_str = onnx_model.SerializeToString()
sess = rt.InferenceSession(onnx_model_str)
input_name = sess.get_inputs()[0].name
x = numpy.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=numpy.float32)
@ -33,7 +48,7 @@ print(res)
options = rt.SessionOptions()
options.enable_profiling = True
sess_profile = rt.InferenceSession(example1, options)
sess_profile = rt.InferenceSession(onnx_model_str, options)
input_name = sess.get_inputs()[0].name
x = numpy.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=numpy.float32)