onnxruntime/orttraining/tools/scripts/pipeline_model_split.py
Justin Chu 0c1a5098dc
Disable PERF* rules in ruff to allow better readability (#16834)
### Description

Disable two PERF* rules in ruff to allow better readability. Rational
commented inline. This change also removes the unused noqa directives
because of the rule change.

### Motivation and Context

Readability
2023-07-25 15:38:22 -07:00

409 lines
15 KiB
Python

import os
import sys # noqa: F401
import onnx
from onnx import OperatorSetIdProto, TensorProto, helper # noqa: F401
# Edge that needs to be cut for the split.
# If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut,
# specify those consuming nodes that need to be cut
class CutEdge:
def __init__(self, edgeId, consumingNodes=None):
self.edgeId = edgeId
self.consumingNodes = consumingNodes
def add_expand_type(model, name, type):
expand_edge = model.graph.value_info.add()
expand_edge.name = name
expand_edge.type.CopyFrom(type)
# Add wait/record/send/recv nodes and split the graph into disconnected subgraphs
def split_graph(model, split_edge_groups):
ms_domain = "com.microsoft"
new_send_nodes = []
new_recv_nodes = []
for cut_index in range(len(split_edge_groups)):
edgeIds = split_edge_groups[cut_index] # noqa: N806
# split the graph based on edgeIds
upstream_nodes = []
upstream_nodes_output_index = []
output_shapes = []
element_types = []
for id in edgeIds:
for node in model.graph.node:
if len(node.output) >= 1:
for i, j in enumerate(node.output):
if j == id:
upstream_nodes.append(node)
upstream_nodes_output_index.append(i)
# assuming all tensors are of type float
element_types.append(1)
for info in model.graph.value_info:
if info.name == id:
output_shapes.append(info.type)
send_input_signal_name = "send_input_signal" + str(cut_index)
send_signal = model.graph.input.add()
send_signal.CopyFrom(helper.make_tensor_value_info(send_input_signal_name, onnx.TensorProto.BOOL, None))
send_signal = helper.make_tensor(send_input_signal_name, TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([send_signal])
recv_input_signal_name = "recv_input_signal" + str(cut_index)
recv_signal = model.graph.input.add()
recv_signal.CopyFrom(helper.make_tensor_value_info(recv_input_signal_name, onnx.TensorProto.BOOL, None))
recv_signal = helper.make_tensor(recv_input_signal_name, TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([recv_signal])
send_dst_rank_name = "send_dst_rank" + str(cut_index)
send_dst_rank = model.graph.input.add()
send_dst_rank.CopyFrom(helper.make_tensor_value_info(send_dst_rank_name, onnx.TensorProto.INT64, None))
send_dst_rank = helper.make_tensor(send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,))
model.graph.initializer.extend([send_dst_rank])
recv_src_rank_name = "recv_src_rank" + str(cut_index)
recv_src_rank = model.graph.input.add()
recv_src_rank.CopyFrom(helper.make_tensor_value_info(recv_src_rank_name, onnx.TensorProto.INT64, None))
recv_src_rank = helper.make_tensor(recv_src_rank_name, TensorProto.INT64, (), (cut_index,))
model.graph.initializer.extend([recv_src_rank])
# output signal from send after cut
send_output_signal = model.graph.output.add()
send_output_signal.CopyFrom(
helper.make_tensor_value_info("send_output_signal" + str(cut_index), onnx.TensorProto.BOOL, None)
)
# output signal from receive after cut
receive_output_signal = model.graph.output.add()
receive_output_signal.CopyFrom(
helper.make_tensor_value_info("receive_output_signal" + str(cut_index), onnx.TensorProto.BOOL, None)
)
new_send = model.graph.node.add()
new_send.CopyFrom(
helper.make_node(
"Send",
inputs=[send_input_signal_name, send_dst_rank_name],
outputs=["send_output_signal" + str(cut_index)],
tag=0,
domain=ms_domain,
element_types=element_types,
name="send",
)
)
new_receive = model.graph.node.add()
new_receive.CopyFrom(
helper.make_node(
"Recv",
inputs=[recv_input_signal_name, recv_src_rank_name],
outputs=["receive_output_signal" + str(cut_index)],
tag=0,
domain=ms_domain,
element_types=element_types,
name="receive",
)
)
for i in range(len(upstream_nodes)):
n = upstream_nodes[i]
idx = upstream_nodes_output_index[i]
output_type = output_shapes[i]
output_edge_name = n.output[idx]
output_nodes = find_all_output_nodes_by_edge(model, output_edge_name)
# deal with shape inference for newly added edge
new_send_input_name = output_edge_name + "_send" + str(cut_index)
add_expand_type(model, new_send_input_name, output_type)
new_receive_output_name = output_edge_name + "_recv" + str(cut_index)
add_expand_type(model, new_receive_output_name, output_type)
# the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input
new_send.input.extend([output_edge_name])
new_receive.output.extend([new_receive_output_name])
for output_node in output_nodes:
for i in range(len(output_node.input)): # noqa: PLW2901
for edgeId in edgeIds: # noqa: N806
if output_node.input[i] == edgeId:
output_node.input[i] = new_receive_output_name
new_send_nodes.append(new_send)
new_recv_nodes.append(new_receive)
model = onnx.shape_inference.infer_shapes(model)
return new_send_nodes, new_recv_nodes
def find_all_input_nodes(model, node):
nodes = []
inputs = []
if node:
for inputId in node.input: # noqa: N806
nodes.extend([n for n in model.graph.node if inputId in n.output])
inputs.extend([n for n in model.graph.input if inputId in n.name])
return nodes, inputs
def find_all_output_nodes(model, node):
nodes = []
outputs = []
if node:
for outputId in node.output: # noqa: N806
nodes.extend([n for n in model.graph.node if outputId in n.input])
outputs.extend([n for n in model.graph.output if outputId in n.name])
return nodes, outputs
def find_all_output_nodes_by_edge(model, arg):
result = [n for n in model.graph.node if arg in n.input]
return result
# Insert identity nodes to separate same output edge which feeds into different sub-graph.
def add_identity(model, cuttingEdge, newEdgeIdName):
output_nodes = None
edgeId = cuttingEdge.edgeId # noqa: N806
for node in model.graph.node:
if len(node.output) >= 1:
for output in node.output:
if output == edgeId:
output_nodes = find_all_output_nodes_by_edge(model, output)
break
assert output_nodes, "no output node"
new_identity = model.graph.node.add()
new_identity.op_type = "Identity"
new_identity.input.extend([edgeId])
new_identity.output.extend([newEdgeIdName])
for i in range(len(output_nodes)):
for output in output_nodes[i].output:
if output in cuttingEdge.consumingNodes:
for j in range(len(output_nodes[i].input)):
if output_nodes[i].input[j] == edgeId:
output_nodes[i].input[j] = newEdgeIdName
return new_identity
def insert_identity(model, all_cut_inputs):
count = 0
updated_edges = {}
new_added_identity = []
split_edge_groups = []
need_shape_inference = False
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
# insert identity node after those edges with a new ID to distinguish the rest.
for cut_input in all_cut_inputs:
split_edges = []
for i in cut_input:
if i.consumingNodes:
# if this edge has previously been modified, update its edgeId before inserting new identity
if i.edgeId in updated_edges:
i.edgeId = updated_edges[i.edgeId]
new_edge_name = "identity_output_" + str(count)
new_added_identity.append(add_identity(model, i, new_edge_name))
count += 1
split_edges.append(new_edge_name)
updated_edges[i.edgeId] = new_edge_name
need_shape_inference = True
else:
split_edges.append(i.edgeId)
split_edge_groups.append(split_edges)
return split_edge_groups, new_added_identity, need_shape_inference
# after the graph is split, remove the added identity node because identity op is not registered in gradient builder.
def remove_identity(model, new_added_identity):
for node in new_added_identity:
assert node.op_type == "Identity"
output_nodes = [n for n in model.graph.node if node.output[0] in n.input]
for output_node in output_nodes:
for i in range(len(output_node.input)):
if output_node.input[i] == node.output[0]:
output_node.input[i] = node.input[0]
def find_all_connected_nodes(model, node):
nodes0, inputs = find_all_input_nodes(model, node)
nodes1, outputs = find_all_output_nodes(model, node)
connected_nodes = nodes0 + nodes1
return connected_nodes, inputs, outputs
def get_index(node_list, node):
found = [i for i, n in enumerate(node_list) if n == node]
return found[0] if found else None
def get_identity_index_for_deleting(node_list, node):
for i, n in enumerate(node_list):
# The node's input name has been changed during send/recv insertion,
# but it is sufficient to just compare the type and outputs.
if n.op_type == "Identity" and n.output == node.output:
return i
return None
# traverse the graph, group connected nodes and generate subgraph
def generate_subgraph(model, start_nodes, identity_node_list):
subgraphs = []
main_graph = onnx.ModelProto()
main_graph.CopyFrom(model)
# remove added identity node before copy to subgraph
identity_node_index = []
for n in identity_node_list:
identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n))
identity_node_index.sort(reverse=True)
for i in reversed(range(len(main_graph.graph.node))):
try:
if i in identity_node_index:
del main_graph.graph.node[i]
except Exception:
print("error deleting identity node", i)
all_visited_nodes = []
model_count = len(start_nodes)
for start in reversed(start_nodes):
stack0 = [start]
visited0 = []
tranversed_node = 0
inputs0 = []
outputs0 = []
while stack0:
node = stack0.pop()
if node not in visited0:
tranversed_node += 1
visited0.append(node)
all_visited_nodes.append(node)
connected_nodes, inputs, outputs = find_all_connected_nodes(main_graph, node)
stack0 = stack0 + connected_nodes
inputs0 = inputs0 + inputs
outputs0 = outputs0 + outputs
subgraph = onnx.ModelProto()
subgraph.CopyFrom(main_graph)
# gather visited nodes
visited_nodes = []
for n in visited0:
visited_nodes.append(get_index(main_graph.graph.node, n))
visited_nodes.sort(reverse=True)
# gather visited inputs
visited_inputs = []
for n in inputs0:
visited_inputs.append(get_index(main_graph.graph.input, n))
visited_inputs.sort(reverse=True)
# gather visited outputs
visited_outputs = []
for n in outputs0:
visited_outputs.append(get_index(main_graph.graph.output, n))
visited_outputs.sort(reverse=True)
for i in reversed(range(len(main_graph.graph.node))):
try:
if i not in visited_nodes:
del subgraph.graph.node[i]
else:
del main_graph.graph.node[i]
except Exception:
print("error deleting node", i)
for i in reversed(range(len(main_graph.graph.input))):
try:
if i not in visited_inputs:
del subgraph.graph.input[i]
else:
del main_graph.graph.input[i]
except Exception:
print("error deleting inputs", i)
for i in reversed(range(len(main_graph.graph.output))):
try:
if i not in visited_outputs:
del subgraph.graph.output[i]
else:
del main_graph.graph.output[i]
except Exception:
print("error deleting outputs ", i)
print("model", str(model_count), " length ", len(subgraph.graph.node))
subgraphs.append(subgraph)
model_count -= 1
print("model", str(model_count), " length ", len(main_graph.graph.node))
subgraphs.append(main_graph)
# as the subgraphs were added in reverse order (the last split is added first), reverse the order back before return
subgraphs.reverse()
return subgraphs
def main():
# temporary hard coded the cutting edge structure
# TODO: move this info to a file (json?) and load the data from there.
input_model_name = "bert-tiny-uncased_L_3_H_128_A_2_V_30528_S_512_Dp_0.1.onnx"
stage_count = 3
cut0_input = {CutEdge("186"), CutEdge("71", {"273", "395"})}
cut1_input = {CutEdge("308"), CutEdge("71", {"395"})}
all_cut_inputs = [cut0_input, cut1_input]
model = onnx.load(input_model_name)
if len(model.graph.value_info) == 0:
model = onnx.shape_inference.infer_shapes(model)
print("original model length ", len(model.graph.node))
output_model_names = [os.path.splitext(input_model_name)[0] + "_" + str(i) + ".onnx" for i in range(stage_count)]
split_edge_groups, new_identity, need_shape_inference = insert_identity(model, all_cut_inputs)
# new edge is being added, need to re-inference shape
if need_shape_inference:
model = onnx.shape_inference.infer_shapes(model)
# after all need-to-be-cut edges identified, split the graph
new_sends, new_receives = split_graph(model, split_edge_groups)
remove_identity(model, new_identity)
sub_graphs = generate_subgraph(model, new_receives, new_identity)
for i in range(stage_count):
sub_graphs[i] = onnx.shape_inference.infer_shapes(sub_graphs[i])
onnx.save(sub_graphs[i], output_model_names[i])
print("save to file: ", output_model_names[i])
if __name__ == "__main__":
main()