diff --git a/.lintrunner.toml b/.lintrunner.toml
index 5ef9ad9337..74744277fa 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -61,7 +61,7 @@ is_formatter = true
[[linter]]
-code = 'BLACK-ISORT'
+code = 'RUFF-FORMAT'
include_patterns = [
'**/*.py',
]
@@ -76,7 +76,7 @@ command = [
'-m',
'lintrunner_adapters',
'run',
- 'black_isort_linter',
+ 'ruff_format_linter',
'--',
'@{{PATHSFILE}}'
]
diff --git a/docs/python/_common/onnx_sphinx.py b/docs/python/_common/onnx_sphinx.py
index 7562d23289..926a2b1d84 100644
--- a/docs/python/_common/onnx_sphinx.py
+++ b/docs/python/_common/onnx_sphinx.py
@@ -2,6 +2,7 @@
"""
Automates the generation of ONNX operators.
"""
+
import importlib
import inspect
import keyword
diff --git a/docs/python/examples/plot_backend.py b/docs/python/examples/plot_backend.py
index 58fb4cd84f..65b5fd0cf7 100644
--- a/docs/python/examples/plot_backend.py
+++ b/docs/python/examples/plot_backend.py
@@ -14,6 +14,7 @@ to run predictions using this runtime.
Let's use the API to compute the prediction
of a simple logistic regression model.
"""
+
import numpy as np
from onnx import load
diff --git a/docs/python/examples/plot_common_errors.py b/docs/python/examples/plot_common_errors.py
index dc7078831a..85cfbf6b97 100644
--- a/docs/python/examples/plot_common_errors.py
+++ b/docs/python/examples/plot_common_errors.py
@@ -15,6 +15,7 @@ It starts by loading the model trained in example
trained on *Iris* datasets. The model takes
a vector of dimension 2 and returns a class among three.
"""
+
import numpy
import onnxruntime as rt
diff --git a/docs/python/examples/plot_convert_pipeline_vectorizer.py b/docs/python/examples/plot_convert_pipeline_vectorizer.py
index 06e9e8d29e..2215cb73ee 100644
--- a/docs/python/examples/plot_convert_pipeline_vectorizer.py
+++ b/docs/python/examples/plot_convert_pipeline_vectorizer.py
@@ -16,6 +16,7 @@ Train a pipeline
The first step consists in creating a dummy datasets.
"""
+
import pandas
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
diff --git a/docs/python/examples/plot_profiling.py b/docs/python/examples/plot_profiling.py
index d35ef72556..6e575ec9eb 100644
--- a/docs/python/examples/plot_profiling.py
+++ b/docs/python/examples/plot_profiling.py
@@ -11,6 +11,7 @@ Profile the execution of a simple model
*ONNX Runtime* can profile the execution of the model.
This example shows how to interpret the results.
"""
+
import numpy
import onnx
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index 9d533af616..c874df8153 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -7,6 +7,7 @@ ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exc
For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_
or the `Github project `_.
"""
+
__version__ = "1.21.0"
__author__ = "Microsoft"
@@ -20,33 +21,35 @@ __author__ = "Microsoft"
# meaningful messages to the user.
# the saved exception is raised after device version validation.
try:
- from onnxruntime.capi._pybind_state import ExecutionMode # noqa: F401
- from onnxruntime.capi._pybind_state import ExecutionOrder # noqa: F401
- from onnxruntime.capi._pybind_state import GraphOptimizationLevel # noqa: F401
- from onnxruntime.capi._pybind_state import LoraAdapter # noqa: F401
- from onnxruntime.capi._pybind_state import ModelMetadata # noqa: F401
- from onnxruntime.capi._pybind_state import NodeArg # noqa: F401
- from onnxruntime.capi._pybind_state import OrtAllocatorType # noqa: F401
- from onnxruntime.capi._pybind_state import OrtArenaCfg # noqa: F401
- from onnxruntime.capi._pybind_state import OrtMemoryInfo # noqa: F401
- from onnxruntime.capi._pybind_state import OrtMemType # noqa: F401
- from onnxruntime.capi._pybind_state import OrtSparseFormat # noqa: F401
- from onnxruntime.capi._pybind_state import RunOptions # noqa: F401
- from onnxruntime.capi._pybind_state import SessionIOBinding # noqa: F401
- from onnxruntime.capi._pybind_state import SessionOptions # noqa: F401
- from onnxruntime.capi._pybind_state import create_and_register_allocator # noqa: F401
- from onnxruntime.capi._pybind_state import create_and_register_allocator_v2 # noqa: F401
- from onnxruntime.capi._pybind_state import disable_telemetry_events # noqa: F401
- from onnxruntime.capi._pybind_state import enable_telemetry_events # noqa: F401
- from onnxruntime.capi._pybind_state import get_all_providers # noqa: F401
- from onnxruntime.capi._pybind_state import get_available_providers # noqa: F401
- from onnxruntime.capi._pybind_state import get_build_info # noqa: F401
- from onnxruntime.capi._pybind_state import get_device # noqa: F401
- from onnxruntime.capi._pybind_state import get_version_string # noqa: F401
- from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401
- from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401
- from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401
- from onnxruntime.capi._pybind_state import set_seed # noqa: F401
+ from onnxruntime.capi._pybind_state import (
+ ExecutionMode, # noqa: F401
+ ExecutionOrder, # noqa: F401
+ GraphOptimizationLevel, # noqa: F401
+ LoraAdapter, # noqa: F401
+ ModelMetadata, # noqa: F401
+ NodeArg, # noqa: F401
+ OrtAllocatorType, # noqa: F401
+ OrtArenaCfg, # noqa: F401
+ OrtMemoryInfo, # noqa: F401
+ OrtMemType, # noqa: F401
+ OrtSparseFormat, # noqa: F401
+ RunOptions, # noqa: F401
+ SessionIOBinding, # noqa: F401
+ SessionOptions, # noqa: F401
+ create_and_register_allocator, # noqa: F401
+ create_and_register_allocator_v2, # noqa: F401
+ disable_telemetry_events, # noqa: F401
+ enable_telemetry_events, # noqa: F401
+ get_all_providers, # noqa: F401
+ get_available_providers, # noqa: F401
+ get_build_info, # noqa: F401
+ get_device, # noqa: F401
+ get_version_string, # noqa: F401
+ has_collective_ops, # noqa: F401
+ set_default_logger_severity, # noqa: F401
+ set_default_logger_verbosity, # noqa: F401
+ set_seed, # noqa: F401
+ )
import_capi_exception = None
except Exception as e:
@@ -57,12 +60,14 @@ from onnxruntime.capi import onnxruntime_validation
if import_capi_exception:
raise import_capi_exception
-from onnxruntime.capi.onnxruntime_inference_collection import AdapterFormat # noqa: F401
-from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession # noqa: F401
-from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401
-from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401
-from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401
-from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401
+from onnxruntime.capi.onnxruntime_inference_collection import (
+ AdapterFormat, # noqa: F401
+ InferenceSession, # noqa: F401
+ IOBinding, # noqa: F401
+ OrtDevice, # noqa: F401
+ OrtValue, # noqa: F401
+ SparseTensor, # noqa: F401
+)
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
try: # noqa: SIM105
diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py
index 67423fe9b5..19f46189e2 100644
--- a/onnxruntime/python/backend/backend.py
+++ b/onnxruntime/python/backend/backend.py
@@ -5,6 +5,7 @@
"""
Implements ONNX's backend API.
"""
+
import os
import unittest
diff --git a/onnxruntime/python/backend/backend_rep.py b/onnxruntime/python/backend/backend_rep.py
index c4dddaaba1..af785b71c5 100644
--- a/onnxruntime/python/backend/backend_rep.py
+++ b/onnxruntime/python/backend/backend_rep.py
@@ -5,6 +5,7 @@
"""
Implements ONNX's backend API.
"""
+
from typing import Any, Tuple # noqa: F401
from onnx.backend.base import BackendRep
diff --git a/onnxruntime/python/datasets/__init__.py b/onnxruntime/python/datasets/__init__.py
index ba64aa8a6e..1a04b37698 100644
--- a/onnxruntime/python/datasets/__init__.py
+++ b/onnxruntime/python/datasets/__init__.py
@@ -3,6 +3,7 @@
"""
Short examples used in the documentation.
"""
+
import os
diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py
index d05fba1928..c12efc7fdf 100644
--- a/onnxruntime/python/onnxruntime_inference_collection.py
+++ b/onnxruntime/python/onnxruntime_inference_collection.py
@@ -115,8 +115,9 @@ def check_and_normalize_provider_args(
def set_provider_options(name, options):
if name not in available_provider_names:
warnings.warn(
- "Specified provider '{}' is not in available provider names."
- "Available providers: '{}'".format(name, ", ".join(available_provider_names))
+ "Specified provider '{}' is not in available provider names.Available providers: '{}'".format(
+ name, ", ".join(available_provider_names)
+ )
)
if name in provider_name_to_options:
diff --git a/onnxruntime/python/onnxruntime_validation.py b/onnxruntime/python/onnxruntime_validation.py
index 4f29c7f424..09ce886c8f 100644
--- a/onnxruntime/python/onnxruntime_validation.py
+++ b/onnxruntime/python/onnxruntime_validation.py
@@ -5,6 +5,7 @@
"""
Check OS requirements for ONNX Runtime Python Bindings.
"""
+
import linecache
import platform
import warnings
diff --git a/onnxruntime/python/tools/profile_explorer/profile_explorer.py b/onnxruntime/python/tools/profile_explorer/profile_explorer.py
index 6e07478839..3c3b8c90f4 100644
--- a/onnxruntime/python/tools/profile_explorer/profile_explorer.py
+++ b/onnxruntime/python/tools/profile_explorer/profile_explorer.py
@@ -86,7 +86,7 @@ def _shape_to_string(shape):
value = next(iter(dict_obj.values()))
if len(res) != 0:
res += ","
- res += f'{key}({"x".join(str(v) for v in value)})'
+ res += f"{key}({'x'.join(str(v) for v in value)})"
return res
diff --git a/onnxruntime/python/tools/pytorch_export_contrib_ops.py b/onnxruntime/python/tools/pytorch_export_contrib_ops.py
index d8cf3c1304..f3cd4c2c89 100644
--- a/onnxruntime/python/tools/pytorch_export_contrib_ops.py
+++ b/onnxruntime/python/tools/pytorch_export_contrib_ops.py
@@ -5,6 +5,7 @@
Support for registering ONNX Runtime's built-in contrib ops with
PyTorch-ONNX exporter (torch.onnx.export).
"""
+
import typing
try:
diff --git a/onnxruntime/python/tools/qnn/add_trans_cast.py b/onnxruntime/python/tools/qnn/add_trans_cast.py
index ced3e3519a..edeaa6b4e2 100644
--- a/onnxruntime/python/tools/qnn/add_trans_cast.py
+++ b/onnxruntime/python/tools/qnn/add_trans_cast.py
@@ -126,9 +126,9 @@ def parse_qnn_json_file(qnn_json_file_path, qnn_input_output_tensor_dic):
qnn_tensor.dim = qnn_tensor_attribute["dims"]
qnn_input_output_tensor_dic[qnn_tensor_name] = qnn_tensor
- assert (
- len(qnn_input_output_tensor_dic) > 1
- ), "Converted QNN model not valid. It should have at least 1 input & 1 output."
+ assert len(qnn_input_output_tensor_dic) > 1, (
+ "Converted QNN model not valid. It should have at least 1 input & 1 output."
+ )
def compare_onnx_shape_with_qnn_shape(onnx_dims, qnn_dims):
diff --git a/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py b/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py
index b7d32fd6b2..7a3e364a08 100644
--- a/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py
+++ b/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py
@@ -150,9 +150,9 @@ def parse_qnn_converter_json_file(qnn_convert_json, qnn_input_tensor_dic, qnn_ou
qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"]
qnn_output_tensor_dic[qnn_tensor_name] = qnn_tensor
- assert (
- len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1
- ), "Converted QNN model not valid. It should have at least 1 input & 1 output."
+ assert len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1, (
+ "Converted QNN model not valid. It should have at least 1 input & 1 output."
+ )
def generate_wrapper_onnx_file(
@@ -286,9 +286,9 @@ def parse_qnn_graph(qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic):
qnn_tensor.offset = 0 - tensor_info["quantizeParams"]["scaleOffset"]["offset"]
qnn_output_tensor_dic[qnn_tensor.name] = qnn_tensor
- assert (
- len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1
- ), "Converted QNN model not valid. It should have at least 1 input & 1 output."
+ assert len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1, (
+ "Converted QNN model not valid. It should have at least 1 input & 1 output."
+ )
return graph_name
diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py
index 712e15a6a1..ac99de348f 100644
--- a/onnxruntime/python/tools/quantization/__init__.py
+++ b/onnxruntime/python/tools/quantization/__init__.py
@@ -7,11 +7,13 @@ from .calibrate import ( # noqa: F401
)
from .qdq_quantizer import QDQQuantizer # noqa: F401
from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401
-from .quantize import DynamicQuantConfig # noqa: F401
-from .quantize import QuantizationMode # noqa: F401
-from .quantize import StaticQuantConfig # noqa: F401
-from .quantize import get_qdq_config # noqa: F401
-from .quantize import quantize # noqa: F401
-from .quantize import quantize_dynamic # noqa: F401
-from .quantize import quantize_static # noqa: F401
+from .quantize import (
+ DynamicQuantConfig, # noqa: F401
+ QuantizationMode, # noqa: F401
+ StaticQuantConfig, # noqa: F401
+ get_qdq_config, # noqa: F401
+ quantize, # noqa: F401
+ quantize_dynamic, # noqa: F401
+ quantize_static, # noqa: F401
+)
from .shape_inference import quant_pre_process # noqa: F401
diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py
index 6235db3234..0cd186bffd 100644
--- a/onnxruntime/python/tools/quantization/base_quantizer.py
+++ b/onnxruntime/python/tools/quantization/base_quantizer.py
@@ -331,9 +331,9 @@ class BaseQuantizer:
scale = np.array(quant_overrides["scale"])
q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
- assert (
- zero_point.dtype != np.float32 and zero_point.dtype != np.float16
- ), f"Unexpected dtype {zero_point.dtype}"
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
+ f"Unexpected dtype {zero_point.dtype}"
+ )
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
else:
@@ -349,9 +349,9 @@ class BaseQuantizer:
)
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
- assert (
- zero_point.dtype != np.float32 and zero_point.dtype != np.float16
- ), f"Unexpected dtype {zero_point.dtype}"
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
+ f"Unexpected dtype {zero_point.dtype}"
+ )
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
scale_dtype = weight.data_type
@@ -465,13 +465,13 @@ class BaseQuantizer:
weight_qType, per_channel_data.flatten(), scale, zero_point
)
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
- assert (
- zero_point.dtype != np.float32 and zero_point.dtype != np.float16
- ), f"Unexpected dtype {zero_point.dtype}"
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
+ f"Unexpected dtype {zero_point.dtype}"
+ )
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
- assert isinstance(
- quantized_per_channel_data, np.ndarray
- ), f"Unexpected type {type(quantized_per_channel_data)}"
+ assert isinstance(quantized_per_channel_data, np.ndarray), (
+ f"Unexpected type {type(quantized_per_channel_data)}"
+ )
else:
zero_point, scale, quantized_per_channel_data = quantize_data(
@@ -485,13 +485,13 @@ class BaseQuantizer:
)
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
- assert (
- zero_point.dtype != np.float32 and zero_point.dtype != np.float16
- ), f"Unexpected dtype {zero_point.dtype}"
+ assert zero_point.dtype != np.float32 and zero_point.dtype != np.float16, (
+ f"Unexpected dtype {zero_point.dtype}"
+ )
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
- assert isinstance(
- quantized_per_channel_data, np.ndarray
- ), f"Unexpected type {type(quantized_per_channel_data)}"
+ assert isinstance(quantized_per_channel_data, np.ndarray), (
+ f"Unexpected type {type(quantized_per_channel_data)}"
+ )
zero_point_list.append(zero_point)
scale_list.append(scale)
diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py
index 4bbb63fef3..7855f260a5 100644
--- a/onnxruntime/python/tools/quantization/calibrate.py
+++ b/onnxruntime/python/tools/quantization/calibrate.py
@@ -820,9 +820,9 @@ class HistogramCollector(CalibrationDataCollector):
for arr in data_arr:
assert isinstance(arr, np.ndarray), f"Unexpected type {type(arr)} for tensor={tensor!r}"
dtypes = set(a.dtype for a in data_arr)
- assert (
- len(dtypes) == 1
- ), f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
+ assert len(dtypes) == 1, (
+ f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
+ )
data_arr_np = np.asarray(data_arr)
elif not isinstance(data_arr, np.ndarray):
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
@@ -842,9 +842,9 @@ class HistogramCollector(CalibrationDataCollector):
# first time it uses num_bins to compute histogram.
hist, hist_edges = np.histogram(data_arr_np, bins=self.num_bins)
hist_edges = hist_edges.astype(data_arr_np.dtype)
- assert (
- data_arr_np.dtype != np.float64
- ), "only float32 or float16 is supported, every constant must be explicitly typed"
+ assert data_arr_np.dtype != np.float64, (
+ "only float32 or float16 is supported, every constant must be explicitly typed"
+ )
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value)
else:
old_histogram = self.histogram_dict[tensor]
@@ -864,9 +864,9 @@ class HistogramCollector(CalibrationDataCollector):
hist, hist_edges = np.histogram(data_arr_np, bins=old_hist_edges)
hist_edges = hist_edges.astype(data_arr_np.dtype)
hist[: len(old_hist)] += old_hist
- assert (
- data_arr_np.dtype != np.float64
- ), "only float32 or float16 is supported, every constant must be explicitly typed"
+ assert data_arr_np.dtype != np.float64, (
+ "only float32 or float16 is supported, every constant must be explicitly typed"
+ )
self.histogram_dict[tensor] = (hist, hist_edges, min(old_min, min_value), max(old_max, max_value))
def collect_value(self, name_to_arr):
diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
index 1d91141a11..4cf9adcd32 100644
--- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
+++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
@@ -1259,7 +1259,6 @@ class MatMul4BitsQuantizer:
self._process_subgraph(graph_stack)
self.model.clean_initializers()
elif self.algo_config.algorithm == "nvidia_awq":
-
# Handle nvidia_awq quantization
logger.info("Processing nvidia_awq quantization...")
self.model = self.node_quantizer.quantize_awq(
@@ -1280,9 +1279,9 @@ class MatMul4BitsQuantizer:
import neural_compressor
- assert version.parse(neural_compressor.__version__) >= version.parse(
- "2.3.2"
- ), "Require neural-compressor >= 2.3.2 to support weight only quantization!"
+ assert version.parse(neural_compressor.__version__) >= version.parse("2.3.2"), (
+ "Require neural-compressor >= 2.3.2 to support weight only quantization!"
+ )
self.int4_quant_algo()
@@ -1446,7 +1445,6 @@ if __name__ == "__main__":
elif args.quant_method == "gptq":
quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize)
elif args.quant_method == "nvidia_awq":
-
if quant_format == QuantFormat.QOperator:
logger.warning("QOperator is not applicable to nvidia_awq. overriding the value to QDQ")
quant_format = QuantFormat.QDQ
diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py
index 922884a5f6..7c5248f90f 100644
--- a/onnxruntime/python/tools/quantization/operators/conv.py
+++ b/onnxruntime/python/tools/quantization/operators/conv.py
@@ -158,7 +158,9 @@ class QLinearConv(QuantOperatorBase):
nodes,
) = self.quantizer.quantize_activation(node, [0])
quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
- node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
+ node.input[1],
+ onnx_proto.TensorProto.INT8,
+ 0, # self.quantizer.weight_qType?
)
quantized_input_names.append(quant_weight_tuple[0])
zero_point_names.append(quant_weight_tuple[1])
diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py
index 5d7bf6e2cd..6b8a389824 100644
--- a/onnxruntime/python/tools/quantization/operators/gemm.py
+++ b/onnxruntime/python/tools/quantization/operators/gemm.py
@@ -3,9 +3,15 @@ import logging
import numpy as np # noqa: F401
import onnx
-from ..quant_utils import find_by_name # noqa: F401
-from ..quant_utils import get_mul_node # noqa: F401
-from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
+from ..quant_utils import (
+ TENSOR_NAME_QUANT_SUFFIX,
+ QuantizedValue,
+ QuantizedValueType,
+ attribute_to_kwarg,
+ find_by_name, # noqa: F401
+ get_mul_node, # noqa: F401
+ ms_domain,
+)
from .base_operator import QuantOperatorBase # noqa: F401
from .matmul import QOpMatMul
from .qdq_base_operator import QDQOperatorBase
diff --git a/onnxruntime/python/tools/quantization/operators/lstm.py b/onnxruntime/python/tools/quantization/operators/lstm.py
index 3ad3147cb8..3a0c94aca6 100644
--- a/onnxruntime/python/tools/quantization/operators/lstm.py
+++ b/onnxruntime/python/tools/quantization/operators/lstm.py
@@ -47,10 +47,14 @@ class LSTMQuant(QuantOperatorBase):
R.dims[0] = R_num_dir * R_4_hidden_size
quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel(
- node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
+ node.input[1],
+ onnx_proto.TensorProto.INT8,
+ 0, # self.quantizer.weight_qType?
)
quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel(
- node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
+ node.input[2],
+ onnx_proto.TensorProto.INT8,
+ 0, # self.quantizer.weight_qType?
)
W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806
diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py
index 5552a4451c..1eed87ba53 100644
--- a/onnxruntime/python/tools/quantization/qdq_quantizer.py
+++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py
@@ -1253,9 +1253,9 @@ class QDQQuantizer(BaseQuantizer):
scale = quant_params["scale"]
zero_point_type = quant_params["quant_type"]
axis: int | None = quant_params.get("axis")
- assert (axis is not None and len(scale.shape) == 1) or (
- axis is None and len(scale.shape) == 0
- ), "Wrong scale/zp shapes"
+ assert (axis is not None and len(scale.shape) == 1) or (axis is None and len(scale.shape) == 0), (
+ "Wrong scale/zp shapes"
+ )
assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank"
zero_point_name = param_name + "_zero_point" + init_name_suffix
diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py
index df53aafeaf..7dd8a7cafc 100644
--- a/onnxruntime/python/tools/quantization/quant_utils.py
+++ b/onnxruntime/python/tools/quantization/quant_utils.py
@@ -197,9 +197,9 @@ def _check_type(*args, zero_point_index=-1):
def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
- assert (
- qType in ONNX_TYPE_TO_NP_TYPE
- ), f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported."
+ assert qType in ONNX_TYPE_TO_NP_TYPE, (
+ f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported."
+ )
if qType in (
onnx_proto.TensorProto.FLOAT8E4M3FN,
onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
@@ -918,10 +918,7 @@ def smooth_distribution(p, eps=0.0001):
def model_has_external_data(model_path: Path):
model = onnx.load(model_path.as_posix(), load_external_data=False)
- for intializer in model.graph.initializer:
- if external_data_helper.uses_external_data(intializer):
- return True
- return False
+ return any(external_data_helper.uses_external_data(intializer) for intializer in model.graph.initializer)
def optimize_model(model_path: Path, opt_model_path: Path):
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index f88011c7a2..b9ff215902 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -1814,12 +1814,12 @@ class SymbolicShapeInference:
def replace_min_with_arg(arg_idx):
replaced = list(expr.args)
- assert isinstance(
- replaced[min_pos], sympy.Min
- ), f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}"
- assert (
- len(replaced[min_pos].args) == 2
- ), f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}"
+ assert isinstance(replaced[min_pos], sympy.Min), (
+ f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}"
+ )
+ assert len(replaced[min_pos].args) == 2, (
+ f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}"
+ )
replaced[min_pos] = replaced[min_pos].args[arg_idx]
return sympy.Add(*replaced)
diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py
index 450474d96d..54027a5a70 100644
--- a/onnxruntime/python/tools/transformers/benchmark.py
+++ b/onnxruntime/python/tools/transformers/benchmark.py
@@ -13,33 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-""" Benchmarking the inference of pretrained transformer models.
- PyTorch/TorchScript benchmark is based on https://github.com/huggingface/transformers/blob/master/examples/benchmarks.py.
- One difference is that random input_ids is generated in this benchmark.
+"""Benchmarking the inference of pretrained transformer models.
+PyTorch/TorchScript benchmark is based on https://github.com/huggingface/transformers/blob/master/examples/benchmarks.py.
+One difference is that random input_ids is generated in this benchmark.
- For onnxruntime, this script will convert a pretrained model to ONNX, and optimize it when -o parameter is used.
+For onnxruntime, this script will convert a pretrained model to ONNX, and optimize it when -o parameter is used.
- Example commands:
- Export all models to ONNX, optimize and validate them:
- python benchmark.py -b 0 -o -v -i 1 2 3
- Run OnnxRuntime on GPU for all models:
- python benchmark.py -g
- Run OnnxRuntime on GPU for all models with fp32 optimization:
- python benchmark.py -g -o
- Run OnnxRuntime on GPU with fp16 optimization:
- python benchmark.py -g -o -p "fp16"
- Run TorchScript on GPU for all models:
- python benchmark.py -e torchscript -g
- Run TorchScript on GPU for all models with fp16:
- python benchmark.py -e torchscript -g -p "fp16"
- Run ONNXRuntime and TorchScript on CPU for all models with quantization:
- python benchmark.py -e torchscript onnxruntime -p "int8" -o
- Run OnnxRuntime with the ROCM provider and graph optimization script:
- python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm
- Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support:
- python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm
+Example commands:
+ Export all models to ONNX, optimize and validate them:
+ python benchmark.py -b 0 -o -v -i 1 2 3
+ Run OnnxRuntime on GPU for all models:
+ python benchmark.py -g
+ Run OnnxRuntime on GPU for all models with fp32 optimization:
+ python benchmark.py -g -o
+ Run OnnxRuntime on GPU with fp16 optimization:
+ python benchmark.py -g -o -p "fp16"
+ Run TorchScript on GPU for all models:
+ python benchmark.py -e torchscript -g
+ Run TorchScript on GPU for all models with fp16:
+ python benchmark.py -e torchscript -g -p "fp16"
+ Run ONNXRuntime and TorchScript on CPU for all models with quantization:
+ python benchmark.py -e torchscript onnxruntime -p "int8" -o
+ Run OnnxRuntime with the ROCM provider and graph optimization script:
+ python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm
+ Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support:
+ python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm
- It is recommended to use run_benchmark.sh to launch benchmark.
+It is recommended to use run_benchmark.sh to launch benchmark.
"""
import argparse
@@ -439,9 +439,9 @@ def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
return func(*args, **kwargs)
if do_eager_mode is True:
- assert (
- use_xla is False
- ), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
+ assert use_xla is False, (
+ "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
+ )
return run_in_eager_mode
else:
return run_in_graph_mode
diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py
index 66f7a63447..d88e689521 100644
--- a/onnxruntime/python/tools/transformers/benchmark_helper.py
+++ b/onnxruntime/python/tools/transformers/benchmark_helper.py
@@ -167,9 +167,9 @@ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
if use_gpu:
if provider == "dml":
- assert (
- "DmlExecutionProvider" in onnxruntime.get_available_providers()
- ), "Please install onnxruntime-directml package to test GPU inference."
+ assert "DmlExecutionProvider" in onnxruntime.get_available_providers(), (
+ "Please install onnxruntime-directml package to test GPU inference."
+ )
else:
assert not set(onnxruntime.get_available_providers()).isdisjoint(
diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py
index 74adc951c4..e9ac4a64f9 100644
--- a/onnxruntime/python/tools/transformers/float16.py
+++ b/onnxruntime/python/tools/transformers/float16.py
@@ -201,9 +201,9 @@ def convert_float_to_float16(
Returns:
ModelProto: converted model.
"""
- assert (
- min_positive_val >= 5.96e-08
- ), "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05"
+ assert min_positive_val >= 5.96e-08, (
+ "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05"
+ )
assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504"
force_fp16_inputs_dict = {} if force_fp16_inputs is None else force_fp16_inputs
diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
index 048c13cdb1..9a353e7e2d 100644
--- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py
+++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
@@ -373,7 +373,9 @@ class FusionAttentionUnet(Fusion):
else "MultiHeadAttention ({})".format(
"self attention with packed qkv"
if self.enable_packed_qkv
- else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
+ else "cross attention with packed kv"
+ if self.enable_packed_kv
+ else "cross attention"
)
)
self.increase_counter(counter_name)
@@ -841,7 +843,9 @@ class FusionAttentionUnet(Fusion):
else "MultiHeadAttention ({})".format(
"self attention with packed qkv"
if self.enable_packed_qkv
- else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
+ else "cross attention with packed kv"
+ if self.enable_packed_kv
+ else "cross attention"
)
)
self.increase_counter(counter_name)
diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py
index 0eaccc0faf..f623102802 100644
--- a/onnxruntime/python/tools/transformers/large_model_exporter.py
+++ b/onnxruntime/python/tools/transformers/large_model_exporter.py
@@ -6,6 +6,7 @@
"""
Export LLM to onnx
"""
+
import argparse
import inspect
import math
@@ -173,8 +174,8 @@ def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.
"""
total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024
- print(f"Model_Size = {get_model_parameter_size(model)/1024} GB")
- print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB")
+ print(f"Model_Size = {get_model_parameter_size(model) / 1024} GB")
+ print(f"total_mem_per_cpu = {total_mem_per_cpu / 1024} GB")
if get_model_parameter_size(model) > total_mem_per_cpu * 0.45:
device_collection = [torch.device(i) for i in range(torch.cuda.device_count())]
if len(device_collection) > 1:
@@ -228,9 +229,9 @@ def fetch_onnx_inputs_outputs_name(
onnx_inp_names = tuple(
[torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
)
- assert (
- "input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names
- ), "input_ids and attention_mask must be existed in inputs"
+ assert "input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names, (
+ "input_ids and attention_mask must be existed in inputs"
+ )
onnx_out_names = ("logits",)
onnx_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py
index 9153193a49..1b12fe9005 100644
--- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py
+++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py
@@ -889,11 +889,11 @@ class Gpt2Helper:
result["nan_rate"] = (total_test_cases - len(max_abs_diff_list)) * 1.0 / total_test_cases
logger.info(
- f"Parity Test Cases={total_test_cases}; Passed={passed_test_cases}; Nan={total_test_cases-len(max_abs_diff_list)}; Top1_Matched={top1_matched_cases}"
+ f"Parity Test Cases={total_test_cases}; Passed={passed_test_cases}; Nan={total_test_cases - len(max_abs_diff_list)}; Top1_Matched={top1_matched_cases}"
)
if passed_test_cases > 0.95 * total_test_cases:
- logger.info(f"Parity is good: passed rate={int(passed_test_cases*100/total_test_cases):.0f}%")
+ logger.info(f"Parity is good: passed rate={int(passed_test_cases * 100 / total_test_cases):.0f}%")
return result
diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py
index d05de369b3..61bfc95073 100644
--- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py
@@ -642,9 +642,9 @@ def get_args(rank=0):
# Check that only one (batch_size, sequence_length) combination is set for profiling
if args.profile:
- assert (
- len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1
- ), "Please provide only one (batch_size, sequence_length) combination for profiling"
+ assert len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1, (
+ "Please provide only one (batch_size, sequence_length) combination for profiling"
+ )
return args
diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
index 9f6f86fc28..db78d837f8 100644
--- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
+++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
@@ -259,14 +259,16 @@ def get_args():
help="Use when GroupQueryAttention (GQA) is in ONNX model",
)
- parser.add_argument(
- "--anomaly-filtering",
- default=False,
- action="store_true",
- help="Use this flag to filter anomaly accelerator times for tokens generated. \
+ (
+ parser.add_argument(
+ "--anomaly-filtering",
+ default=False,
+ action="store_true",
+ help="Use this flag to filter anomaly accelerator times for tokens generated. \
This may give more accurate latency and throughput metrics for tokens generated. \
Wall-clock metrics are still reported with anomaly times though.",
- ),
+ ),
+ )
parser.add_argument(
"-b",
diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
index f5446ed718..7bf8bcb82e 100644
--- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
@@ -455,9 +455,8 @@ def smooth_quant(
decoder_model_int8_path: str,
decoder_with_past_model_int8_path: str,
):
- from neural_compressor import PostTrainingQuantConfig
+ from neural_compressor import PostTrainingQuantConfig, set_workspace
from neural_compressor import quantization as intel_quantization
- from neural_compressor import set_workspace
from onnx.external_data_helper import load_external_data_for_model
from quant_kv_dataloader import QuantKVDataLoader
diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py
index ab92a12343..274d56df3f 100644
--- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py
+++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py
@@ -148,9 +148,9 @@ def test_ort_latency(
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
for global_length in global_lengths:
- assert (
- global_length <= model.config.attention_window[0]
- ), "Limitation of current implementation: number of global token <= attention_window"
+ assert global_length <= model.config.attention_window[0], (
+ "Limitation of current implementation: number of global token <= attention_window"
+ )
logger.info(
f"Testing batch_size={batch_size} sequence_length={sequence_length} global_length={global_length} "
diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py
index 5eafb29713..07ed150631 100644
--- a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py
+++ b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py
@@ -212,7 +212,6 @@ def test_decoder_onnx(
onnx_model_path: str,
multimask_output=False,
):
-
batch_size = 1
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py
index 9533e2652f..af6b0e17e7 100644
--- a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py
+++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py
@@ -76,7 +76,7 @@ def show_masks(
show_box(box_coords, plt.gca())
if len(scores) > 1:
- plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
+ plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis("off")
if output_image_file_prefix:
diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py
index 363b5daf46..3c0c886b87 100644
--- a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py
+++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py
@@ -136,9 +136,9 @@ class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
- assert (
- len(input_image.shape) == 4 and input_image.shape[1] == 3
- ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
+ assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
+ f"input_image must be of size 1x3xHxW, got {input_image.shape}"
+ )
# Computing image embeddings for the provided image
io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
index 0452cff235..74652239bc 100755
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
@@ -1368,9 +1368,9 @@ def main():
use_io_binding=args.use_io_binding,
)
elif args.engine == "onnxruntime":
- assert args.pipeline and os.path.isdir(
- args.pipeline
- ), "--pipeline should be specified for the directory of ONNX models"
+ assert args.pipeline and os.path.isdir(args.pipeline), (
+ "--pipeline should be specified for the directory of ONNX models"
+ )
print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}")
result = run_ort(
model_name=sd_model,
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py
index 57cb51bbea..41d2d267c5 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py
@@ -156,8 +156,7 @@ class DDIMScheduler:
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
- f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction`"
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`"
)
# 4. Clip "predicted x_0"
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py
index 522cc541c1..ac955f5014 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py
@@ -568,7 +568,7 @@ class StableDiffusionPipeline:
prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20]
parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)]
image_path = os.path.join(self.output_dir, "-".join(parts) + ".png")
- print(f"Saving image {i+1} / {len(images)} to: {image_path}")
+ print(f"Saving image {i + 1} / {len(images)} to: {image_path}")
from PIL import PngImagePlugin
diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py
index 2a6f9c3d75..33506d6d00 100644
--- a/onnxruntime/python/tools/transformers/onnx_model.py
+++ b/onnxruntime/python/tools/transformers/onnx_model.py
@@ -1284,7 +1284,7 @@ class OnnxModel:
op_count[op] = 1 if op not in op_count else (op_count[op] + 1)
# Sorted by count in the descending order, then by key in alphabetical order.
- logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv:(-kv[1], kv[0]))}")
+ logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv: (-kv[1], kv[0]))}")
return op_count
diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py
index 6a25196dbc..9e44921bde 100644
--- a/onnxruntime/python/tools/transformers/quantize_helper.py
+++ b/onnxruntime/python/tools/transformers/quantize_helper.py
@@ -64,7 +64,7 @@ class QuantizeHelper:
from onnxruntime.quantization import quantize_dynamic
Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True)
- logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path)/(1024*1024)}")
+ logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path) / (1024 * 1024)}")
quantize_dynamic(
onnx_model_path,
quantized_model_path,
@@ -73,4 +73,4 @@ class QuantizeHelper:
)
logger.info(f"quantized model saved to:{quantized_model_path}")
# TODO: inlcude external data in total model size.
- logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path)/(1024*1024)}")
+ logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path) / (1024 * 1024)}")
diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py
index 7dcd6484a5..796a58f1a9 100644
--- a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py
+++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py
@@ -49,13 +49,13 @@ if args.dim is None or args.dim == 2:
print(f' OpTester test("AffineGrid", {opset_version});')
print(f' test.AddAttribute("align_corners", (int64_t){1 if align_corners else 0});')
print(
- f" test.AddInput(\"theta\", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{', '.join([f'{x:.6f}f' for x in theta.flatten()])}}});"
+ f' test.AddInput("theta", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{", ".join([f"{x:.6f}f" for x in theta.flatten()])}}});'
)
print(
f' test.AddInput("size", {{{len(size)}}}, {{{size[0]}, {size[1]}, {size[2]}, {size[3]}}});'
)
print(
- f" test.AddOutput(\"grid\", {{{size[0]}, {size[2]}, {size[3]}, 2}}, {{{', '.join([f'{x:.4f}f' for x in grid.flatten()])}}});"
+ f' test.AddOutput("grid", {{{size[0]}, {size[2]}, {size[3]}, 2}}, {{{", ".join([f"{x:.4f}f" for x in grid.flatten()])}}});'
)
print(" test.Run();")
print("}\n")
@@ -104,13 +104,13 @@ if args.dim is None or args.dim == 3:
print(f' OpTester test("AffineGrid", {opset_version});')
print(f' test.AddAttribute("align_corners", (int64_t){1 if align_corners else 0});')
print(
- f" test.AddInput(\"theta\", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{', '.join([f'{x:.6f}f' for x in theta.flatten()])}}});"
+ f' test.AddInput("theta", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{", ".join([f"{x:.6f}f" for x in theta.flatten()])}}});'
)
print(
f' test.AddInput("size", {{{len(size)}}}, {{{size[0]}, {size[1]}, {size[2]}, {size[3]}, {size[4]}}});'
)
print(
- f" test.AddOutput(\"grid\", {{{size[0]}, {size[2]}, {size[3]}, {size[4]}, 3}}, {{{', '.join([f'{x:.4f}f' for x in grid.flatten()])}}});"
+ f' test.AddOutput("grid", {{{size[0]}, {size[2]}, {size[3]}, {size[4]}, 3}}, {{{", ".join([f"{x:.4f}f" for x in grid.flatten()])}}});'
)
print(" test.Run();")
print("}\n")
diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
index bf58a5d3fc..627b681793 100644
--- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
+++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
@@ -80,11 +80,11 @@ for opset_version in [16, 20]:
print(f'{spaces}std::string padding_mode = "{padding_mode}";')
print(f"{spaces}int64_t align_corners = {onnx_align_corners};")
print(f"{spaces}std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};")
- print(f"{spaces}std::initializer_list X_data { X_data_str };")
+ print(f"{spaces}std::initializer_list X_data {X_data_str};")
print(f"{spaces}std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};")
- print(f"{spaces}std::initializer_list Grid_data { Grid_data_str };")
+ print(f"{spaces}std::initializer_list Grid_data {Grid_data_str};")
print(f"{spaces}std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};")
- print(f"{spaces}std::initializer_list Y_data { Y_data_str };")
+ print(f"{spaces}std::initializer_list Y_data {Y_data_str};")
print(f'{spaces}test.AddInput("X", X_shape, X_data);')
print(f'{spaces}test.AddInput("Grid", Grid_shape, Grid_data);')
diff --git a/onnxruntime/test/python/onnxruntime_test_float8.py b/onnxruntime/test/python/onnxruntime_test_float8.py
index bb63ea2344..29aede0784 100644
--- a/onnxruntime/test/python/onnxruntime_test_float8.py
+++ b/onnxruntime/test/python/onnxruntime_test_float8.py
@@ -354,8 +354,7 @@ class TestInferenceSession(unittest.TestCase):
assert_allclose(expect, y)
except AssertionError as e:
raise AssertionError(
- f"Discrepancies with name={name}, float_name={float_name}, "
- f"saturate={saturate}\nexpect={expect}\ny={y}"
+ f"Discrepancies with name={name}, float_name={float_name}, saturate={saturate}\nexpect={expect}\ny={y}"
) from e
self.assertEqual(expect.shape, y.shape)
self.assertEqual(expect.dtype, y.dtype)
@@ -394,8 +393,7 @@ class TestInferenceSession(unittest.TestCase):
assert_allclose(expect, y)
except AssertionError as e:
raise AssertionError(
- f"Discrepancies with name={name}, float_name={float_name}, "
- f"saturate={saturate}\nexpect={expect}\ny={y}"
+ f"Discrepancies with name={name}, float_name={float_name}, saturate={saturate}\nexpect={expect}\ny={y}"
) from e
self.assertEqual(expect.shape, y.shape)
self.assertEqual(expect.dtype, y.dtype)
@@ -608,8 +606,7 @@ class TestInferenceSession(unittest.TestCase):
if not saturate:
return
raise AssertionError(
- f"Discrepancies with name={name}, float_name={float_name}, "
- f"saturate={saturate}\nexpect={expect}\ny={y}"
+ f"Discrepancies with name={name}, float_name={float_name}, saturate={saturate}\nexpect={expect}\ny={y}"
) from e
self.assertEqual(expect.shape, y.shape)
self.assertEqual(expect.dtype, y.dtype)
diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
index 2dba8ff532..c9876d3d55 100644
--- a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
+++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
@@ -173,16 +173,16 @@ class TestFloat8Gemm8(unittest.TestCase):
raise AssertionError(
f"Gemm ERROR len(inputs)={len(feeds)}"
- f"\na@b=\n{check(lambda:a@b)}"
- f"\na.T@b=\n{check(lambda:a.T@b)}"
- f"\na@b.T=\n{check(lambda:a@b.T)}"
- f"\na.T@b.T=\n{check(lambda:a.T@b.T)}"
- f"\n----\nb@a=\n{check(lambda:b@a)}"
- f"\nb.T@a=\n{check(lambda:b.T@a)}"
- f"\nb@a.T=\n{check(lambda:b@a.T)}"
- f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}"
- f"\n----\nexpected=\n{expected[:2,:2]}"
- f"\n----\ngot=\n{y[:2,:2]}"
+ f"\na@b=\n{check(lambda: a @ b)}"
+ f"\na.T@b=\n{check(lambda: a.T @ b)}"
+ f"\na@b.T=\n{check(lambda: a @ b.T)}"
+ f"\na.T@b.T=\n{check(lambda: a.T @ b.T)}"
+ f"\n----\nb@a=\n{check(lambda: b @ a)}"
+ f"\nb.T@a=\n{check(lambda: b.T @ a)}"
+ f"\nb@a.T=\n{check(lambda: b @ a.T)}"
+ f"\nb.T@a.T=\n{check(lambda: b.T @ a.T)}"
+ f"\n----\nexpected=\n{expected[:2, :2]}"
+ f"\n----\ngot=\n{y[:2, :2]}"
f"\nkwargs={kwargs}"
) from e
@@ -225,16 +225,16 @@ class TestFloat8Gemm8(unittest.TestCase):
raise AssertionError(
f"Gemm ERROR len(inputs)={len(feeds)}"
- f"\na@b=\n{check(lambda:a@b)}"
- f"\na.T@b=\n{check(lambda:a.T@b)}"
- f"\na@b.T=\n{check(lambda:a@b.T)}"
- f"\na.T@b.T=\n{check(lambda:a.T@b.T)}"
- f"\n----\nb@a=\n{check(lambda:b@a)}"
- f"\nb.T@a=\n{check(lambda:b.T@a)}"
- f"\nb@a.T=\n{check(lambda:b@a.T)}"
- f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}"
- f"\n----\nexpected=\n{expected[:2,:2]}"
- f"\n----\ngot=\n{y[:2,:2]}"
+ f"\na@b=\n{check(lambda: a @ b)}"
+ f"\na.T@b=\n{check(lambda: a.T @ b)}"
+ f"\na@b.T=\n{check(lambda: a @ b.T)}"
+ f"\na.T@b.T=\n{check(lambda: a.T @ b.T)}"
+ f"\n----\nb@a=\n{check(lambda: b @ a)}"
+ f"\nb.T@a=\n{check(lambda: b.T @ a)}"
+ f"\nb@a.T=\n{check(lambda: b @ a.T)}"
+ f"\nb.T@a.T=\n{check(lambda: b.T @ a.T)}"
+ f"\n----\nexpected=\n{expected[:2, :2]}"
+ f"\n----\ngot=\n{y[:2, :2]}"
f"\nkwargs={kwargs}"
) from e
self.assertEqual(expected.shape, y.shape)
diff --git a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py
index 76fc78e376..77f9e6f5cf 100644
--- a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py
+++ b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py
@@ -223,7 +223,6 @@ class TestIOBinding(unittest.TestCase):
for inner_device, provider in devices:
for onnx_dtype, torch_dtype in onnx_to_torch_type_map.items():
with self.subTest(onnx_dtype=onnx_dtype, inner_device=str(inner_device)):
-
# Create onnx graph with dynamic axes
X = helper.make_tensor_value_info("X", onnx_dtype, [None]) # noqa: N806
Y = helper.make_tensor_value_info("Y", onnx_dtype, [None]) # noqa: N806
diff --git a/onnxruntime/test/python/quantization/test_conv_dynamic.py b/onnxruntime/test/python/quantization/test_conv_dynamic.py
index f6ee3fe97a..5892e18bae 100644
--- a/onnxruntime/test/python/quantization/test_conv_dynamic.py
+++ b/onnxruntime/test/python/quantization/test_conv_dynamic.py
@@ -10,9 +10,13 @@ import unittest
import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper
-from op_test_utils import TestDataFeeds # noqa: F401
-from op_test_utils import check_op_type_order # noqa: F401
-from op_test_utils import check_model_correctness, check_op_type_count, check_qtype_by_node_type
+from op_test_utils import (
+ TestDataFeeds, # noqa: F401
+ check_model_correctness,
+ check_op_type_count,
+ check_op_type_order, # noqa: F401
+ check_qtype_by_node_type,
+)
from onnxruntime.quantization import DynamicQuantConfig, QuantType, quantize, quantize_dynamic
diff --git a/onnxruntime/test/python/quantization/test_op_pooling.py b/onnxruntime/test/python/quantization/test_op_pooling.py
index 539affc314..5364171307 100644
--- a/onnxruntime/test/python/quantization/test_op_pooling.py
+++ b/onnxruntime/test/python/quantization/test_op_pooling.py
@@ -10,8 +10,13 @@ import unittest
import numpy as np
import onnx
from onnx import TensorProto, helper
-from op_test_utils import check_op_nodes # noqa: F401
-from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type
+from op_test_utils import (
+ TestDataFeeds,
+ check_model_correctness,
+ check_op_nodes, # noqa: F401
+ check_op_type_count,
+ check_qtype_by_node_type,
+)
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py
index 23b397ffd8..178cb9d876 100644
--- a/onnxruntime/test/python/quantization/test_qdq.py
+++ b/onnxruntime/test/python/quantization/test_qdq.py
@@ -759,12 +759,12 @@ class TestQDQFormatConvRelu(TestQDQFormat):
QuantType.QInt16: TensorProto.INT16,
QuantType.QUInt16: TensorProto.UINT16,
}
- assert (
- weight_type not in to_tensor_types or to_tensor_types[weight_type] in zero_types
- ), f"weight_type={weight_type} not in zero_types={zero_types}"
- assert (
- activation_type not in to_tensor_types or to_tensor_types[activation_type] in zero_types
- ), f"activation_type={activation_type} not in zero_types={zero_types}"
+ assert weight_type not in to_tensor_types or to_tensor_types[weight_type] in zero_types, (
+ f"weight_type={weight_type} not in zero_types={zero_types}"
+ )
+ assert activation_type not in to_tensor_types or to_tensor_types[activation_type] in zero_types, (
+ f"activation_type={activation_type} not in zero_types={zero_types}"
+ )
check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next(), rtol=rtol, atol=atol)
diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py
index 41dae04f1c..5617a424cf 100644
--- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py
+++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py
@@ -1195,7 +1195,9 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
# get_qnn_qdq_config() should be able to validate the per-channel axis without having to load
# the external weight data.
qnn_config = get_qnn_qdq_config(
- str(model_path), DummyDataReader([]), init_overrides=init_overrides # Dummy data reader does nothing
+ str(model_path),
+ DummyDataReader([]),
+ init_overrides=init_overrides, # Dummy data reader does nothing
)
self.assertEqual(set(qnn_config.op_types_to_quantize), {"Conv"})
self.assertTrue(qnn_config.use_external_data_format)
diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py
index 53d015a029..5cef4ae863 100644
--- a/onnxruntime/test/python/transformers/benchmark_gqa.py
+++ b/onnxruntime/test/python/transformers/benchmark_gqa.py
@@ -6,6 +6,7 @@
"""
Benchmark performance of GroupQueryAttention.
"""
+
from typing import Optional
import torch
diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py
index 5b27a46ea0..71e4f2b63c 100644
--- a/onnxruntime/test/python/transformers/conformer_model_generator.py
+++ b/onnxruntime/test/python/transformers/conformer_model_generator.py
@@ -22,7 +22,9 @@ def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False
weights = (
[np.random.uniform(low, high) for _ in range(total_elements)]
if random
- else [0.0] * total_elements if zeros else [1.0] * total_elements
+ else [0.0] * total_elements
+ if zeros
+ else [1.0] * total_elements
)
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights
diff --git a/onnxruntime/test/python/transformers/parity_utilities.py b/onnxruntime/test/python/transformers/parity_utilities.py
index d7f79304d2..376b684c76 100644
--- a/onnxruntime/test/python/transformers/parity_utilities.py
+++ b/onnxruntime/test/python/transformers/parity_utilities.py
@@ -115,9 +115,9 @@ def optimize_onnx(
onnx_model.save_model_to_file(optimized_onnx_path)
if expected_op is not None:
- assert (
- len(onnx_model.get_nodes_by_op_type(expected_op)) == 1
- ), f"Expected {expected_op} node not found in the optimized model {optimized_onnx_path}"
+ assert len(onnx_model.get_nodes_by_op_type(expected_op)) == 1, (
+ f"Expected {expected_op} node not found in the optimized model {optimized_onnx_path}"
+ )
def diff_outputs(torch_outputs, ort_outputs, index):
diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py
index 45726ecc7c..6f396f35f7 100644
--- a/onnxruntime/test/python/transformers/test_mha.py
+++ b/onnxruntime/test/python/transformers/test_mha.py
@@ -183,14 +183,14 @@ def mha_with_past_reference(
assert config.kv_sequence_length == config.sequence_length
assert config.use_kv_cache
if past_k is not None:
- assert (
- past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1)
- ), f"expect BNSH format: {past_k.shape=} {k.shape=}"
+ assert past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1), (
+ f"expect BNSH format: {past_k.shape=} {k.shape=}"
+ )
if past_v is not None:
- assert (
- past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1)
- ), f"expect BNSH format: {past_v.shape=} {v.shape=}"
+ assert past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1), (
+ f"expect BNSH format: {past_v.shape=} {v.shape=}"
+ )
present_k = torch.cat((past_k, k), dim=2) if past_k is not None else k
present_v = torch.cat((past_v, v), dim=2) if past_v is not None else v
@@ -533,7 +533,6 @@ def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=No
def merge_padding_and_causal_masks(config):
-
q_mask, k_mask, mask = config.right_side_padding_masks()
if config.causal:
query_padding_mask = q_mask.reshape(config.batch_size, config.sequence_length)
diff --git a/onnxruntime/test/python/transformers/test_parity_t5_mha.py b/onnxruntime/test/python/transformers/test_parity_t5_mha.py
index 84708ddcf8..7eae2f0a23 100644
--- a/onnxruntime/test/python/transformers/test_parity_t5_mha.py
+++ b/onnxruntime/test/python/transformers/test_parity_t5_mha.py
@@ -418,9 +418,9 @@ class T5Attention(nn.Module):
real_seq_length = seq_length
if past_key_value is not None:
- assert (
- len(past_key_value) == 2
- ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ assert len(past_key_value) == 2, (
+ f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
+ )
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
@@ -538,9 +538,9 @@ class T5Attention(nn.Module):
real_seq_length = seq_length
if past_key_value is not None:
- assert (
- len(past_key_value) == 2
- ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ assert len(past_key_value) == 2, (
+ f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
+ )
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
def project(hidden_states, proj_layer, key_value_states, past_key_value):
diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py
index 373ad86ced..aba0ccdac2 100644
--- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py
+++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py
@@ -1026,14 +1026,14 @@ class TestRotaryAttentionFusion(unittest.TestCase):
unsqueeze_0_node = helper.make_node(
"Unsqueeze",
inputs=[gather_0_node.output[0] if not use_mul_and_add_nodes_0 else "mul_extra_out", "zero"],
- outputs=[f"unsqueeze_extra_{2*i}"],
- name=f"Unsqueeze_extra_{2*i}",
+ outputs=[f"unsqueeze_extra_{2 * i}"],
+ name=f"Unsqueeze_extra_{2 * i}",
)
unsqueeze_1_node = helper.make_node(
"Unsqueeze",
inputs=[gather_1_node.output[0] if not use_mul_and_add_nodes_1 else "add_extra_out", "zero"],
- outputs=[f"unsqueeze_extra_{2*i + 1}"],
- name=f"Unsqueeze_extra_{2*i + 1}",
+ outputs=[f"unsqueeze_extra_{2 * i + 1}"],
+ name=f"Unsqueeze_extra_{2 * i + 1}",
)
reshape_name = reshape_node.name
diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py
index 5dbb9a277e..774761afdd 100644
--- a/onnxruntime/test/python/transformers/test_sparse_attention.py
+++ b/onnxruntime/test/python/transformers/test_sparse_attention.py
@@ -6,6 +6,7 @@
"""
Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 7.5 or above.
"""
+
import math
import unittest
from typing import Optional, Union
diff --git a/onnxruntime/test/python/transformers/whisper_model_generator.py b/onnxruntime/test/python/transformers/whisper_model_generator.py
index a57b45cbc5..71d1a4cbdc 100644
--- a/onnxruntime/test/python/transformers/whisper_model_generator.py
+++ b/onnxruntime/test/python/transformers/whisper_model_generator.py
@@ -22,7 +22,9 @@ def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False
weights = (
[np.random.uniform(low, high) for _ in range(total_elements)]
if random
- else [0.0] * total_elements if zeros else [1.0] * total_elements
+ else [0.0] * total_elements
+ if zeros
+ else [1.0] * total_elements
)
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights
diff --git a/onnxruntime/test/testdata/dummy_t5_model_generator.py b/onnxruntime/test/testdata/dummy_t5_model_generator.py
index 1ecd8b9ee9..00d9231fc8 100644
--- a/onnxruntime/test/testdata/dummy_t5_model_generator.py
+++ b/onnxruntime/test/testdata/dummy_t5_model_generator.py
@@ -1,4 +1,4 @@
-""" Script to generate a dummy ONNX model emulating T5 model with BeamSearch op. """
+"""Script to generate a dummy ONNX model emulating T5 model with BeamSearch op."""
import argparse
diff --git a/onnxruntime/test/testdata/sparse_initializer_as_output.py b/onnxruntime/test/testdata/sparse_initializer_as_output.py
index 1f85f5690d..b10c84ccc1 100644
--- a/onnxruntime/test/testdata/sparse_initializer_as_output.py
+++ b/onnxruntime/test/testdata/sparse_initializer_as_output.py
@@ -6,13 +6,17 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Tuple, T
import numpy as np
import onnx
-from onnx import AttributeProto # noqa: F401
-from onnx import GraphProto # noqa: F401
-from onnx import SparseTensorProto # noqa: F401
-from onnx import mapping # noqa: F401
-from onnx import numpy_helper # noqa: F401
-from onnx import utils # noqa: F401
-from onnx import TensorProto, ValueInfoProto, helper
+from onnx import (
+ AttributeProto, # noqa: F401
+ GraphProto, # noqa: F401
+ SparseTensorProto, # noqa: F401
+ TensorProto,
+ ValueInfoProto,
+ helper,
+ mapping, # noqa: F401
+ numpy_helper, # noqa: F401
+ utils, # noqa: F401
+)
from onnx.helper import make_opsetid
diff --git a/onnxruntime/test/testdata/sparse_to_dense_matmul.py b/onnxruntime/test/testdata/sparse_to_dense_matmul.py
index ceabae9c2d..57a15ba723 100644
--- a/onnxruntime/test/testdata/sparse_to_dense_matmul.py
+++ b/onnxruntime/test/testdata/sparse_to_dense_matmul.py
@@ -6,13 +6,17 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Tuple, T
import numpy as np # noqa: F401
import onnx
-from onnx import AttributeProto # noqa: F401
-from onnx import GraphProto # noqa: F401
-from onnx import SparseTensorProto # noqa: F401
-from onnx import mapping # noqa: F401
-from onnx import numpy_helper # noqa: F401
-from onnx import utils # noqa: F401
-from onnx import TensorProto, ValueInfoProto, helper
+from onnx import (
+ AttributeProto, # noqa: F401
+ GraphProto, # noqa: F401
+ SparseTensorProto, # noqa: F401
+ TensorProto,
+ ValueInfoProto,
+ helper,
+ mapping, # noqa: F401
+ numpy_helper, # noqa: F401
+ utils, # noqa: F401
+)
from onnx.helper import make_opsetid
diff --git a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py
index 443444044b..430a9a345e 100644
--- a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py
+++ b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""This file is used to generate test data for Adam optimizer tests in
- orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc."""
+orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc."""
import torch
diff --git a/onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py
index c67faaca5c..e4ecae4b18 100644
--- a/onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py
+++ b/onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""This file is used to generate test data for LR scheduler optimizer tests in
- orttraining/orttraining/test/training_api/core/training_api_tests.cc."""
+orttraining/orttraining/test/training_api/core/training_api_tests.cc."""
import torch
from torch.optim.lr_scheduler import LambdaLR
@@ -33,7 +33,7 @@ class WarmupLinearSchedule(LambdaLR):
super().__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
- print(f"warmup_step_count_: {self.warmup_steps }, step: {step}, total_step_count_: {self.t_total}")
+ print(f"warmup_step_count_: {self.warmup_steps}, step: {step}, total_step_count_: {self.t_total}")
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
diff --git a/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py
index 173225a21a..e601385dc8 100644
--- a/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py
+++ b/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""This file is used to generate test data for SGD optimizer tests in
- orttraining/orttraining/test/training_ops/cuda/optimizer/sgd_test.cc."""
+orttraining/orttraining/test/training_ops/cuda/optimizer/sgd_test.cc."""
import torch
diff --git a/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py
index 70e8c4ac01..b2ad2463aa 100644
--- a/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py
+++ b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""This file is used to generate test data for ort format model tests in
- orttraining/orttraining/test/training_api/core/training_capi_tests.cc."""
+orttraining/orttraining/test/training_api/core/training_capi_tests.cc."""
import onnx
import torch
diff --git a/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py b/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py
index e7fd4ac70f..1dd4ae0aee 100644
--- a/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py
+++ b/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py
@@ -24,6 +24,7 @@ Models created with this script:
- fusion/constant_folding_qdq_node_unit.graph_output.qdq_contrib.onnx
- fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx
"""
+
from __future__ import annotations
import argparse
diff --git a/onnxruntime/test/testdata/transform/recompute/recompute_test_graph_generator.py b/onnxruntime/test/testdata/transform/recompute/recompute_test_graph_generator.py
index 2c734feda7..b7552d9a26 100644
--- a/onnxruntime/test/testdata/transform/recompute/recompute_test_graph_generator.py
+++ b/onnxruntime/test/testdata/transform/recompute/recompute_test_graph_generator.py
@@ -2,11 +2,11 @@
# Licensed under the MIT License.
"""This file is used to generate test data for MemoryOptimizer tests in
- onnxruntime/test/optimizer/memory_optimizer_test.cc.
+onnxruntime/test/optimizer/memory_optimizer_test.cc.
- Be noticed, after run this script, manually rename recompute_XXXX_execution_model_training.onnx to
- recompute_XXXX.onnx
- """
+Be noticed, after run this script, manually rename recompute_XXXX_execution_model_training.onnx to
+recompute_XXXX.onnx
+"""
import torch
diff --git a/orttraining/orttraining/python/training/optim/__init__.py b/orttraining/orttraining/python/training/optim/__init__.py
index 3cace4d30c..2ce3a32b0b 100644
--- a/orttraining/orttraining/python/training/optim/__init__.py
+++ b/orttraining/orttraining/python/training/optim/__init__.py
@@ -1,8 +1,10 @@
from .config import AdamConfig, LambConfig, SGDConfig, _OptimizerConfig # noqa: F401
from .fp16_optimizer import FP16_Optimizer # noqa: F401
from .fused_adam import AdamWMode, FusedAdam # noqa: F401
-from .lr_scheduler import ConstantWarmupLRScheduler # noqa: F401
-from .lr_scheduler import CosineWarmupLRScheduler # noqa: F401
-from .lr_scheduler import LinearWarmupLRScheduler # noqa: F401
-from .lr_scheduler import PolyWarmupLRScheduler # noqa: F401
-from .lr_scheduler import _LRScheduler # noqa: F401
+from .lr_scheduler import (
+ ConstantWarmupLRScheduler, # noqa: F401
+ CosineWarmupLRScheduler, # noqa: F401
+ LinearWarmupLRScheduler, # noqa: F401
+ PolyWarmupLRScheduler, # noqa: F401
+ _LRScheduler, # noqa: F401
+)
diff --git a/orttraining/orttraining/python/training/optim/config.py b/orttraining/orttraining/python/training/optim/config.py
index d63c7ab40a..d509c8b06f 100644
--- a/orttraining/orttraining/python/training/optim/config.py
+++ b/orttraining/orttraining/python/training/optim/config.py
@@ -57,9 +57,9 @@ class _OptimizerConfig:
)
for k in group:
if k != "params":
- assert (
- k in defaults or k.replace("_coef", "") in defaults
- ), f"'params' has {k} hyper parameter not present at 'defaults'"
+ assert k in defaults or k.replace("_coef", "") in defaults, (
+ f"'params' has {k} hyper parameter not present at 'defaults'"
+ )
self.name = name
self.lr = float(defaults["lr"])
diff --git a/orttraining/orttraining/python/training/optim/lr_scheduler.py b/orttraining/orttraining/python/training/optim/lr_scheduler.py
index 2a9bf438fa..bef6abb4a2 100644
--- a/orttraining/orttraining/python/training/optim/lr_scheduler.py
+++ b/orttraining/orttraining/python/training/optim/lr_scheduler.py
@@ -273,9 +273,9 @@ class PolyWarmupLRScheduler(_LRScheduler):
self._num_warmup_steps = warmup * total_steps
def _warmup_poly(self, train_step_info):
- assert (
- train_step_info.optimizer_config.lr > self.lr_end
- ), f"lr_end ({lr_end}) must be be smaller than initial lr ({train_step_info.optimizer_config.lr})" # noqa: F821
+ assert train_step_info.optimizer_config.lr > self.lr_end, (
+ f"lr_end ({self.lr_end}) must be be smaller than initial lr ({train_step_info.optimizer_config.lr})"
+ )
if train_step_info.optimization_step < self._num_warmup_steps:
return float(train_step_info.optimization_step) / float(max(1, self._num_warmup_steps))
diff --git a/orttraining/orttraining/python/training/ort_triton/__init__.py b/orttraining/orttraining/python/training/ort_triton/__init__.py
index 5f2d0c62ff..f87f8d73e7 100644
--- a/orttraining/orttraining/python/training/ort_triton/__init__.py
+++ b/orttraining/orttraining/python/training/ort_triton/__init__.py
@@ -9,8 +9,12 @@ from functools import wraps
from onnxruntime.capi import _pybind_state as _C
from .kernel import * # noqa: F403
-from .triton_op_executor import register_triton_kernel # noqa: F401
-from .triton_op_executor import call_triton_by_name, call_triton_by_onnx, get_config
+from .triton_op_executor import (
+ call_triton_by_name,
+ call_triton_by_onnx,
+ get_config,
+ register_triton_kernel, # noqa: F401
+)
def run_once_register_triton_op_executor(f):
diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py
index 9a447d8019..c6759630b2 100644
--- a/orttraining/orttraining/python/training/ort_triton/_codegen.py
+++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py
@@ -105,9 +105,9 @@ class TritonCodegen(NodeVisitor):
name = node.tensor_arg.name
var_name = context.get_variable_name(name)
internal_var_name = context.get_internal_variable_name(name)
- assert (
- var_name != internal_var_name
- ), f"variable name {var_name} and its internal variable name should not be the same."
+ assert var_name != internal_var_name, (
+ f"variable name {var_name} and its internal variable name should not be the same."
+ )
offset_str, mask_str = self._get_offset_mask(node.offset_calc, node.tensor_arg.name)
if offset_str:
@@ -359,8 +359,7 @@ class TritonCodegen(NodeVisitor):
for reduce_node in node.reduce_nodes:
tmp_var_name = "tmp_" + context.get_internal_variable_name(reduce_node.outputs[0].name)
code_buffer += (
- f"{space_indent}{tmp_var_name} = "
- f"tl.zeros([XBLOCK, RBLOCK], tl.float32) + {reduce_node.default_value}\n"
+ f"{space_indent}{tmp_var_name} = tl.zeros([XBLOCK, RBLOCK], tl.float32) + {reduce_node.default_value}\n"
)
code_buffer += (
f"{space_indent}for roffset in range(0, rnumel, RBLOCK):\n{space_indent} rindex = rbase + roffset\n"
@@ -440,9 +439,7 @@ class TritonCodegen(NodeVisitor):
def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): # noqa: N802
space_indent = " " * indent
code_buffer += (
- f"{space_indent}import triton\n"
- f"{space_indent}import triton.language as tl\n"
- f"{space_indent}import torch\n"
+ f"{space_indent}import triton\n{space_indent}import triton.language as tl\n{space_indent}import torch\n"
)
for kernel_node in node.kernels:
diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py
index f7b7c1ff08..3850d988ef 100644
--- a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py
+++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py
@@ -793,7 +793,7 @@ def flash_attn_forward(q, k, v, bias=None, **kwargs):
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = "matrix"
else:
- raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)")
+ raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
@@ -903,7 +903,7 @@ def flash_attn_backward(do, q, k, v, o, lse, bias=None, **kwargs):
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = "matrix"
else:
- raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)")
+ raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
index 1efc3a23ee..3e679c994f 100644
--- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
+++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
@@ -191,7 +191,6 @@ def _export_pt_1_10(g, n, *args, **kwargs):
def _default_export(
g, func_full_qual_name, func_class, cconv, output_size, output_tensor_types, output_tensor_ranks, *args, **kwargs
):
-
input_tensor_types = []
input_tensor_ranks = []
diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback.py b/orttraining/orttraining/python/training/ortmodule/_fallback.py
index 56bb45d064..6a3793cf0f 100644
--- a/orttraining/orttraining/python/training/ortmodule/_fallback.py
+++ b/orttraining/orttraining/python/training/ortmodule/_fallback.py
@@ -11,7 +11,6 @@ from typing import Optional
import torch
from . import _logger, _utils
-from ._fallback_exceptions import wrap_exception # noqa: F401
from ._fallback_exceptions import (
ORTModuleDeviceException,
ORTModuleFallbackException,
@@ -19,6 +18,7 @@ from ._fallback_exceptions import (
ORTModuleIOError,
ORTModuleONNXModelException,
ORTModuleTorchModelException,
+ wrap_exception, # noqa: F401
)
diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
index d9cae8e1f9..bbf271e4e9 100755
--- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
+++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
@@ -580,9 +580,9 @@ class GraphTransitionManager:
parameter_names = {k: v for k, v in flatten_module.named_parameters()}
for input_name in exported_model_info.onnx_graph_input_names:
if input_name in exported_model_info.onnx_graph_input_names_user_defined:
- assert (
- input_name in model_info_for_export.onnx_graph_input_data_accessor_user_defined
- ), f"{input_name} model_info_for_export.onnx_graph_input_data_accessor_user_defined"
+ assert input_name in model_info_for_export.onnx_graph_input_data_accessor_user_defined, (
+ f"{input_name} model_info_for_export.onnx_graph_input_data_accessor_user_defined"
+ )
# We assume the data accessor should be the same as the one used for the previous export, because
# there is args and kwargs schema check during export check phase.
if model_info_for_export.onnx_graph_input_data_accessor_user_defined[input_name](
@@ -736,7 +736,6 @@ class GraphTransitionManager:
runtime_inspector: RuntimeInspector,
logger: logging.Logger,
) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]:
-
# Add hooks to check the sparsity of the embedding and label inputs during the export.
embedding_hook_handles = GraphTransitionManager._add_check_embedding_sparsity_hook(
enable_embedding_sparse_optimizer, device, logger, runtime_inspector, flattened_module
diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py
index 86fa4c9c9a..c739283e5c 100644
--- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py
+++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py
@@ -201,7 +201,6 @@ class MemoryObserver:
_MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE,
_MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE,
]:
-
apply_config = []
for cluster_id in self.cluster_id_combination_to_saving_symbolics_map:
diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
index 11d978e71d..7da3e18007 100644
--- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
+++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
@@ -102,9 +102,9 @@ def post_processing_enable_zero_stage3_compat(
func_name = _get_func_name(c)
if func_name == pre_forward_function_name:
- assert (
- pre_forward_pythonop_node is None
- ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen"
+ assert pre_forward_pythonop_node is None, (
+ "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen"
+ )
pre_forward_pythonop_node = c
if pre_forward_pythonop_node is None:
@@ -210,7 +210,7 @@ def post_processing_enable_zero_stage3_compat(
def _create_weight_retrieval_function(
- zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]
+ zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]],
) -> str:
"""This function is used to create a weight retrieving function using zero_stage3_named_params."""
diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py
index 76c8ce3bf3..7cda029524 100644
--- a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py
+++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py
@@ -59,9 +59,9 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
assert hasattr(data, _load_use_external_gpu_allocator.loading_key)
log.info(f"Found keyword {_load_use_external_gpu_allocator.loading_key} in json. Loading attributes from file.")
- assert isinstance(
- data.UseExternalGPUAllocator, bool
- ), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean"
+ assert isinstance(data.UseExternalGPUAllocator, bool), (
+ f"{_load_use_external_gpu_allocator.loading_key} must be a boolean"
+ )
ortmodule_config_accessor._runtime_options.use_external_gpu_allocator = data.UseExternalGPUAllocator
@@ -73,9 +73,9 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file."
)
- assert isinstance(
- data.EnableCustomAutogradFunction, bool
- ), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean"
+ assert isinstance(data.EnableCustomAutogradFunction, bool), (
+ f"{_load_enable_custom_autograd_function.loading_key} must be a boolean"
+ )
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
@@ -89,9 +89,9 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
assert hasattr(data, _load_enable_grad_acc_optimization.loading_key)
log.info(f"Found keyword {_load_enable_grad_acc_optimization.loading_key} in json. Loading attributes from file.")
- assert isinstance(
- data.EnableGradAccOptimization, bool
- ), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean"
+ assert isinstance(data.EnableGradAccOptimization, bool), (
+ f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean"
+ )
ortmodule_config_accessor._runtime_options.enable_grad_acc_optimization = data.EnableGradAccOptimization
@@ -101,9 +101,9 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
assert hasattr(data, _load_run_symbolic_shape_infer.loading_key)
log.info(f"Found keyword {_load_run_symbolic_shape_infer.loading_key} in json. Loading attributes from file.")
- assert isinstance(
- data.RunSymbolicShapeInference, bool
- ), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean"
+ assert isinstance(data.RunSymbolicShapeInference, bool), (
+ f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean"
+ )
ortmodule_config_accessor._runtime_options.run_symbolic_shape_infer = data.RunSymbolicShapeInference
@@ -175,9 +175,9 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data):
assert hasattr(data, _load_use_memory_efficient_gradient.loading_key)
log.info(f"Found keyword {_load_use_memory_efficient_gradient.loading_key} in json. Loading attributes from file.")
- assert isinstance(
- data.UseMemoryEfficientGradient, bool
- ), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean"
+ assert isinstance(data.UseMemoryEfficientGradient, bool), (
+ f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean"
+ )
ortmodule_config_accessor._runtime_options.use_memory_efficient_gradient = data.UseMemoryEfficientGradient
diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py
index a8e730488d..d7ea3dc419 100644
--- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py
+++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py
@@ -278,11 +278,11 @@ def _summarize_tensor(
std_value = torch.sqrt(s.sum() / (element_count - 1))
f.write(
- f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
+ f"{'>' * max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
f"min: {min_value} max: {max_value}, mean: {mean_value}, "
f"std: {std_value} \n"
f"nan: {num_nan}, inf: {num_inf}\n"
)
f.write(f"samples(top 128): {flatten_array[:128]}\n")
f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n")
- f.write(f"{'='*16}\n")
+ f.write(f"{'=' * 16}\n")
diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py
index e98fe48fc4..a6aa390a3e 100644
--- a/orttraining/orttraining/python/training/utils/torch_io_helper.py
+++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py
@@ -291,9 +291,9 @@ def unflatten_data_using_schema(
elif PrimitiveType.is_primitive_type(data_schema):
return data_schema
elif isinstance(data_schema, _TensorStub):
- assert isinstance(
- data[data_schema.tensor_idx], torch.Tensor
- ), f"Expecting torch.Tensor, got {type(data[data_schema.tensor_idx])}"
+ assert isinstance(data[data_schema.tensor_idx], torch.Tensor), (
+ f"Expecting torch.Tensor, got {type(data[data_schema.tensor_idx])}"
+ )
return data[data_schema.tensor_idx]
elif isinstance(data_schema, abc.Sequence):
sequence_type = type(data_schema)
diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py
index 65043c10d8..3d75b3f988 100644
--- a/orttraining/orttraining/test/python/_test_helpers.py
+++ b/orttraining/orttraining/test/python/_test_helpers.py
@@ -84,7 +84,12 @@ def _get_name(name):
# Depending on calling backward() from which outputs, it's possible that grad of some weights are not calculated.
# none_pt_params is to tell what these weights are, so we will not compare the tensors.
def assert_gradients_match_and_reset_gradient(
- ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-04, atol=1e-05 # noqa: B006
+ ort_model,
+ pt_model,
+ none_pt_params=(),
+ reset_gradient=True,
+ rtol=1e-04,
+ atol=1e-05,
):
ort_named_params = list(ort_model.named_parameters())
pt_named_params = list(pt_model.named_parameters())
diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py
index e57b615de0..bd36ebf545 100644
--- a/orttraining/orttraining/test/python/orttraining_test_dort.py
+++ b/orttraining/orttraining/test/python/orttraining_test_dort.py
@@ -165,9 +165,9 @@ class TestTorchDynamoOrt(unittest.TestCase):
for tensor, baseline_tensor in zip(tensors, baseline_tensors):
torch.testing.assert_close(tensor, baseline_tensor)
- assert (
- len(cached.keys()) == 2
- ), "Should only see two GraphModules so far. One for forward and the other one for backward."
+ assert len(cached.keys()) == 2, (
+ "Should only see two GraphModules so far. One for forward and the other one for backward."
+ )
for value in cached.values():
assert len(value) == 1, (
"One GraphModule should only be mapped to one ONNX model since "
diff --git a/orttraining/orttraining/test/python/orttraining_test_gru.py b/orttraining/orttraining/test/python/orttraining_test_gru.py
index c9e22bf738..fcb7e13b16 100644
--- a/orttraining/orttraining/test/python/orttraining_test_gru.py
+++ b/orttraining/orttraining/test/python/orttraining_test_gru.py
@@ -355,7 +355,9 @@ class GRU:
prev_h = (
all_hidden_states[t - 1, 0, idx, :]
if t > 0
- else initial_hidden_state[0, idx, :] if initial_hidden_state is not None else 0
+ else initial_hidden_state[0, idx, :]
+ if initial_hidden_state is not None
+ else 0
)
grad_update_gate = (prev_h - hidden_gate) * grad_h
diff --git a/orttraining/orttraining/test/python/orttraining_test_lstm.py b/orttraining/orttraining/test/python/orttraining_test_lstm.py
index 4debe73951..1d75f12801 100644
--- a/orttraining/orttraining/test/python/orttraining_test_lstm.py
+++ b/orttraining/orttraining/test/python/orttraining_test_lstm.py
@@ -480,7 +480,9 @@ class LSTM:
grad_forget_gate = grad_c * (
all_cell_states[t - 1, 0, idx, :]
if t > 0
- else initial_cell_state[0, idx, :] if initial_cell_state is not None else 0
+ else initial_cell_state[0, idx, :]
+ if initial_cell_state is not None
+ else 0
)
grad_control_gate = grad_c * input_gate
@@ -520,7 +522,9 @@ class LSTM:
prev_h = (
all_hidden_states[t - 1, 0, idx, :]
if t > 0
- else initial_hidden_state[0, idx, :] if initial_hidden_state is not None else 0
+ else initial_hidden_state[0, idx, :]
+ if initial_hidden_state is not None
+ else 0
)
grad_recurrence_weights[0, : self._hidden_size, :] += np.dot(
np.expand_dims(grad_input_activation, axis=0).T, np.expand_dims(prev_h, axis=0)
@@ -549,17 +553,22 @@ class LSTM:
grad_peephole_weights[0, : self._hidden_size] += grad_input_activation * (
all_cell_states[t - 1, 0, idx, :]
if t > 0
- else initial_cell_state[0, idx, :] if initial_cell_state is not None else 0
+ else initial_cell_state[0, idx, :]
+ if initial_cell_state is not None
+ else 0
)
grad_peephole_weights[0, self._hidden_size : 2 * self._hidden_size] += (
grad_output_activation * all_cell_states[t, 0, idx, :]
)
- grad_peephole_weights[
- 0, 2 * self._hidden_size : 3 * self._hidden_size
- ] += grad_forget_activation * (
- all_cell_states[t - 1, 0, idx, :]
- if t > 0
- else initial_cell_state[0, idx, :] if initial_cell_state is not None else 0
+ grad_peephole_weights[0, 2 * self._hidden_size : 3 * self._hidden_size] += (
+ grad_forget_activation
+ * (
+ all_cell_states[t - 1, 0, idx, :]
+ if t > 0
+ else initial_cell_state[0, idx, :]
+ if initial_cell_state is not None
+ else 0
+ )
)
grad_c = grad_prev_c
diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py
index 0866d4a411..275d53daec 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py
@@ -1102,7 +1102,6 @@ def test_custom_optimizer_block():
def test_generate_artifacts_path():
-
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
index 0ab441ac93..912af9bc88 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
@@ -6562,7 +6562,8 @@ def test_bert_memory_inspection(caplog):
os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1"
pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake.
ort_model = ORTModule(
- copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level.
+ copy.deepcopy(pt_model),
+ DebugOptions(log_level=LogLevel.INFO), # The logged memory info is in INFO level.
)
def run_step(model, x, y, z):
@@ -6776,11 +6777,9 @@ def test_enable_layerwise_recompute(memory_optimization_level, allow_gradient_ch
def test_layerwise_recompute_pythonop_deterministic():
-
original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None)
class DropoutFunction(torch.autograd.Function):
-
@staticmethod
def forward(ctx, x):
return torch.nn.functional.dropout(x, p=0.5, training=True)
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py
index 95012aa050..5764a6a81e 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py
@@ -1414,13 +1414,9 @@ def test_pythonop_training_mode():
def check_pythonop_training_mode(model, is_eval_mode):
## make sure the ort's PythonOp's training_mode is correct
if is_eval_mode:
- onnx_nodes = (
- model._torch_module._execution_manager._inference_manager._graph_transition_manager._exported_model_info.exported_model.graph.node
- )
+ onnx_nodes = model._torch_module._execution_manager._inference_manager._graph_transition_manager._exported_model_info.exported_model.graph.node
else:
- onnx_nodes = (
- model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node
- )
+ onnx_nodes = model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node
found_pythonop = False
for node in onnx_nodes:
@@ -1642,14 +1638,14 @@ def test_customized_shape_inference():
_find_shape_and_dtype(graph.value_info)
assert all(s is not None for s in input_shapes), "PythonOp input shape should be found in the optimized_model"
- assert (
- all(d is not None for d in input_dtypes) is not None
- ), "PythonOp input dtype should be found in the optimized_model"
+ assert all(d is not None for d in input_dtypes) is not None, (
+ "PythonOp input dtype should be found in the optimized_model"
+ )
assert all(s is not None for s in output_shapes), "PythonOp output shape should be found in the optimized_model"
- assert (
- all(d is not None for d in output_dtypes) is not None
- ), "PythonOp output dtype should be found in the optimized_model"
+ assert all(d is not None for d in output_dtypes) is not None, (
+ "PythonOp output dtype should be found in the optimized_model"
+ )
def _compare_shape(shape1, shape2):
if len(shape1.dim) != len(shape2.dim):
@@ -1805,7 +1801,6 @@ def test_python_op_return_persistent_param_as_value():
def test_determistic_pythonop_export():
-
class TestFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
@@ -1839,9 +1834,7 @@ def test_determistic_pythonop_export():
ortmodule = ORTModule(TestModel(output_size)).train()
_ = ortmodule(torch.randn(output_size, dtype=torch.float))
- onnx_nodes = (
- ortmodule._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node
- )
+ onnx_nodes = ortmodule._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node
found_pythonop = False
for node in onnx_nodes:
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
index 877dcd2baa..0d5825fb31 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
@@ -12,10 +12,10 @@ import torch
import wget
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
-from transformers import BertConfig # noqa: F401
from transformers import (
AdamW,
AutoConfig,
+ BertConfig, # noqa: F401
BertForSequenceClassification,
BertTokenizer,
get_linear_schedule_with_warmup,
@@ -429,7 +429,9 @@ def main():
# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
- optimizer, num_warmup_steps=0, num_training_steps=total_steps # Default value in run_glue.py
+ optimizer,
+ num_warmup_steps=0,
+ num_training_steps=total_steps, # Default value in run_glue.py
)
# Seed
random.seed(args.seed)
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py
index 4930f73edf..50f411c02a 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py
@@ -12,9 +12,14 @@ import torch
import wget
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
-from transformers import AdamW # noqa: F401
-from transformers import BertConfig # noqa: F401
-from transformers import AutoConfig, BertForSequenceClassification, BertTokenizer, get_linear_schedule_with_warmup
+from transformers import (
+ AdamW, # noqa: F401
+ AutoConfig,
+ BertConfig, # noqa: F401
+ BertForSequenceClassification,
+ BertTokenizer,
+ get_linear_schedule_with_warmup,
+)
import onnxruntime
from onnxruntime.training.ortmodule import DebugOptions, ORTModule
@@ -432,7 +437,9 @@ def main():
# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
- optimizer, num_warmup_steps=0, num_training_steps=total_steps # Default value in run_glue.py
+ optimizer,
+ num_warmup_steps=0,
+ num_training_steps=total_steps, # Default value in run_glue.py
)
scaler = torch.cuda.amp.GradScaler()
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py
index 46b172a396..174edf3775 100755
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py
@@ -108,7 +108,10 @@ ds = SampleData(x, y)
print("Initialize deepspeed")
model_engine, optimizer, _, _ = deepspeed.initialize(
- args=args, model=model, model_parameters=params, training_data=ds # (x,y)#
+ args=args,
+ model=model,
+ model_parameters=params,
+ training_data=ds, # (x,y)#
)
for step in range(args.steps):
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py
index 35e5bae3ea..07d581b576 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py
@@ -69,9 +69,7 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
self.assert_values_are_close(ort_prediction, pt_prediction, **kwargs)
self.assert_gradients_match_and_reset_gradient(ort_model, pt_model, **kwargs)
- onnx_graph_inf = (
- ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model
- )
+ onnx_graph_inf = ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model
onnx_graph_train = ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model
if debug:
with open(f"debug_{name}_ortmodule_infer.onnx", "wb") as f:
diff --git a/orttraining/orttraining/test/python/qat_poc_example/train.py b/orttraining/orttraining/test/python/qat_poc_example/train.py
index a25c071c58..45c0aa77ae 100644
--- a/orttraining/orttraining/test/python/qat_poc_example/train.py
+++ b/orttraining/orttraining/test/python/qat_poc_example/train.py
@@ -68,8 +68,8 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp
# Training loop
epochs = 5
for epoch in range(epochs):
- logging.info(f"Starting epoch: {epoch+1}")
+ logging.info(f"Starting epoch: {epoch + 1}")
training_loss = _train_epoch(model, optimizer, train_loader)
eval_loss = _eval(model, test_loader)
- logging.info(f"End of epoch: {epoch+1}, training loss: {training_loss:.4f}, eval loss: {eval_loss:.4f}")
+ logging.info(f"End of epoch: {epoch + 1}, training loss: {training_loss:.4f}, eval loss: {eval_loss:.4f}")
diff --git a/orttraining/tools/ci_test/run_batch_size_test.py b/orttraining/tools/ci_test/run_batch_size_test.py
index 348d490678..a1bf3fd71c 100755
--- a/orttraining/tools/ci_test/run_batch_size_test.py
+++ b/orttraining/tools/ci_test/run_batch_size_test.py
@@ -106,7 +106,7 @@ def main():
]
if config.enable_mixed_precision:
- cmds.append("--use_mixed_precision"),
+ (cmds.append("--use_mixed_precision"),)
subprocess.run(cmds, timeout=120).check_returncode() # noqa: PLW1510
diff --git a/orttraining/tools/ci_test/run_bert_perf_test.py b/orttraining/tools/ci_test/run_bert_perf_test.py
index 13d5e9f140..c848621c88 100644
--- a/orttraining/tools/ci_test/run_bert_perf_test.py
+++ b/orttraining/tools/ci_test/run_bert_perf_test.py
@@ -94,8 +94,8 @@ def main():
]
if c.use_mixed_precision:
- cmds.append("--use_mixed_precision"),
- cmds.append("--allreduce_in_fp16"),
+ (cmds.append("--use_mixed_precision"),)
+ (cmds.append("--allreduce_in_fp16"),)
subprocess.run(cmds).check_returncode() # noqa: PLW1510
if c.expected_perf > 0.0:
diff --git a/orttraining/tools/ci_test/run_gpt2_perf_test.py b/orttraining/tools/ci_test/run_gpt2_perf_test.py
index 18e59d275b..1df71f02b7 100644
--- a/orttraining/tools/ci_test/run_gpt2_perf_test.py
+++ b/orttraining/tools/ci_test/run_gpt2_perf_test.py
@@ -60,7 +60,7 @@ def main():
]
if c.use_mixed_precision:
- cmds.append("--use_mixed_precision"),
+ (cmds.append("--use_mixed_precision"),)
subprocess.run(cmds).check_returncode() # noqa: PLW1510
diff --git a/orttraining/tools/scripts/nv_run_pretraining.py b/orttraining/tools/scripts/nv_run_pretraining.py
index 8c57101f72..8f399263e1 100644
--- a/orttraining/tools/scripts/nv_run_pretraining.py
+++ b/orttraining/tools/scripts/nv_run_pretraining.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""BERT finetuning runner."""
-
import argparse
# ==================
diff --git a/pyproject.toml b/pyproject.toml
index 40e6eb96df..60fe630b13 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,43 +1,6 @@
-[tool.black]
-line-length = 120
-# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
-extend-exclude = "cmake|onnxruntime/core/flatbuffers/"
-# NOTE: use the minimum supported python version as target-version
-target-version = ["py310"]
-
-[tool.isort]
-# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
-profile = "black"
-line_length = 120
-extend_skip_glob = [
- "cmake/*",
- "orttraining/*",
- "onnxruntime/core/flatbuffers/*",
-]
-
[tool.pydocstyle]
convention = "google"
-[tool.pylint.BASIC]
-good-names = [
- "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n",
- "p", "q", "r", "s", "t", "u", "v", "w", "ex", "Run", "_", "x", "y", "z"
-]
-
-[tool.pylint.messages_control]
-disable = [
- "format",
- "line-too-long",
- "import-error",
- "no-name-in-module",
- "no-member",
- "too-many-arguments",
- "too-many-locals",
- "too-few-public-methods",
- "missing-docstring",
- "fixme",
-]
-
[tool.pyright]
exclude = ["onnxruntime/core/flatbuffers/*"]
reportMissingImports = false
@@ -45,6 +8,7 @@ reportMissingImports = false
[tool.ruff]
# NOTE: Do not create an exclude list. Edit .lintrunner.toml instead
target-version = "py38"
+line-length = 120
[tool.ruff.lint]
select = [
@@ -53,6 +17,7 @@ select = [
"F", # Pyflakes
"FURB", # refurb
"G", # flake8-logging-format
+ "I", # isort
"ISC", # flake8-implicit-str-concat
"N", # pep8-naming
"NPY", # numpy
@@ -92,10 +57,6 @@ ignore = [
"SIM116", # Don't use dict lookup to replace if-else
]
ignore-init-module-imports = true
-unfixable = [
- "F401", # Unused imports
- "SIM112", # Use upper case for env vars
-]
[tool.ruff.lint.per-file-ignores]
# NOTE: Refrain from growing the ignore list unless for exceptional cases.
diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt
index 2257e259a4..f51a828ff5 100644
--- a/requirements-lintrunner.txt
+++ b/requirements-lintrunner.txt
@@ -4,8 +4,5 @@ lintrunner==0.12.5
lintrunner-adapters==0.12.4
# RUFF
ruff==0.9.1
-# BLACK-ISORT
-black==24.10.0
-isort==5.13.2
# CLANGFORMAT
clang-format==19.1.6
diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py
index 9e567e1ceb..87180a242e 100644
--- a/tools/ci_build/build.py
+++ b/tools/ci_build/build.py
@@ -451,9 +451,7 @@ def parse_arguments():
parser.add_argument(
"--apple_deploy_target",
type=str,
- help="Specify the minimum version of the target platform "
- "(e.g. macOS or iOS)"
- "This is only supported on MacOS",
+ help="Specify the minimum version of the target platform (e.g. macOS or iOS)This is only supported on MacOS",
)
# A 32-bit progress doesn't have enough memory to run all the tests in onnxruntime_test_all.
# Mimalloc is incompatible with address sanitizer.
@@ -1248,8 +1246,7 @@ def generate_build_tree(
cmake_args += ["-Donnxruntime_MPI_HOME=" + mpi_home]
else:
log.warning(
- "mpi_home is supplied but use_mpi is set to false."
- " Build will continue without linking MPI libraries."
+ "mpi_home is supplied but use_mpi is set to false. Build will continue without linking MPI libraries."
)
if nccl_home and os.path.exists(nccl_home):
diff --git a/tools/ci_build/compile_triton.py b/tools/ci_build/compile_triton.py
index c1119aad49..abe95b31e8 100644
--- a/tools/ci_build/compile_triton.py
+++ b/tools/ci_build/compile_triton.py
@@ -93,9 +93,9 @@ def convert_and_save(metadata, header_file, out_dir, out_obj_file):
lib_name = m["lib_file"].replace(".", "_")
meta_ele.append(f'"_binary_{lib_name}_start"')
- meta_ele.append(f"\"{m['func_name']}\"")
- meta_ele.append(f"\"{m['group']}\"")
- meta_ele.append(f"\"{m['name']}\"")
+ meta_ele.append(f'"{m["func_name"]}"')
+ meta_ele.append(f'"{m["group"]}"')
+ meta_ele.append(f'"{m["name"]}"')
meta_ele.append(str(m["num_warps"]))
meta_ele.append(str(m["shared"]))
@@ -103,9 +103,9 @@ def convert_and_save(metadata, header_file, out_dir, out_obj_file):
constants = []
for k, v in m["constants"].items():
constants.append(f'{{ "{k}", {v!s}}}')
- meta_ele.append(f"{{ { ', '.join(constants) } }}")
+ meta_ele.append(f"{{ {', '.join(constants)} }}")
- c_metadata.append(f"{{ { ', '.join(meta_ele) } }}")
+ c_metadata.append(f"{{ {', '.join(meta_ele)} }}")
archive_obj_files(binary_files, out_dir, out_obj_file)
@@ -123,7 +123,7 @@ struct _TritonKernelInfo {{
}};
const _TritonKernelInfo kernel_infos[] = {{
- { ', '.join(c_metadata) },
+ {", ".join(c_metadata)},
}};
"""
diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py
index 1b34b3d302..e9f8fea951 100644
--- a/tools/ci_build/github/android/build_aar_package.py
+++ b/tools/ci_build/github/android/build_aar_package.py
@@ -41,10 +41,7 @@ def _parse_build_settings(args):
build_settings = {}
- if "build_abis" in build_settings_data:
- build_settings["build_abis"] = build_settings_data["build_abis"]
- else:
- build_settings["build_abis"] = DEFAULT_BUILD_ABIS
+ build_settings["build_abis"] = build_settings_data.get("build_abis", DEFAULT_BUILD_ABIS)
build_params = []
if "build_params" in build_settings_data:
diff --git a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py
index dd037c17ae..c18cb1d070 100755
--- a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py
+++ b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py
@@ -11,9 +11,10 @@ import sys
import tempfile
from c.assemble_c_pod_package import assemble_c_pod_package
-from objectivec.assemble_objc_pod_package import assemble_objc_pod_package
from package_assembly_utils import PackageVariant, get_ort_version
+from objectivec.assemble_objc_pod_package import assemble_objc_pod_package
+
SCRIPT_PATH = pathlib.Path(__file__).resolve()
SCRIPT_DIR = SCRIPT_PATH.parent
REPO_DIR = SCRIPT_PATH.parents[4]
diff --git a/tools/ci_build/github/apple/package_release_tasks.py b/tools/ci_build/github/apple/package_release_tasks.py
index 592a326d86..c8d78400c6 100755
--- a/tools/ci_build/github/apple/package_release_tasks.py
+++ b/tools/ci_build/github/apple/package_release_tasks.py
@@ -52,8 +52,7 @@ def _resolve_single_path_from_pattern(path_pattern: str) -> Path:
def _parse_args():
parser = argparse.ArgumentParser(
- description="Helper script to perform release tasks. "
- "Mostly useful for the CocoaPods package release pipeline.",
+ description="Helper script to perform release tasks. Mostly useful for the CocoaPods package release pipeline.",
)
parser.add_argument(
diff --git a/tools/python/dump_ort_model.py b/tools/python/dump_ort_model.py
index b9e3bfa0d3..9d7e23bf3a 100644
--- a/tools/python/dump_ort_model.py
+++ b/tools/python/dump_ort_model.py
@@ -80,7 +80,7 @@ class OrtFormatModelDumper:
outputs = [node.Outputs(i).decode() for i in range(node.OutputsLength())]
print(
f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) "
- f'inputs=[{",".join(inputs)}] outputs=[{",".join(outputs)}]'
+ f"inputs=[{','.join(inputs)}] outputs=[{','.join(outputs)}]"
)
def _dump_graph(self, graph: fbs.Graph):
diff --git a/tools/python/gen_contrib_doc.py b/tools/python/gen_contrib_doc.py
index ab9421b395..ce6f0a1205 100644
--- a/tools/python/gen_contrib_doc.py
+++ b/tools/python/gen_contrib_doc.py
@@ -320,9 +320,7 @@ def main(output_path: str, domain_filter: [str]):
)
# domain -> support level -> name -> [schema]
- index = defaultdict(
- lambda: defaultdict(lambda: defaultdict(list))
- ) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
+ index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
for schema in rtpy.get_all_operator_schema():
index[schema.domain][int(schema.support_level)][schema.name].append(schema)
@@ -331,9 +329,7 @@ def main(output_path: str, domain_filter: [str]):
# Preprocess the Operator Schemas
# [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
- operator_schemas = (
- list()
- ) # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
+ operator_schemas = list() # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
exsting_ops = set() # type: Set[Text]
for domain, _supportmap in sorted(index.items()):
if not should_render_domain(domain, domain_filter):
@@ -394,7 +390,7 @@ if __name__ == "__main__":
parser.add_argument(
"--domains",
nargs="+",
- help="Filter to specified domains. " "e.g. `--domains com.microsoft com.microsoft.nchwc`", # noqa: ISC001
+ help="Filter to specified domains. e.g. `--domains com.microsoft com.microsoft.nchwc`",
)
parser.add_argument(
"--output_path",
diff --git a/tools/python/sparsify_initializers.py b/tools/python/sparsify_initializers.py
index f9cc8db38e..2c80b07cd0 100644
--- a/tools/python/sparsify_initializers.py
+++ b/tools/python/sparsify_initializers.py
@@ -54,9 +54,7 @@ def setup_logging(verbose): # type: (bool) -> None
logger.setLevel(logging_level)
-def convert_tensor_to_sparse(
- tensor, sparsity_threshold, tolerance
-): # type: (TensorProto, float, float) -> Tuple[SparseTensorProto, float]
+def convert_tensor_to_sparse(tensor, sparsity_threshold, tolerance): # type: (TensorProto, float, float) -> Tuple[SparseTensorProto, float]
"""returns a tuple of sparse_tensor and sparsity level"""
values = []
indices = []
@@ -140,9 +138,7 @@ def convert_tensor_to_sparse(
return (sparse_tensor, sparsity)
-def convert_initializers(
- model, exclude_names, sparsity_threshold, tolerance
-): # type: (ModelProto, List[str], float, float) -> None
+def convert_initializers(model, exclude_names, sparsity_threshold, tolerance): # type: (ModelProto, List[str], float, float) -> None
graph = model.graph
converted_sparse = []
remaining_initializers = []
diff --git a/tools/python/util/mobile_helpers/usability_checker.py b/tools/python/util/mobile_helpers/usability_checker.py
index e7948c43ba..81c3c07aa9 100644
--- a/tools/python/util/mobile_helpers/usability_checker.py
+++ b/tools/python/util/mobile_helpers/usability_checker.py
@@ -151,23 +151,23 @@ class PartitioningInfo:
if self.supported_groups:
logger.info(
- f'\tPartition sizes: [{", ".join([str(len(partition)) for partition in self.supported_groups])}]'
+ f"\tPartition sizes: [{', '.join([str(len(partition)) for partition in self.supported_groups])}]"
)
# dump full groups if debug output is enabled
for group in self.supported_groups:
- logger.debug(f'Nodes in group: {",".join([f"{node.op_type}:{node.name}" for node in group])}')
+ logger.debug(f"Nodes in group: {','.join([f'{node.op_type}:{node.name}' for node in group])}")
logger.info(f"Unsupported nodes due to operator={self.nodes_unsupported_due_to_op}")
if self.unsupported_ops:
- logger.info(f'\tUnsupported ops: {",".join(sorted(self.unsupported_ops))}')
+ logger.info(f"\tUnsupported ops: {','.join(sorted(self.unsupported_ops))}")
caveats = self.supported_ops_checker.get_caveats()
if caveats:
indent = " " * 5
logger.info(
"\tCaveats that have not been checked and may result in a node not actually being supported: "
- f'{"".join([os.linesep + indent + caveat for caveat in caveats])}'
+ f"{''.join([os.linesep + indent + caveat for caveat in caveats])}"
)
if self.nodes_unsupported_due_to_dynamic_input:
@@ -341,7 +341,7 @@ def _check_partitioning_for_graph(
continue
if not is_op_supported:
- unsupported_ops.add(f'{node.domain if node.domain else "ai.onnx"}:{node.op_type}')
+ unsupported_ops.add(f"{node.domain if node.domain else 'ai.onnx'}:{node.op_type}")
num_unsupported_nodes_due_to_op += 1
if not is_input_shape_supported:
@@ -349,7 +349,7 @@ def _check_partitioning_for_graph(
if not is_rank_supported:
num_unsupported_nodes_due_to_rank += 1
- ops_with_unsupported_rank.add(f'{node.domain if node.domain else "ai.onnx"}:{node.op_type}')
+ ops_with_unsupported_rank.add(f"{node.domain if node.domain else 'ai.onnx'}:{node.op_type}")
if is_node_supported:
num_supported_nodes += 1
@@ -569,8 +569,7 @@ def check_shapes(graph: onnx.GraphProto, logger: logging.Logger | None = None):
# a model where all inputs are dynamic (results in no value_info)
if not graph.value_info and not (len(graph.node) == 1 or len(dynamic_inputs) == len(graph.input)):
logger.warning(
- "Unable to check shapes within model. "
- "ONNX shape inferencing should be run on the model prior to checking."
+ "Unable to check shapes within model. ONNX shape inferencing should be run on the model prior to checking."
)
for vi in graph.value_info:
diff --git a/tools/python/util/onnx_model_utils.py b/tools/python/util/onnx_model_utils.py
index 5c970430a3..1938a2411e 100644
--- a/tools/python/util/onnx_model_utils.py
+++ b/tools/python/util/onnx_model_utils.py
@@ -227,7 +227,7 @@ def make_input_shape_fixed(graph: onnx.GraphProto, input_name: str, fixed_shape:
raise ValueError(
f"Input {input_name} was not found in graph inputs. "
- f'Valid input names are: {",".join([i.name for i in graph.input])}'
+ f"Valid input names are: {','.join([i.name for i in graph.input])}"
)
@@ -337,7 +337,7 @@ def get_producer_consumer_maps(graph: onnx.GraphProto):
# top level graph should have no implicit inputs
if implicit_inputs:
raise ValueError(
- f'This appears to be an invalid model with missing inputs of {",".join(sorted(implicit_inputs))}'
+ f"This appears to be an invalid model with missing inputs of {','.join(sorted(implicit_inputs))}"
)
return node_to_producers, node_to_consumers
diff --git a/tools/python/util/ort_format_model/__init__.py b/tools/python/util/ort_format_model/__init__.py
index 318851642d..29e8e70ed2 100644
--- a/tools/python/util/ort_format_model/__init__.py
+++ b/tools/python/util/ort_format_model/__init__.py
@@ -18,8 +18,10 @@ else:
sys.path.append(ort_fbs_py_parent_dir)
-from .operator_type_usage_processors import GloballyAllowedTypesOpTypeImplFilter # noqa: E402, F401
-from .operator_type_usage_processors import OperatorTypeUsageManager # noqa: E402, F401
-from .operator_type_usage_processors import OpTypeImplFilterInterface # noqa: E402, F401
+from .operator_type_usage_processors import ( # noqa: E402
+ GloballyAllowedTypesOpTypeImplFilter, # noqa: F401
+ OperatorTypeUsageManager, # noqa: F401
+ OpTypeImplFilterInterface, # noqa: F401
+)
from .ort_model_processor import OrtFormatModelProcessor # noqa: E402, F401
from .utils import create_config_from_models # noqa: E402, F401
diff --git a/tools/python/util/ort_format_model/types.py b/tools/python/util/ort_format_model/types.py
index ffeda6b2e7..9661eb33c9 100644
--- a/tools/python/util/ort_format_model/types.py
+++ b/tools/python/util/ort_format_model/types.py
@@ -6,6 +6,7 @@ import ort_flatbuffers_py.fbs as fbs
class FbsTypeInfo:
"Class to provide conversion between ORT flatbuffers schema values and C++ types"
+
tensordatatype_to_string = { # noqa: RUF012
fbs.TensorDataType.TensorDataType.FLOAT: "float",
fbs.TensorDataType.TensorDataType.UINT8: "uint8_t",