onnxruntime/docs/python/examples/plot_train_convert_predict.py

235 lines
6.8 KiB
Python
Raw Normal View History

2018-11-20 00:48:22 +00:00
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# pylint: disable=C0411,C0412,C0413
2018-11-20 00:48:22 +00:00
"""
.. _l-logreg-example-speed:
2018-11-20 00:48:22 +00:00
Train, convert and predict with ONNX Runtime
============================================
This example demonstrates an end to end scenario
starting with the training of a machine learned model
to its use in its converted from.
Train a logistic regression
+++++++++++++++++++++++++++
The first step consists in retrieving the iris datset.
"""
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
2018-11-20 00:48:22 +00:00
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
####################################
# Then we fit a model.
clr = LogisticRegression()
clr.fit(X_train, y_train)
####################################
# We compute the prediction on the test set
# and we show the confusion matrix.
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
from sklearn.metrics import confusion_matrix # noqa: E402
2018-11-20 00:48:22 +00:00
pred = clr.predict(X_test)
print(confusion_matrix(y_test, pred))
####################################
# Conversion to ONNX format
# +++++++++++++++++++++++++
#
# We use module
# `sklearn-onnx <https://github.com/onnx/sklearn-onnx>`_
2018-11-20 00:48:22 +00:00
# to convert the model into ONNX format.
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
from skl2onnx import convert_sklearn # noqa: E402
from skl2onnx.common.data_types import FloatTensorType # noqa: E402
2018-11-20 00:48:22 +00:00
initial_type = [("float_input", FloatTensorType([None, 4]))]
2018-11-20 00:48:22 +00:00
onx = convert_sklearn(clr, initial_types=initial_type)
with open("logreg_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
2018-11-20 00:48:22 +00:00
##################################
# We load the model with ONNX Runtime and look at
# its input and output.
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
import onnxruntime as rt # noqa: E402
sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())
2018-11-20 00:48:22 +00:00
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
print(f"input name='{sess.get_inputs()[0].name}' and shape={sess.get_inputs()[0].shape}")
print(f"output name='{sess.get_outputs()[0].name}' and shape={sess.get_outputs()[0].shape}")
2018-11-20 00:48:22 +00:00
##################################
# We compute the predictions.
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
import numpy # noqa: E402
2018-11-20 00:48:22 +00:00
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(confusion_matrix(pred, pred_onx))
###################################
# The prediction are perfectly identical.
#
# Probabilities
# +++++++++++++
#
# Probabilities are needed to compute other
# relevant metrics such as the ROC Curve.
# Let's see how to get them first with
# scikit-learn.
prob_sklearn = clr.predict_proba(X_test)
print(prob_sklearn[:3])
#############################
# And then with ONNX Runtime.
# The probabilies appear to be
2018-11-20 00:48:22 +00:00
prob_name = sess.get_outputs()[1].name
prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0]
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
import pprint # noqa: E402
2018-11-20 00:48:22 +00:00
pprint.pprint(prob_rt[0:3])
###############################
# Let's benchmark.
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
from timeit import Timer # noqa: E402
2018-11-20 00:48:22 +00:00
def speed(inst, number=5, repeat=10):
2018-11-20 00:48:22 +00:00
timer = Timer(inst, globals=globals())
raw = numpy.array(timer.repeat(repeat, number=number))
ave = raw.sum() / len(raw) / number
mi, ma = raw.min() / number, raw.max() / number
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
print(f"Average {ave:1.3g} min={mi:1.3g} max={ma:1.3g}")
2018-11-20 00:48:22 +00:00
return ave
2018-11-20 00:48:22 +00:00
print("Execution time for clr.predict")
speed("clr.predict(X_test)")
print("Execution time for ONNX Runtime")
speed("sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]")
###############################
# Let's benchmark a scenario similar to what a webservice
# experiences: the model has to do one prediction at a time
# as opposed to a batch of prediction.
2018-11-20 00:48:22 +00:00
def loop(X_test, fct, n=None):
nrow = X_test.shape[0]
if n is None:
n = nrow
for i in range(0, n):
im = i % nrow
fct(X_test[im : im + 1])
2018-11-20 00:48:22 +00:00
print("Execution time for clr.predict")
speed("loop(X_test, clr.predict, 50)")
2018-11-20 00:48:22 +00:00
2018-11-20 00:48:22 +00:00
def sess_predict(x):
return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0]
2018-11-20 00:48:22 +00:00
print("Execution time for sess_predict")
speed("loop(X_test, sess_predict, 50)")
2018-11-20 00:48:22 +00:00
#####################################
# Let's do the same for the probabilities.
print("Execution time for predict_proba")
speed("loop(X_test, clr.predict_proba, 50)")
2018-11-20 00:48:22 +00:00
2018-11-20 00:48:22 +00:00
def sess_predict_proba(x):
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]
2018-11-20 00:48:22 +00:00
print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba, 50)")
2018-11-20 00:48:22 +00:00
#####################################
# This second comparison is better as
2018-11-20 00:48:22 +00:00
# ONNX Runtime, in this experience,
# computes the label and the probabilities
# in every case.
##########################################
# Benchmark with RandomForest
# +++++++++++++++++++++++++++
#
# We first train and save a model in ONNX format.
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
from sklearn.ensemble import RandomForestClassifier # noqa: E402
rf = RandomForestClassifier(n_estimators=10)
2018-11-20 00:48:22 +00:00
rf.fit(X_train, y_train)
initial_type = [("float_input", FloatTensorType([1, 4]))]
2018-11-20 00:48:22 +00:00
onx = convert_sklearn(rf, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
2018-11-20 00:48:22 +00:00
###################################
# We compare.
sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers())
2018-11-20 00:48:22 +00:00
2018-11-20 00:48:22 +00:00
def sess_predict_proba_rf(x):
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]
2018-11-20 00:48:22 +00:00
print("Execution time for predict_proba")
speed("loop(X_test, rf.predict_proba, 50)")
2018-11-20 00:48:22 +00:00
print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba_rf, 50)")
2018-11-20 00:48:22 +00:00
##################################
# Let's see with different number of trees.
measures = []
for n_trees in range(5, 51, 5):
2018-11-20 00:48:22 +00:00
print(n_trees)
rf = RandomForestClassifier(n_estimators=n_trees)
rf.fit(X_train, y_train)
initial_type = [("float_input", FloatTensorType([1, 4]))]
2018-11-20 00:48:22 +00:00
onx = convert_sklearn(rf, initial_types=initial_type)
with open("rf_iris_%d.onnx" % n_trees, "wb") as f:
f.write(onx.SerializeToString())
sess = rt.InferenceSession("rf_iris_%d.onnx" % n_trees, providers=rt.get_available_providers())
2018-11-20 00:48:22 +00:00
def sess_predict_proba_loop(x):
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] # noqa: B023
tsk = speed("loop(X_test, rf.predict_proba, 25)", number=5, repeat=4)
trt = speed("loop(X_test, sess_predict_proba_loop, 25)", number=5, repeat=4)
measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt})
2018-11-20 00:48:22 +00:00
Adopt linrtunner as the linting tool - take 2 (#15085) ### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot: ![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png) 5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 22:29:03 +00:00
from pandas import DataFrame # noqa: E402
2018-11-20 00:48:22 +00:00
df = DataFrame(measures)
ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True)
df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True)
2018-11-20 00:48:22 +00:00
ax.set_xlabel("Number of trees")
ax.set_ylabel("Prediction time (s)")
ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset")
ax.legend()