mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-15 20:50:42 +00:00
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.5.4 to 0.9.1. <details> <summary>Release notes</summary> <p><em>Sourced from <a href="https://github.com/astral-sh/ruff/releases">ruff's releases</a>.</em></p> <blockquote> <h2>0.9.1</h2> <h2>Release Notes</h2> <h3>Preview features</h3> <ul> <li>[<code>pycodestyle</code>] Run <code>too-many-newlines-at-end-of-file</code> on each cell in notebooks (<code>W391</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15308">#15308</a>)</li> <li>[<code>ruff</code>] Omit diagnostic for shadowed private function parameters in <code>used-dummy-variable</code> (<code>RUF052</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15376">#15376</a>)</li> </ul> <h3>Rule changes</h3> <ul> <li>[<code>flake8-bugbear</code>] Improve <code>assert-raises-exception</code> message (<code>B017</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15389">#15389</a>)</li> </ul> <h3>Formatter</h3> <ul> <li>Preserve trailing end-of line comments for the last string literal in implicitly concatenated strings (<a href="https://redirect.github.com/astral-sh/ruff/pull/15378">#15378</a>)</li> </ul> <h3>Server</h3> <ul> <li>Fix a bug where the server and client notebooks were out of sync after reordering cells (<a href="https://redirect.github.com/astral-sh/ruff/pull/15398">#15398</a>)</li> </ul> <h3>Bug fixes</h3> <ul> <li>[<code>flake8-pie</code>] Correctly remove wrapping parentheses (<code>PIE800</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15394">#15394</a>)</li> <li>[<code>pyupgrade</code>] Handle comments and multiline expressions correctly (<code>UP037</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15337">#15337</a>)</li> </ul> <h2>Contributors</h2> <ul> <li><a href="https://github.com/AntoineD"><code>@AntoineD</code></a></li> <li><a href="https://github.com/InSyncWithFoo"><code>@InSyncWithFoo</code></a></li> <li><a href="https://github.com/MichaReiser"><code>@MichaReiser</code></a></li> <li><a href="https://github.com/calumy"><code>@calumy</code></a></li> <li><a href="https://github.com/dcreager"><code>@dcreager</code></a></li> <li><a href="https://github.com/dhruvmanila"><code>@dhruvmanila</code></a></li> <li><a href="https://github.com/dylwil3"><code>@dylwil3</code></a></li> <li><a href="https://github.com/sharkdp"><code>@sharkdp</code></a></li> <li><a href="https://github.com/tjkuson"><code>@tjkuson</code></a></li> </ul> <h2>Install ruff 0.9.1</h2> <h3>Install prebuilt binaries via shell script</h3> <pre lang="sh"><code>curl --proto '=https' --tlsv1.2 -LsSf https://github.com/astral-sh/ruff/releases/download/0.9.1/ruff-installer.sh | sh </code></pre> <h3>Install prebuilt binaries via powershell script</h3> <pre lang="sh"><code>powershell -ExecutionPolicy ByPass -c "irm https://github.com/astral-sh/ruff/releases/download/0.9.1/ruff-installer.ps1 | iex" </code></pre> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Changelog</summary> <p><em>Sourced from <a href="https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md">ruff's changelog</a>.</em></p> <blockquote> <h2>0.9.1</h2> <h3>Preview features</h3> <ul> <li>[<code>pycodestyle</code>] Run <code>too-many-newlines-at-end-of-file</code> on each cell in notebooks (<code>W391</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15308">#15308</a>)</li> <li>[<code>ruff</code>] Omit diagnostic for shadowed private function parameters in <code>used-dummy-variable</code> (<code>RUF052</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15376">#15376</a>)</li> </ul> <h3>Rule changes</h3> <ul> <li>[<code>flake8-bugbear</code>] Improve <code>assert-raises-exception</code> message (<code>B017</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15389">#15389</a>)</li> </ul> <h3>Formatter</h3> <ul> <li>Preserve trailing end-of line comments for the last string literal in implicitly concatenated strings (<a href="https://redirect.github.com/astral-sh/ruff/pull/15378">#15378</a>)</li> </ul> <h3>Server</h3> <ul> <li>Fix a bug where the server and client notebooks were out of sync after reordering cells (<a href="https://redirect.github.com/astral-sh/ruff/pull/15398">#15398</a>)</li> </ul> <h3>Bug fixes</h3> <ul> <li>[<code>flake8-pie</code>] Correctly remove wrapping parentheses (<code>PIE800</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15394">#15394</a>)</li> <li>[<code>pyupgrade</code>] Handle comments and multiline expressions correctly (<code>UP037</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15337">#15337</a>)</li> </ul> <h2>0.9.0</h2> <p>Check out the <a href="https://astral.sh/blog/ruff-v0.9.0">blog post</a> for a migration guide and overview of the changes!</p> <h3>Breaking changes</h3> <p>Ruff now formats your code according to the 2025 style guide. As a result, your code might now get formatted differently. See the formatter section for a detailed list of changes.</p> <p>This release doesn’t remove or remap any existing stable rules.</p> <h3>Stabilization</h3> <p>The following rules have been stabilized and are no longer in preview:</p> <ul> <li><a href="https://docs.astral.sh/ruff/rules/stdlib-module-shadowing/"><code>stdlib-module-shadowing</code></a> (<code>A005</code>). This rule has also been renamed: previously, it was called <code>builtin-module-shadowing</code>.</li> <li><a href="https://docs.astral.sh/ruff/rules/builtin-lambda-argument-shadowing/"><code>builtin-lambda-argument-shadowing</code></a> (<code>A006</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/slice-to-remove-prefix-or-suffix/"><code>slice-to-remove-prefix-or-suffix</code></a> (<code>FURB188</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/boolean-chained-comparison/"><code>boolean-chained-comparison</code></a> (<code>PLR1716</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/decimal-from-float-literal/"><code>decimal-from-float-literal</code></a> (<code>RUF032</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/post-init-default/"><code>post-init-default</code></a> (<code>RUF033</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/useless-if-else/"><code>useless-if-else</code></a> (<code>RUF034</code>)</li> </ul> <p>The following behaviors have been stabilized:</p> <ul> <li><a href="https://docs.astral.sh/ruff/rules/pytest-parametrize-names-wrong-type/"><code>pytest-parametrize-names-wrong-type</code></a> (<code>PT006</code>): Detect <a href="https://docs.pytest.org/en/7.1.x/how-to/parametrize.html#parametrize"><code>pytest.parametrize</code></a> calls outside decorators and calls with keyword arguments.</li> </ul> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Commits</summary> <ul> <li><a href="12f86f39a4"><code>12f86f3</code></a> Ruff 0.9.1 (<a href="https://redirect.github.com/astral-sh/ruff/issues/15407">#15407</a>)</li> <li><a href="2b28d566a4"><code>2b28d56</code></a> Associate a trailing end-of-line comment in a parenthesized implicit concaten...</li> <li><a href="adca7bd95c"><code>adca7bd</code></a> Remove pygments pin (<a href="https://redirect.github.com/astral-sh/ruff/issues/15404">#15404</a>)</li> <li><a href="6b98a26452"><code>6b98a26</code></a> [red-knot] Support <code>assert_type</code> (<a href="https://redirect.github.com/astral-sh/ruff/issues/15194">#15194</a>)</li> <li><a href="c87463842a"><code>c874638</code></a> [red-knot] Move tuple-containing-Never tests to Markdown (<a href="https://redirect.github.com/astral-sh/ruff/issues/15402">#15402</a>)</li> <li><a href="c364b586f9"><code>c364b58</code></a> [<code>flake8-pie</code>] Correctly remove wrapping parentheses (<code>PIE800</code>) (<a href="https://redirect.github.com/astral-sh/ruff/issues/15394">#15394</a>)</li> <li><a href="73d424ee5e"><code>73d424e</code></a> Fix outdated doc for handling the default file types with the pre-commit hook...</li> <li><a href="6e9ff445fd"><code>6e9ff44</code></a> Insert the cells from the <code>start</code> position (<a href="https://redirect.github.com/astral-sh/ruff/issues/15398">#15398</a>)</li> <li><a href="f2c3ddc5ea"><code>f2c3ddc</code></a> [red-knot] Move intersection type tests to Markdown (<a href="https://redirect.github.com/astral-sh/ruff/issues/15396">#15396</a>)</li> <li><a href="b861551b6a"><code>b861551</code></a> Remove unnecessary backticks (<a href="https://redirect.github.com/astral-sh/ruff/issues/15393">#15393</a>)</li> <li>Additional commits viewable in <a href="https://github.com/astral-sh/ruff/compare/0.5.4...0.9.1">compare view</a></li> </ul> </details> <br /> [](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) --- <details> <summary>Dependabot commands and options</summary> <br /> You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show <dependency name> ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) </details> --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
234 lines
6.8 KiB
Python
234 lines
6.8 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# pylint: disable=C0411,C0412,C0413
|
|
|
|
"""
|
|
|
|
.. _l-logreg-example-speed:
|
|
|
|
Train, convert and predict with ONNX Runtime
|
|
============================================
|
|
|
|
This example demonstrates an end to end scenario
|
|
starting with the training of a machine learned model
|
|
to its use in its converted from.
|
|
|
|
Train a logistic regression
|
|
+++++++++++++++++++++++++++
|
|
|
|
The first step consists in retrieving the iris datset.
|
|
"""
|
|
|
|
from sklearn.datasets import load_iris
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
iris = load_iris()
|
|
X, y = iris.data, iris.target
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
|
|
|
####################################
|
|
# Then we fit a model.
|
|
|
|
clr = LogisticRegression()
|
|
clr.fit(X_train, y_train)
|
|
|
|
####################################
|
|
# We compute the prediction on the test set
|
|
# and we show the confusion matrix.
|
|
from sklearn.metrics import confusion_matrix # noqa: E402
|
|
|
|
pred = clr.predict(X_test)
|
|
print(confusion_matrix(y_test, pred))
|
|
|
|
####################################
|
|
# Conversion to ONNX format
|
|
# +++++++++++++++++++++++++
|
|
#
|
|
# We use module
|
|
# `sklearn-onnx <https://github.com/onnx/sklearn-onnx>`_
|
|
# to convert the model into ONNX format.
|
|
|
|
from skl2onnx import convert_sklearn # noqa: E402
|
|
from skl2onnx.common.data_types import FloatTensorType # noqa: E402
|
|
|
|
initial_type = [("float_input", FloatTensorType([None, 4]))]
|
|
onx = convert_sklearn(clr, initial_types=initial_type)
|
|
with open("logreg_iris.onnx", "wb") as f:
|
|
f.write(onx.SerializeToString())
|
|
|
|
##################################
|
|
# We load the model with ONNX Runtime and look at
|
|
# its input and output.
|
|
|
|
import onnxruntime as rt # noqa: E402
|
|
|
|
sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())
|
|
|
|
print(f"input name='{sess.get_inputs()[0].name}' and shape={sess.get_inputs()[0].shape}")
|
|
print(f"output name='{sess.get_outputs()[0].name}' and shape={sess.get_outputs()[0].shape}")
|
|
|
|
##################################
|
|
# We compute the predictions.
|
|
|
|
input_name = sess.get_inputs()[0].name
|
|
label_name = sess.get_outputs()[0].name
|
|
|
|
import numpy # noqa: E402
|
|
|
|
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
|
|
print(confusion_matrix(pred, pred_onx))
|
|
|
|
###################################
|
|
# The prediction are perfectly identical.
|
|
#
|
|
# Probabilities
|
|
# +++++++++++++
|
|
#
|
|
# Probabilities are needed to compute other
|
|
# relevant metrics such as the ROC Curve.
|
|
# Let's see how to get them first with
|
|
# scikit-learn.
|
|
|
|
prob_sklearn = clr.predict_proba(X_test)
|
|
print(prob_sklearn[:3])
|
|
|
|
#############################
|
|
# And then with ONNX Runtime.
|
|
# The probabilies appear to be
|
|
|
|
prob_name = sess.get_outputs()[1].name
|
|
prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0]
|
|
|
|
import pprint # noqa: E402
|
|
|
|
pprint.pprint(prob_rt[0:3])
|
|
|
|
###############################
|
|
# Let's benchmark.
|
|
from timeit import Timer # noqa: E402
|
|
|
|
|
|
def speed(inst, number=5, repeat=10):
|
|
timer = Timer(inst, globals=globals())
|
|
raw = numpy.array(timer.repeat(repeat, number=number))
|
|
ave = raw.sum() / len(raw) / number
|
|
mi, ma = raw.min() / number, raw.max() / number
|
|
print(f"Average {ave:1.3g} min={mi:1.3g} max={ma:1.3g}")
|
|
return ave
|
|
|
|
|
|
print("Execution time for clr.predict")
|
|
speed("clr.predict(X_test)")
|
|
|
|
print("Execution time for ONNX Runtime")
|
|
speed("sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]")
|
|
|
|
###############################
|
|
# Let's benchmark a scenario similar to what a webservice
|
|
# experiences: the model has to do one prediction at a time
|
|
# as opposed to a batch of prediction.
|
|
|
|
|
|
def loop(X_test, fct, n=None):
|
|
nrow = X_test.shape[0]
|
|
if n is None:
|
|
n = nrow
|
|
for i in range(n):
|
|
im = i % nrow
|
|
fct(X_test[im : im + 1])
|
|
|
|
|
|
print("Execution time for clr.predict")
|
|
speed("loop(X_test, clr.predict, 50)")
|
|
|
|
|
|
def sess_predict(x):
|
|
return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0]
|
|
|
|
|
|
print("Execution time for sess_predict")
|
|
speed("loop(X_test, sess_predict, 50)")
|
|
|
|
#####################################
|
|
# Let's do the same for the probabilities.
|
|
|
|
print("Execution time for predict_proba")
|
|
speed("loop(X_test, clr.predict_proba, 50)")
|
|
|
|
|
|
def sess_predict_proba(x):
|
|
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]
|
|
|
|
|
|
print("Execution time for sess_predict_proba")
|
|
speed("loop(X_test, sess_predict_proba, 50)")
|
|
|
|
#####################################
|
|
# This second comparison is better as
|
|
# ONNX Runtime, in this experience,
|
|
# computes the label and the probabilities
|
|
# in every case.
|
|
|
|
##########################################
|
|
# Benchmark with RandomForest
|
|
# +++++++++++++++++++++++++++
|
|
#
|
|
# We first train and save a model in ONNX format.
|
|
from sklearn.ensemble import RandomForestClassifier # noqa: E402
|
|
|
|
rf = RandomForestClassifier(n_estimators=10)
|
|
rf.fit(X_train, y_train)
|
|
|
|
initial_type = [("float_input", FloatTensorType([1, 4]))]
|
|
onx = convert_sklearn(rf, initial_types=initial_type)
|
|
with open("rf_iris.onnx", "wb") as f:
|
|
f.write(onx.SerializeToString())
|
|
|
|
###################################
|
|
# We compare.
|
|
|
|
sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers())
|
|
|
|
|
|
def sess_predict_proba_rf(x):
|
|
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]
|
|
|
|
|
|
print("Execution time for predict_proba")
|
|
speed("loop(X_test, rf.predict_proba, 50)")
|
|
|
|
print("Execution time for sess_predict_proba")
|
|
speed("loop(X_test, sess_predict_proba_rf, 50)")
|
|
|
|
##################################
|
|
# Let's see with different number of trees.
|
|
|
|
measures = []
|
|
|
|
for n_trees in range(5, 51, 5):
|
|
print(n_trees)
|
|
rf = RandomForestClassifier(n_estimators=n_trees)
|
|
rf.fit(X_train, y_train)
|
|
initial_type = [("float_input", FloatTensorType([1, 4]))]
|
|
onx = convert_sklearn(rf, initial_types=initial_type)
|
|
with open(f"rf_iris_{n_trees}.onnx", "wb") as f:
|
|
f.write(onx.SerializeToString())
|
|
sess = rt.InferenceSession(f"rf_iris_{n_trees}.onnx", providers=rt.get_available_providers())
|
|
|
|
def sess_predict_proba_loop(x):
|
|
return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] # noqa: B023
|
|
|
|
tsk = speed("loop(X_test, rf.predict_proba, 25)", number=5, repeat=4)
|
|
trt = speed("loop(X_test, sess_predict_proba_loop, 25)", number=5, repeat=4)
|
|
measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt})
|
|
|
|
from pandas import DataFrame # noqa: E402
|
|
|
|
df = DataFrame(measures)
|
|
ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True)
|
|
df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True)
|
|
ax.set_xlabel("Number of trees")
|
|
ax.set_ylabel("Prediction time (s)")
|
|
ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset")
|
|
ax.legend()
|