mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.5.4 to 0.9.1. <details> <summary>Release notes</summary> <p><em>Sourced from <a href="https://github.com/astral-sh/ruff/releases">ruff's releases</a>.</em></p> <blockquote> <h2>0.9.1</h2> <h2>Release Notes</h2> <h3>Preview features</h3> <ul> <li>[<code>pycodestyle</code>] Run <code>too-many-newlines-at-end-of-file</code> on each cell in notebooks (<code>W391</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15308">#15308</a>)</li> <li>[<code>ruff</code>] Omit diagnostic for shadowed private function parameters in <code>used-dummy-variable</code> (<code>RUF052</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15376">#15376</a>)</li> </ul> <h3>Rule changes</h3> <ul> <li>[<code>flake8-bugbear</code>] Improve <code>assert-raises-exception</code> message (<code>B017</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15389">#15389</a>)</li> </ul> <h3>Formatter</h3> <ul> <li>Preserve trailing end-of line comments for the last string literal in implicitly concatenated strings (<a href="https://redirect.github.com/astral-sh/ruff/pull/15378">#15378</a>)</li> </ul> <h3>Server</h3> <ul> <li>Fix a bug where the server and client notebooks were out of sync after reordering cells (<a href="https://redirect.github.com/astral-sh/ruff/pull/15398">#15398</a>)</li> </ul> <h3>Bug fixes</h3> <ul> <li>[<code>flake8-pie</code>] Correctly remove wrapping parentheses (<code>PIE800</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15394">#15394</a>)</li> <li>[<code>pyupgrade</code>] Handle comments and multiline expressions correctly (<code>UP037</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15337">#15337</a>)</li> </ul> <h2>Contributors</h2> <ul> <li><a href="https://github.com/AntoineD"><code>@AntoineD</code></a></li> <li><a href="https://github.com/InSyncWithFoo"><code>@InSyncWithFoo</code></a></li> <li><a href="https://github.com/MichaReiser"><code>@MichaReiser</code></a></li> <li><a href="https://github.com/calumy"><code>@calumy</code></a></li> <li><a href="https://github.com/dcreager"><code>@dcreager</code></a></li> <li><a href="https://github.com/dhruvmanila"><code>@dhruvmanila</code></a></li> <li><a href="https://github.com/dylwil3"><code>@dylwil3</code></a></li> <li><a href="https://github.com/sharkdp"><code>@sharkdp</code></a></li> <li><a href="https://github.com/tjkuson"><code>@tjkuson</code></a></li> </ul> <h2>Install ruff 0.9.1</h2> <h3>Install prebuilt binaries via shell script</h3> <pre lang="sh"><code>curl --proto '=https' --tlsv1.2 -LsSf https://github.com/astral-sh/ruff/releases/download/0.9.1/ruff-installer.sh | sh </code></pre> <h3>Install prebuilt binaries via powershell script</h3> <pre lang="sh"><code>powershell -ExecutionPolicy ByPass -c "irm https://github.com/astral-sh/ruff/releases/download/0.9.1/ruff-installer.ps1 | iex" </code></pre> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Changelog</summary> <p><em>Sourced from <a href="https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md">ruff's changelog</a>.</em></p> <blockquote> <h2>0.9.1</h2> <h3>Preview features</h3> <ul> <li>[<code>pycodestyle</code>] Run <code>too-many-newlines-at-end-of-file</code> on each cell in notebooks (<code>W391</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15308">#15308</a>)</li> <li>[<code>ruff</code>] Omit diagnostic for shadowed private function parameters in <code>used-dummy-variable</code> (<code>RUF052</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15376">#15376</a>)</li> </ul> <h3>Rule changes</h3> <ul> <li>[<code>flake8-bugbear</code>] Improve <code>assert-raises-exception</code> message (<code>B017</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15389">#15389</a>)</li> </ul> <h3>Formatter</h3> <ul> <li>Preserve trailing end-of line comments for the last string literal in implicitly concatenated strings (<a href="https://redirect.github.com/astral-sh/ruff/pull/15378">#15378</a>)</li> </ul> <h3>Server</h3> <ul> <li>Fix a bug where the server and client notebooks were out of sync after reordering cells (<a href="https://redirect.github.com/astral-sh/ruff/pull/15398">#15398</a>)</li> </ul> <h3>Bug fixes</h3> <ul> <li>[<code>flake8-pie</code>] Correctly remove wrapping parentheses (<code>PIE800</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15394">#15394</a>)</li> <li>[<code>pyupgrade</code>] Handle comments and multiline expressions correctly (<code>UP037</code>) (<a href="https://redirect.github.com/astral-sh/ruff/pull/15337">#15337</a>)</li> </ul> <h2>0.9.0</h2> <p>Check out the <a href="https://astral.sh/blog/ruff-v0.9.0">blog post</a> for a migration guide and overview of the changes!</p> <h3>Breaking changes</h3> <p>Ruff now formats your code according to the 2025 style guide. As a result, your code might now get formatted differently. See the formatter section for a detailed list of changes.</p> <p>This release doesn’t remove or remap any existing stable rules.</p> <h3>Stabilization</h3> <p>The following rules have been stabilized and are no longer in preview:</p> <ul> <li><a href="https://docs.astral.sh/ruff/rules/stdlib-module-shadowing/"><code>stdlib-module-shadowing</code></a> (<code>A005</code>). This rule has also been renamed: previously, it was called <code>builtin-module-shadowing</code>.</li> <li><a href="https://docs.astral.sh/ruff/rules/builtin-lambda-argument-shadowing/"><code>builtin-lambda-argument-shadowing</code></a> (<code>A006</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/slice-to-remove-prefix-or-suffix/"><code>slice-to-remove-prefix-or-suffix</code></a> (<code>FURB188</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/boolean-chained-comparison/"><code>boolean-chained-comparison</code></a> (<code>PLR1716</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/decimal-from-float-literal/"><code>decimal-from-float-literal</code></a> (<code>RUF032</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/post-init-default/"><code>post-init-default</code></a> (<code>RUF033</code>)</li> <li><a href="https://docs.astral.sh/ruff/rules/useless-if-else/"><code>useless-if-else</code></a> (<code>RUF034</code>)</li> </ul> <p>The following behaviors have been stabilized:</p> <ul> <li><a href="https://docs.astral.sh/ruff/rules/pytest-parametrize-names-wrong-type/"><code>pytest-parametrize-names-wrong-type</code></a> (<code>PT006</code>): Detect <a href="https://docs.pytest.org/en/7.1.x/how-to/parametrize.html#parametrize"><code>pytest.parametrize</code></a> calls outside decorators and calls with keyword arguments.</li> </ul> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Commits</summary> <ul> <li><a href="12f86f39a4"><code>12f86f3</code></a> Ruff 0.9.1 (<a href="https://redirect.github.com/astral-sh/ruff/issues/15407">#15407</a>)</li> <li><a href="2b28d566a4"><code>2b28d56</code></a> Associate a trailing end-of-line comment in a parenthesized implicit concaten...</li> <li><a href="adca7bd95c"><code>adca7bd</code></a> Remove pygments pin (<a href="https://redirect.github.com/astral-sh/ruff/issues/15404">#15404</a>)</li> <li><a href="6b98a26452"><code>6b98a26</code></a> [red-knot] Support <code>assert_type</code> (<a href="https://redirect.github.com/astral-sh/ruff/issues/15194">#15194</a>)</li> <li><a href="c87463842a"><code>c874638</code></a> [red-knot] Move tuple-containing-Never tests to Markdown (<a href="https://redirect.github.com/astral-sh/ruff/issues/15402">#15402</a>)</li> <li><a href="c364b586f9"><code>c364b58</code></a> [<code>flake8-pie</code>] Correctly remove wrapping parentheses (<code>PIE800</code>) (<a href="https://redirect.github.com/astral-sh/ruff/issues/15394">#15394</a>)</li> <li><a href="73d424ee5e"><code>73d424e</code></a> Fix outdated doc for handling the default file types with the pre-commit hook...</li> <li><a href="6e9ff445fd"><code>6e9ff44</code></a> Insert the cells from the <code>start</code> position (<a href="https://redirect.github.com/astral-sh/ruff/issues/15398">#15398</a>)</li> <li><a href="f2c3ddc5ea"><code>f2c3ddc</code></a> [red-knot] Move intersection type tests to Markdown (<a href="https://redirect.github.com/astral-sh/ruff/issues/15396">#15396</a>)</li> <li><a href="b861551b6a"><code>b861551</code></a> Remove unnecessary backticks (<a href="https://redirect.github.com/astral-sh/ruff/issues/15393">#15393</a>)</li> <li>Additional commits viewable in <a href="https://github.com/astral-sh/ruff/compare/0.5.4...0.9.1">compare view</a></li> </ul> </details> <br /> [](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) --- <details> <summary>Dependabot commands and options</summary> <br /> You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show <dependency name> ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) </details> --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
204 lines
6.4 KiB
Python
204 lines
6.4 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
import importlib.util
|
|
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
import torch
|
|
from packaging.version import Version
|
|
|
|
|
|
def get_device_index(device):
|
|
if isinstance(device, str):
|
|
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
|
device = torch.device(device)
|
|
elif isinstance(device, int):
|
|
return device
|
|
return 0 if device.index is None else device.index
|
|
|
|
|
|
def get_device_str(device):
|
|
if isinstance(device, str):
|
|
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
|
if device.find(":") == -1:
|
|
device += ":" + str(torch.cuda.current_device())
|
|
elif isinstance(device, int):
|
|
device = "cuda:" + str(device)
|
|
elif isinstance(device, torch.device):
|
|
if device.index is None:
|
|
device = device.type + ":" + str(torch.cuda.current_device())
|
|
else:
|
|
device = device.type + ":" + str(device.index)
|
|
else:
|
|
raise RuntimeError("Unsupported device type")
|
|
return device
|
|
|
|
|
|
def dtype_torch_to_numpy(torch_dtype):
|
|
"""Converts PyTorch types to Numpy types
|
|
|
|
Also must map to types accepted by:
|
|
MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type)
|
|
|
|
References:
|
|
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html
|
|
https://pytorch.org/docs/stable/tensors.html
|
|
"""
|
|
if torch_dtype == torch.float64 or torch_dtype == torch.double:
|
|
return np.float64
|
|
elif torch_dtype == torch.float32 or torch_dtype == torch.float:
|
|
return np.float32
|
|
elif torch_dtype == torch.float16 or torch_dtype == torch.half or torch_dtype == torch.bfloat16:
|
|
# NOTE: numpy doesn't support bfloat16
|
|
return np.float16
|
|
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
|
|
return np.longlong # np.int64 doesn't work!?
|
|
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
|
|
return np.int32
|
|
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
|
|
return np.int16
|
|
elif torch_dtype == torch.int8:
|
|
return np.int8
|
|
elif torch_dtype == torch.uint8:
|
|
return np.uint8
|
|
elif torch_dtype == torch.complex64 or (
|
|
# complex32 is missing in torch-1.11.
|
|
(Version(torch.__version__) < Version("1.11.0") or Version(torch.__version__) >= Version("1.12.0"))
|
|
and torch_dtype == torch.complex32
|
|
):
|
|
# NOTE: numpy doesn't support complex32
|
|
return np.complex64
|
|
elif torch_dtype == torch.complex128 or torch_dtype == torch.cdouble:
|
|
return np.complex128
|
|
elif torch_dtype == torch.bool:
|
|
return np.bool_
|
|
else:
|
|
raise ValueError(f"torch_dtype ({torch_dtype!s}) type is not supported by Numpy")
|
|
|
|
|
|
def dtype_onnx_to_torch(onnx_type):
|
|
"""Converts ONNX types to PyTorch types
|
|
|
|
Reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.in.proto (enum DataType)
|
|
https://pytorch.org/docs/stable/tensors.html
|
|
"""
|
|
onnx_types = [
|
|
"UNDEFINED",
|
|
"FLOAT",
|
|
"UINT8",
|
|
"INT8",
|
|
"UINT16",
|
|
"INT16",
|
|
"INT32",
|
|
"INT64",
|
|
"STRING",
|
|
"BOOL",
|
|
"FLOAT16",
|
|
"DOUBLE",
|
|
"UINT32",
|
|
"UINT64",
|
|
"COMPLEX64",
|
|
"COMPLEX128",
|
|
"BFLOAT16",
|
|
"FLOAT8E4M3FN",
|
|
"FLOAT8E4M3FNUZ",
|
|
"FLOAT8E5M2",
|
|
"FLOAT8E5M2FNUZ",
|
|
]
|
|
|
|
if isinstance(onnx_type, int):
|
|
assert onnx_type < len(onnx_types), "Invalid onnx_type integer"
|
|
elif isinstance(onnx_type, str):
|
|
onnx_type = onnx_type.upper()
|
|
assert onnx_type in onnx_types, "Invalid onnx_type string"
|
|
onnx_type = onnx_types.index(onnx_type)
|
|
else:
|
|
raise ValueError("'onnx_type' must be an ONNX type represented by either a string or integer")
|
|
|
|
if onnx_type == 0:
|
|
return None
|
|
elif onnx_type == 1:
|
|
return torch.float
|
|
elif onnx_type >= 2 and onnx_type <= 3:
|
|
# NOTE: Pytorch doesn't support uint8
|
|
return torch.int8
|
|
elif onnx_type >= 4 and onnx_type <= 5:
|
|
# NOTE: Pytorch doesn't support int16
|
|
return torch.int16
|
|
elif onnx_type == 6 or onnx_type == 12:
|
|
# NOTE: Pytorch doesn't support uint32
|
|
return torch.int32
|
|
elif onnx_type == 7 or onnx_type == 13:
|
|
# NOTE: Pytorch doesn't support uint64
|
|
return torch.int64
|
|
elif onnx_type == 8:
|
|
return str
|
|
elif onnx_type == 9:
|
|
return torch.bool
|
|
elif onnx_type == 10:
|
|
return torch.float16
|
|
elif onnx_type == 11:
|
|
return torch.double
|
|
elif onnx_type == 14:
|
|
return torch.complex64
|
|
elif onnx_type == 15:
|
|
return torch.complex128
|
|
elif onnx_type == 16:
|
|
return torch.bfloat
|
|
|
|
|
|
def static_vars(**kwargs):
|
|
r"""Decorator to add :py:attr:`kwargs` as static vars to 'func'
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
>>> @static_vars(counter=0)
|
|
... def myfync():
|
|
... myfync.counter += 1
|
|
... return myfync.counter
|
|
...
|
|
>>> print(myfunc())
|
|
1
|
|
>>> print(myfunc())
|
|
2
|
|
>>> print(myfunc())
|
|
3
|
|
>>> myfunc.counter = 100
|
|
>>> print(myfunc())
|
|
101
|
|
"""
|
|
|
|
def decorate(func):
|
|
for k, v in kwargs.items():
|
|
setattr(func, k, v)
|
|
return func
|
|
|
|
return decorate
|
|
|
|
|
|
def import_module_from_file(file_path, module_name=None):
|
|
"""Import a Python module from a file into interpreter"""
|
|
|
|
if not isinstance(file_path, str) or not os.path.exists(file_path):
|
|
raise AssertionError(
|
|
f"'file_path' must be a full path string with the python file to load. file_path={file_path!r}."
|
|
)
|
|
if module_name is not None and (not isinstance(module_name, str) or not module_name):
|
|
raise AssertionError(
|
|
"'module_name' must be a string with the python module name to load. module_name={module_name!r}."
|
|
)
|
|
|
|
if not module_name:
|
|
module_name = os.path.basename(file_path).split(".")[0]
|
|
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|