onnxruntime/orttraining/tools/scripts/model_transform.py
dependabot[bot] 1461a16e71
Bump ruff from 0.5.4 to 0.9.1 (#23328)
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.5.4 to 0.9.1.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/astral-sh/ruff/releases">ruff's
releases</a>.</em></p>
<blockquote>
<h2>0.9.1</h2>
<h2>Release Notes</h2>
<h3>Preview features</h3>
<ul>
<li>[<code>pycodestyle</code>] Run
<code>too-many-newlines-at-end-of-file</code> on each cell in notebooks
(<code>W391</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15308">#15308</a>)</li>
<li>[<code>ruff</code>] Omit diagnostic for shadowed private function
parameters in <code>used-dummy-variable</code> (<code>RUF052</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15376">#15376</a>)</li>
</ul>
<h3>Rule changes</h3>
<ul>
<li>[<code>flake8-bugbear</code>] Improve
<code>assert-raises-exception</code> message (<code>B017</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15389">#15389</a>)</li>
</ul>
<h3>Formatter</h3>
<ul>
<li>Preserve trailing end-of line comments for the last string literal
in implicitly concatenated strings (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15378">#15378</a>)</li>
</ul>
<h3>Server</h3>
<ul>
<li>Fix a bug where the server and client notebooks were out of sync
after reordering cells (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15398">#15398</a>)</li>
</ul>
<h3>Bug fixes</h3>
<ul>
<li>[<code>flake8-pie</code>] Correctly remove wrapping parentheses
(<code>PIE800</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15394">#15394</a>)</li>
<li>[<code>pyupgrade</code>] Handle comments and multiline expressions
correctly (<code>UP037</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15337">#15337</a>)</li>
</ul>
<h2>Contributors</h2>
<ul>
<li><a
href="https://github.com/AntoineD"><code>@​AntoineD</code></a></li>
<li><a
href="https://github.com/InSyncWithFoo"><code>@​InSyncWithFoo</code></a></li>
<li><a
href="https://github.com/MichaReiser"><code>@​MichaReiser</code></a></li>
<li><a href="https://github.com/calumy"><code>@​calumy</code></a></li>
<li><a
href="https://github.com/dcreager"><code>@​dcreager</code></a></li>
<li><a
href="https://github.com/dhruvmanila"><code>@​dhruvmanila</code></a></li>
<li><a href="https://github.com/dylwil3"><code>@​dylwil3</code></a></li>
<li><a href="https://github.com/sharkdp"><code>@​sharkdp</code></a></li>
<li><a href="https://github.com/tjkuson"><code>@​tjkuson</code></a></li>
</ul>
<h2>Install ruff 0.9.1</h2>
<h3>Install prebuilt binaries via shell script</h3>
<pre lang="sh"><code>curl --proto '=https' --tlsv1.2 -LsSf
https://github.com/astral-sh/ruff/releases/download/0.9.1/ruff-installer.sh
| sh
</code></pre>
<h3>Install prebuilt binaries via powershell script</h3>
<pre lang="sh"><code>powershell -ExecutionPolicy ByPass -c &quot;irm
https://github.com/astral-sh/ruff/releases/download/0.9.1/ruff-installer.ps1
| iex&quot;
</code></pre>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Changelog</summary>
<p><em>Sourced from <a
href="https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md">ruff's
changelog</a>.</em></p>
<blockquote>
<h2>0.9.1</h2>
<h3>Preview features</h3>
<ul>
<li>[<code>pycodestyle</code>] Run
<code>too-many-newlines-at-end-of-file</code> on each cell in notebooks
(<code>W391</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15308">#15308</a>)</li>
<li>[<code>ruff</code>] Omit diagnostic for shadowed private function
parameters in <code>used-dummy-variable</code> (<code>RUF052</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15376">#15376</a>)</li>
</ul>
<h3>Rule changes</h3>
<ul>
<li>[<code>flake8-bugbear</code>] Improve
<code>assert-raises-exception</code> message (<code>B017</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15389">#15389</a>)</li>
</ul>
<h3>Formatter</h3>
<ul>
<li>Preserve trailing end-of line comments for the last string literal
in implicitly concatenated strings (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15378">#15378</a>)</li>
</ul>
<h3>Server</h3>
<ul>
<li>Fix a bug where the server and client notebooks were out of sync
after reordering cells (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15398">#15398</a>)</li>
</ul>
<h3>Bug fixes</h3>
<ul>
<li>[<code>flake8-pie</code>] Correctly remove wrapping parentheses
(<code>PIE800</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15394">#15394</a>)</li>
<li>[<code>pyupgrade</code>] Handle comments and multiline expressions
correctly (<code>UP037</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/pull/15337">#15337</a>)</li>
</ul>
<h2>0.9.0</h2>
<p>Check out the <a href="https://astral.sh/blog/ruff-v0.9.0">blog
post</a> for a migration guide and overview of the changes!</p>
<h3>Breaking changes</h3>
<p>Ruff now formats your code according to the 2025 style guide. As a
result, your code might now get formatted differently. See the formatter
section for a detailed list of changes.</p>
<p>This release doesn’t remove or remap any existing stable rules.</p>
<h3>Stabilization</h3>
<p>The following rules have been stabilized and are no longer in
preview:</p>
<ul>
<li><a
href="https://docs.astral.sh/ruff/rules/stdlib-module-shadowing/"><code>stdlib-module-shadowing</code></a>
(<code>A005</code>).
This rule has also been renamed: previously, it was called
<code>builtin-module-shadowing</code>.</li>
<li><a
href="https://docs.astral.sh/ruff/rules/builtin-lambda-argument-shadowing/"><code>builtin-lambda-argument-shadowing</code></a>
(<code>A006</code>)</li>
<li><a
href="https://docs.astral.sh/ruff/rules/slice-to-remove-prefix-or-suffix/"><code>slice-to-remove-prefix-or-suffix</code></a>
(<code>FURB188</code>)</li>
<li><a
href="https://docs.astral.sh/ruff/rules/boolean-chained-comparison/"><code>boolean-chained-comparison</code></a>
(<code>PLR1716</code>)</li>
<li><a
href="https://docs.astral.sh/ruff/rules/decimal-from-float-literal/"><code>decimal-from-float-literal</code></a>
(<code>RUF032</code>)</li>
<li><a
href="https://docs.astral.sh/ruff/rules/post-init-default/"><code>post-init-default</code></a>
(<code>RUF033</code>)</li>
<li><a
href="https://docs.astral.sh/ruff/rules/useless-if-else/"><code>useless-if-else</code></a>
(<code>RUF034</code>)</li>
</ul>
<p>The following behaviors have been stabilized:</p>
<ul>
<li><a
href="https://docs.astral.sh/ruff/rules/pytest-parametrize-names-wrong-type/"><code>pytest-parametrize-names-wrong-type</code></a>
(<code>PT006</code>): Detect <a
href="https://docs.pytest.org/en/7.1.x/how-to/parametrize.html#parametrize"><code>pytest.parametrize</code></a>
calls outside decorators and calls with keyword arguments.</li>
</ul>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="12f86f39a4"><code>12f86f3</code></a>
Ruff 0.9.1 (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15407">#15407</a>)</li>
<li><a
href="2b28d566a4"><code>2b28d56</code></a>
Associate a trailing end-of-line comment in a parenthesized implicit
concaten...</li>
<li><a
href="adca7bd95c"><code>adca7bd</code></a>
Remove pygments pin (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15404">#15404</a>)</li>
<li><a
href="6b98a26452"><code>6b98a26</code></a>
[red-knot] Support <code>assert_type</code> (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15194">#15194</a>)</li>
<li><a
href="c87463842a"><code>c874638</code></a>
[red-knot] Move tuple-containing-Never tests to Markdown (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15402">#15402</a>)</li>
<li><a
href="c364b586f9"><code>c364b58</code></a>
[<code>flake8-pie</code>] Correctly remove wrapping parentheses
(<code>PIE800</code>) (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15394">#15394</a>)</li>
<li><a
href="73d424ee5e"><code>73d424e</code></a>
Fix outdated doc for handling the default file types with the pre-commit
hook...</li>
<li><a
href="6e9ff445fd"><code>6e9ff44</code></a>
Insert the cells from the <code>start</code> position (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15398">#15398</a>)</li>
<li><a
href="f2c3ddc5ea"><code>f2c3ddc</code></a>
[red-knot] Move intersection type tests to Markdown (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15396">#15396</a>)</li>
<li><a
href="b861551b6a"><code>b861551</code></a>
Remove unnecessary backticks (<a
href="https://redirect.github.com/astral-sh/ruff/issues/15393">#15393</a>)</li>
<li>Additional commits viewable in <a
href="https://github.com/astral-sh/ruff/compare/0.5.4...0.9.1">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.5.4&new-version=0.9.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-01-15 11:11:17 -08:00

327 lines
11 KiB
Python

from __future__ import annotations
import sys
import numpy as np
import onnx
from onnx import numpy_helper
if len(sys.argv) < 2:
print("Please give model path...")
exit(1)
input_model_name = sys.argv[1]
output_model_name = input_model_name[:-5] + "_optimized.onnx"
model = onnx.load(input_model_name)
def add_name(model):
for i, node in enumerate(model.graph.node):
node.name = f"{node.op_type}_{i}"
def find_input_node(model, arg):
result = []
for node in model.graph.node:
for output in node.output:
if output == arg:
result.append(node)
return result[0] if len(result) == 1 else None
def find_output_node(model, arg):
result = []
for node in model.graph.node:
for input in node.input:
if input == arg:
result.append(node)
return result[0] if len(result) == 1 else None
def find_input(model, arg):
for initializer in model.graph.initializer:
if initializer.name == arg:
return initializer
return None
def find_all_fused_nodes(model, concat_node):
result = []
candidate = [concat_node]
while len(candidate) > 0:
node = candidate[0]
candidate.pop(0)
result.append(node)
if node.op_type == "Shape":
continue
for input in node.input:
input_node = find_input_node(model, input)
if input_node is not None:
candidate.append(input_node)
return result
def get_node_index(model, node):
i = 0
while i < len(model.graph.node):
if model.graph.node[i] == node:
break
i += 1
return i if i < len(model.graph.node) else None
def add_const(model, name, output, t_value=None, f_value=None):
const_node = model.graph.node.add()
const_node.op_type = "Constant"
const_node.name = name
const_node.output.extend([output])
attr = const_node.attribute.add()
attr.name = "value"
if t_value is not None:
attr.type = 4
attr.t.CopyFrom(t_value)
else:
attr.type = 1
attr.f = f_value
return const_node
def process_concat(model):
new_nodes = {}
delete_nodes = []
for node in model.graph.node:
if node.op_type == "Concat":
input_nodes = []
for input in node.input:
input_nodes.append(find_input_node(model, input))
# figure out target shape
shape = []
for input_node in input_nodes:
assert input_node.op_type == "Unsqueeze"
const_input = find_input_node(model, input_node.input[0])
if const_input.op_type != "Constant":
shape.append(0)
else:
attr = const_input.attribute
assert len(attr) == 1
assert attr[0].name == "value"
assert attr[0].type == 4
data = numpy_helper.to_array(attr[0].t)
shape.append(np.asscalar(data))
print(f"concat node: {node.name}, new_shape is: {shape}")
# find out the nodes need to be deleted.
fuse_nodes = find_all_fused_nodes(model, node)
reshape_node = find_output_node(model, node.output[0])
assert reshape_node.op_type == "Reshape"
new_nodes[get_node_index(model, reshape_node)] = shape
for n in fuse_nodes:
delete_nodes.append(get_node_index(model, n))
# insert new shape to reshape
for index, reshape_node_index in enumerate(new_nodes):
shape_tensor = numpy_helper.from_array(np.asarray(new_nodes[reshape_node_index], dtype=np.int64))
const_node = add_const(model, f"concat_shape_node_{index}", f"concat_shape_{index}", shape_tensor)
reshape_node = model.graph.node[reshape_node_index]
reshape_node.input[1] = const_node.output[0]
# delete nodes
delete_nodes.sort(reverse=True)
for delete_node in delete_nodes:
del model.graph.node[delete_node]
def add_cast(model, name, input, output, type):
cast_node = model.graph.node.add()
cast_node.name = name
cast_node.op_type = "Cast"
attr = cast_node.attribute.add()
attr.name = "to"
attr.type = 2
attr.i = type
cast_node.input.extend([input])
cast_node.output.extend([output])
return cast_node
def fix_expand(model):
# find expand node
expand_node = None
for node in model.graph.node:
if node.op_type == "Expand":
expand_node = node
break
assert expand_node is not None
const_expand_input = find_input_node(model, expand_node.input[0])
assert const_expand_input.op_type == "Constant"
shape_node = find_input_node(model, expand_node.input[1])
assert shape_node.op_type == "Shape"
# insert cast --> min --> cast
cast_1 = add_cast(model, "new_cast_01", shape_node.output[0], "to_min_01", 1)
min_target = numpy_helper.from_array(np.asarray([1, 9999], dtype=np.float32))
min_target_node = add_const(model, "op_min_node_10", "op_min_ends_expand_10", min_target)
min_node = model.graph.node.add()
min_node.name = "new_min_01"
min_node.op_type = "Min"
min_node.input.extend([cast_1.output[0], min_target_node.output[0]])
min_node.output.extend(["from_min_01"])
cast_2 = add_cast(model, "new_cast_02", min_node.output[0], "to_slice_01", 7)
# insert slice
position = numpy_helper.from_array(np.expand_dims(np.arange(512, dtype=np.int64), axis=0))
position_node = add_const(model, "position_01_node", "position_01", position)
start_extend = numpy_helper.from_array(np.asarray([0, 0], dtype=np.int64), "start_expand_10")
start_extend_node = add_const(model, "start_expand_10_node", "start_expand_10", start_extend)
axes = numpy_helper.from_array(np.asarray([0, 1], dtype=np.int64), "axes_expand_10")
axes_node = add_const(model, "axes_expand_10_node", "axes_expand_10", axes)
slice_node = model.graph.node.add()
slice_node.name = "new_slice_01"
slice_node.op_type = "Slice"
slice_node.input.extend(
[position_node.output[0], start_extend_node.output[0], cast_2.output[0], axes_node.output[0]]
)
slice_node.output.extend(["from_slice_01"])
# connect to expand
expand_node.input[0] = slice_node.output[0]
# delete the const input
del model.graph.node[get_node_index(model, const_expand_input)]
def fix_dim(model):
del model.graph.input[3:]
def replace_input_arg(model, arg, new_arg):
for node in model.graph.node:
i = 0
while i < len(node.input):
if node.input[i] == arg:
node.input[i] = new_arg
i += 1
def find_weight_index(model, name: str) -> int | None:
for index, w in enumerate(model.graph.initializer):
if w.name == name:
return index
return None
def fix_transpose(model):
transpose = []
for node in model.graph.node:
if node.op_type == "Transpose":
weight = find_input(model, node.input[0])
if weight is not None:
result = []
for n in model.graph.node:
for input in n.input:
if input == weight.name:
result.append(n)
if len(result) > 1:
continue
perm = node.attribute[0]
assert perm.name == "perm"
perm = perm.ints
assert len(perm) == 2 and perm[0] == 1 and perm[1] == 0
transpose.append((get_node_index(model, node), weight))
for t in transpose:
node = model.graph.node[t[0]]
weight = numpy_helper.to_array(t[1])
assert len(weight.shape) == 2
weight = weight.transpose(perm)
new_weight = numpy_helper.from_array(weight, f"{t[1].name}_transposed")
model.graph.initializer.extend([new_weight])
replace_input_arg(model, node.output[0], new_weight.name)
transpose.sort(reverse=True)
for t in transpose:
del model.graph.node[t[0]]
old_ws = []
for t in transpose:
if find_output_node(model, t[1].name) is None:
old_ws.append(find_weight_index(model, t[1].name))
old_ws.sort(reverse=True)
for w_i in old_ws:
del model.graph.initializer[w_i]
def process_dropout(model):
dropouts = []
index = 0
for node in model.graph.node:
if node.op_type == "Dropout":
new_dropout = model.graph.node.add()
new_dropout.op_type = "TrainableDropout"
new_dropout.name = f"TrainableDropout_{index}"
# make ratio node
ratio = np.asarray([node.attribute[0].f], dtype=np.float32)
print(ratio.shape)
ratio_value = numpy_helper.from_array(ratio)
ratio_node = add_const(
model, f"dropout_node_ratio_{index}", f"dropout_node_ratio_{index}", t_value=ratio_value
)
print(ratio_node)
new_dropout.input.extend([node.input[0], ratio_node.output[0]])
new_dropout.output.extend(node.output)
dropouts.append(get_node_index(model, node))
index += 1
dropouts.sort(reverse=True)
for d in dropouts:
del model.graph.node[d]
# Also need to set following line differently for different version of bert
# expand_out.name = '412'
def add_expand_shape(model):
expand_out = model.graph.value_info.add()
expand_out.name = "74" #'410' # 74 for base model
expand_out.type.CopyFrom(model.graph.input[0].type)
# add name to nodes
add_name(model)
# replace garther&concat to reshape
process_concat(model)
# fix the expand with dynamic shape
fix_expand(model)
# use dynamic batch/sequence
fix_dim(model)
# constant fold transpose
fix_transpose(model)
# replace dropout with trainable dropout
process_dropout(model)
# add output shape of expand
add_expand_shape(model)
# set opset version to 10
model.opset_import[0].version = 10
f = open(output_model_name, "wb") # noqa: SIM115
f.write(model.SerializeToString())
f.close()
# Use ORT to verify the converted model. Notice that you must use python package from the
# training branch because training requires some extra ops.
import onnxruntime as ort # noqa: E402
# We convert model to accept variable-length batch size, so it can be any positive integer.
batch = 3
# This should match --max_seq_length when calling nv_run_pretraining.py.
sq_length = 512
# This should match vocab_size in bert_config.json in DeepLearningExamples/PyTorch/LanguageModeling/BERT.
vocab_size = 30528
# Create a fake data point.
input_ids = np.random.randint(low=0, high=vocab_size, size=(batch, sq_length), dtype=np.int64)
segment_ids = np.random.randint(low=0, high=2, size=(batch, sq_length), dtype=np.int64)
input_mask = np.ones((batch, sq_length), dtype=np.int64)
# Do forward using the original model.
sess = ort.InferenceSession(input_model_name, providers=ort.get_available_providers())
result = sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask})
# Do forward using the new model.
new_sess = ort.InferenceSession(output_model_name, providers=ort.get_available_providers())
new_result = new_sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask})
# Compare the outcomes from the two models.
print(np.linalg.norm(result[0] - new_result[0]))
print(np.linalg.norm(result[1] - new_result[1]))