mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Description: Format all python files under onnxruntime with black and isort. After checking in, we can use .git-blame-ignore-revs to ignore the formatting PR in git blame. #11315, #11316
80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
import onnx
|
|
from onnx import TensorProto, helper
|
|
|
|
if_body = helper.make_graph(
|
|
[
|
|
# need to use main_graph_initializer in a way that can't be constant folded
|
|
helper.make_node("Add", ["state_var_in", "main_graph_initializer"], ["add_out"], "If_add"),
|
|
helper.make_node("Cast", ["add_out"], ["output"], to=TensorProto.BOOL),
|
|
],
|
|
"if_branch_body",
|
|
[
|
|
# no explicit inputs
|
|
],
|
|
[
|
|
helper.make_tensor_value_info("output", TensorProto.BOOL, [1]), # how is this getting a type of float?
|
|
],
|
|
)
|
|
|
|
# Loop body graph with If node and usage of main_graph_initializer on this level
|
|
body = helper.make_graph(
|
|
[
|
|
# Add node that can be constant folded. Creates NodeArg when created but that implicit usage of an outer scope
|
|
# value main_graph_initializer goes away after constant folding
|
|
helper.make_node(
|
|
"Add",
|
|
["sub_graph_initializer", "main_graph_initializer"],
|
|
["initializer_sum"],
|
|
"Add1",
|
|
),
|
|
helper.make_node("Add", ["initializer_sum", "loop_state_in"], ["loop_state_out"], "Add2"),
|
|
# If node to create usage of main_graph_initializer another level down
|
|
helper.make_node(
|
|
"If",
|
|
["subgraph_keep_going_in"],
|
|
["subgraph_keep_going_out"],
|
|
"If1",
|
|
then_branch=if_body,
|
|
else_branch=if_body,
|
|
),
|
|
],
|
|
"Loop_body",
|
|
[
|
|
helper.make_tensor_value_info("iteration_num", TensorProto.INT64, [1]),
|
|
helper.make_tensor_value_info("subgraph_keep_going_in", TensorProto.BOOL, [1]),
|
|
helper.make_tensor_value_info("loop_state_in", TensorProto.FLOAT, [1]),
|
|
],
|
|
[
|
|
helper.make_tensor_value_info("subgraph_keep_going_out", TensorProto.BOOL, [1]),
|
|
helper.make_tensor_value_info("loop_state_out", TensorProto.FLOAT, [1]),
|
|
],
|
|
[helper.make_tensor("sub_graph_initializer", TensorProto.FLOAT, [1], [1.0])],
|
|
)
|
|
|
|
# Create the main graph
|
|
graph_proto = helper.make_graph(
|
|
[
|
|
helper.make_node(
|
|
"Loop",
|
|
["max_trip_count", "keep_going", "state_var_in"],
|
|
["state_var_out"],
|
|
"Loop1",
|
|
body=body,
|
|
)
|
|
],
|
|
"Main_graph",
|
|
[
|
|
helper.make_tensor_value_info("state_var_in", TensorProto.FLOAT, [1]),
|
|
],
|
|
[
|
|
helper.make_tensor_value_info("state_var_out", TensorProto.FLOAT, [1]),
|
|
],
|
|
[
|
|
helper.make_tensor("max_trip_count", TensorProto.INT64, [1], [1]),
|
|
helper.make_tensor("main_graph_initializer", TensorProto.FLOAT, [1], [1.0]),
|
|
helper.make_tensor("keep_going", TensorProto.BOOL, [1], [True]),
|
|
],
|
|
)
|
|
|
|
model = helper.make_model(graph_proto)
|
|
onnx.save(model, "ort_github_issue_4031.onnx")
|