remove keras example from python documentation (#6574)

This commit is contained in:
Xavier Dupré 2021-02-05 01:10:11 +01:00 committed by GitHub
parent 4e61e254ec
commit 615acf156c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 2 additions and 96 deletions

View file

@ -1,94 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""
.. _l-example-backend-api-tensorflow:
ONNX Runtime for Keras
======================
The following demonstrates how to compute the predictions
of a pretrained deep learning model obtained from
`keras <https://keras.io/>`_
with *onnxruntime*. The conversion requires
`keras <https://keras.io/>`_,
`tensorflow <https://www.tensorflow.org/>`_,
`keras-onnx <https://github.com/onnx/keras-onnx/>`_,
`onnxmltools <https://pypi.org/project/onnxmltools/>`_
but then only *onnxruntime* is required
to compute the predictions.
"""
import os
if not os.path.exists('dense121.onnx'):
from keras.applications.densenet import DenseNet121
model = DenseNet121(include_top=True, weights='imagenet')
from keras2onnx import convert_keras
onx = convert_keras(model, 'dense121.onnx')
with open("dense121.onnx", "wb") as f:
f.write(onx.SerializeToString())
##################################
# Let's load an image (source: wikipedia).
from keras.preprocessing.image import array_to_img, img_to_array, load_img
img = load_img('Sannosawa1.jpg')
ximg = img_to_array(img)
import matplotlib.pyplot as plt
plt.imshow(ximg / 255)
plt.axis('off')
#############################################
# Let's load the model with onnxruntime.
import onnxruntime as rt
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidGraph
try:
sess = rt.InferenceSession('dense121.onnx')
ok = True
except (InvalidGraph, TypeError, RuntimeError) as e:
# Probably a mismatch between onnxruntime and onnx version.
print(e)
ok = False
if ok:
print("The model expects input shape:", sess.get_inputs()[0].shape)
print("image shape:", ximg.shape)
#######################################
# Let's resize the image.
if ok:
from skimage.transform import resize
import numpy
ximg224 = resize(ximg / 255, (224, 224, 3), anti_aliasing=True)
ximg = ximg224[numpy.newaxis, :, :, :]
ximg = ximg.astype(numpy.float32)
print("new shape:", ximg.shape)
##################################
# Let's compute the output.
if ok:
input_name = sess.get_inputs()[0].name
res = sess.run(None, {input_name: ximg})
prob = res[0]
print(prob.ravel()[:10]) # Too big to be displayed.
##################################
# Let's get more comprehensive results.
if ok:
from keras.applications.densenet import decode_predictions
decoded = decode_predictions(prob)
import pandas
df = pandas.DataFrame(decoded[0], columns=["class_id", "name", "P"])
print(df)

View file

@ -1,5 +1,3 @@
keras
keras-onnx
sphinx
sphinx-gallery
pyquickhelper

View file

@ -23,6 +23,8 @@ def rename_folder(root):
renamed.append((r, name, into))
full_src = os.path.join(r, name)
full_into = os.path.join(r, into)
if os.path.exists(full_into):
raise RuntimeError("%r already exists, previous documentation should be removed.")
print("rename %r" % full_src)
os.rename(full_src, full_into)