mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
add pipeline graph split script (#3275)
* pipeline graph cut * add element type * add input wait event and shape info * shape inference * support multiple cuts * format script * address feedback * address feedback
This commit is contained in:
parent
fb2f97a002
commit
efc8bd738f
1 changed files with 403 additions and 0 deletions
403
orttraining/tools/scripts/pipeline_model_split.py
Normal file
403
orttraining/tools/scripts/pipeline_model_split.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
import sys
|
||||
import os
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
from onnx import OperatorSetIdProto
|
||||
|
||||
# Edge that needs to be cut for the split.
|
||||
# If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut,
|
||||
# specify those consuming nodes that need to be cut
|
||||
|
||||
|
||||
class CutEdge:
|
||||
def __init__(self, edgeId, consumingNodes=None):
|
||||
self.edgeId = edgeId
|
||||
self.consumingNodes = consumingNodes
|
||||
|
||||
|
||||
def add_expand_type(model, name, type):
|
||||
expand_edge = model.graph.value_info.add()
|
||||
expand_edge.name = name
|
||||
expand_edge.type.CopyFrom(type)
|
||||
|
||||
# Add wait/record/send/recv nodes and split the graph into disconnected subgraphs
|
||||
|
||||
|
||||
def split_graph(model, split_edge_groups):
|
||||
ms_domain = 'com.microsoft'
|
||||
|
||||
new_send_nodes = []
|
||||
new_recv_nodes = []
|
||||
# Add wait for initial inputs. This needs to be done first before new inputs
|
||||
# are introduced from split
|
||||
initializer_lists = [a.name for a in model.graph.initializer]
|
||||
input_tensors = [
|
||||
value.name for value in model.graph.input if value.name not in initializer_lists]
|
||||
|
||||
input_wait_signal = model.graph.input.add()
|
||||
input_wait_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'input_wait_signal', onnx.TensorProto.INT64, None))
|
||||
|
||||
input_wait = model.graph.node.add()
|
||||
input_wait.CopyFrom(helper.make_node(
|
||||
'WaitEvent',
|
||||
inputs=['input_wait_signal'],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
for i in input_tensors:
|
||||
for node in model.graph.node:
|
||||
for j in range(len(node.input)):
|
||||
if node.input[j] == i:
|
||||
node.input[j] = i + '_sync'
|
||||
|
||||
input_wait.input.extend(input_tensors)
|
||||
input_wait.output.extend([i + '_sync' for i in input_tensors])
|
||||
|
||||
for cut_index in range(len(split_edge_groups)):
|
||||
edgeIds = split_edge_groups[cut_index]
|
||||
|
||||
# split the graph based on edgeIds
|
||||
upstream_nodes = []
|
||||
upstream_nodes_output_index = []
|
||||
output_shapes = []
|
||||
|
||||
for id in edgeIds:
|
||||
for node in model.graph.node:
|
||||
if len(node.output) >= 1:
|
||||
for i, j in enumerate(node.output):
|
||||
if j == id:
|
||||
upstream_nodes.append(node)
|
||||
upstream_nodes_output_index.append(i)
|
||||
for info in model.graph.value_info:
|
||||
if info.name == id:
|
||||
output_shapes.append(info.type)
|
||||
|
||||
record_signal = model.graph.input.add()
|
||||
record_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'record_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
|
||||
|
||||
wait_signal = model.graph.input.add()
|
||||
wait_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'wait_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
|
||||
|
||||
send_signal = model.graph.input.add()
|
||||
send_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'send_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
|
||||
recv_signal = model.graph.input.add()
|
||||
recv_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'recv_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
|
||||
# output signal from send after cut
|
||||
send_output_signal = model.graph.output.add()
|
||||
send_output_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'send_output_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
|
||||
# output signal from receive after cut
|
||||
receive_output_signal = model.graph.output.add()
|
||||
receive_output_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'receive_output_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
|
||||
new_send = model.graph.node.add()
|
||||
new_send.CopyFrom(helper.make_node(
|
||||
'Send',
|
||||
inputs=['send_input_signal' + str(cut_index)],
|
||||
outputs=['send_output_signal' + str(cut_index)],
|
||||
tag=0,
|
||||
src=cut_index,
|
||||
dst=cut_index + 1,
|
||||
domain=ms_domain,
|
||||
element_type=7, # assuming all tensors are of type float
|
||||
name='send'))
|
||||
|
||||
new_receive = model.graph.node.add()
|
||||
new_receive.CopyFrom(helper.make_node(
|
||||
'Recv',
|
||||
inputs=['recv_input_signal' + str(cut_index)],
|
||||
outputs=['receive_output_signal' + str(cut_index)],
|
||||
tag=1,
|
||||
src=cut_index,
|
||||
dst=cut_index + 1,
|
||||
domain=ms_domain,
|
||||
element_type=7, # assuming all tensors are of type float
|
||||
name='receive'))
|
||||
|
||||
new_wait = model.graph.node.add()
|
||||
new_wait.CopyFrom(helper.make_node(
|
||||
'WaitEvent',
|
||||
inputs=['wait_input_signal' + str(cut_index)],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
new_record = model.graph.node.add()
|
||||
new_record.CopyFrom(helper.make_node(
|
||||
'RecordEvent',
|
||||
inputs=['record_input_signal' + str(cut_index)],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
for i in range(len(upstream_nodes)):
|
||||
n = upstream_nodes[i]
|
||||
idx = upstream_nodes_output_index[i]
|
||||
output_type = output_shapes[i]
|
||||
output_edge_name = n.output[idx]
|
||||
|
||||
output_nodes = find_all_output_nodes_by_edge(
|
||||
model, output_edge_name)
|
||||
|
||||
# deal with shape inference for newly added edge
|
||||
new_send_input_name = output_edge_name + '_send' + str(cut_index)
|
||||
add_expand_type(model, new_send_input_name, output_type)
|
||||
|
||||
new_receive_output_name = output_edge_name + \
|
||||
'_recv' + str(cut_index)
|
||||
add_expand_type(model, new_receive_output_name, output_type)
|
||||
|
||||
new_wait_output_name = output_edge_name + '_wait' + str(cut_index)
|
||||
add_expand_type(model, new_wait_output_name, output_type)
|
||||
|
||||
# the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input
|
||||
new_record.input.extend([output_edge_name])
|
||||
new_record.output.extend([new_send_input_name])
|
||||
|
||||
new_send.input.extend([new_send_input_name])
|
||||
new_receive.output.extend([new_receive_output_name])
|
||||
|
||||
new_wait.input.extend([new_receive_output_name])
|
||||
new_wait.output.extend([new_wait_output_name])
|
||||
|
||||
for output_node in output_nodes:
|
||||
for i in range(len(output_node.input)):
|
||||
for edgeId in edgeIds:
|
||||
if output_node.input[i] == edgeId:
|
||||
output_node.input[i] = new_wait_output_name
|
||||
|
||||
new_send_nodes.append(new_send)
|
||||
new_recv_nodes.append(new_receive)
|
||||
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
|
||||
return new_send_nodes, new_recv_nodes
|
||||
|
||||
|
||||
def find_all_input_nodes(model, node):
|
||||
nodes = []
|
||||
inputs = []
|
||||
|
||||
if node:
|
||||
for inputId in node.input:
|
||||
nodes.extend([n for n in model.graph.node if inputId in n.output])
|
||||
inputs.extend([n for n in model.graph.input if inputId in n.name])
|
||||
return nodes, inputs
|
||||
|
||||
|
||||
def find_all_output_nodes(model, node):
|
||||
nodes = []
|
||||
outputs = []
|
||||
if node:
|
||||
for outputId in node.output:
|
||||
nodes.extend([n for n in model.graph.node if outputId in n.input])
|
||||
outputs.extend(
|
||||
[n for n in model.graph.output if outputId in n.name])
|
||||
return nodes, outputs
|
||||
|
||||
|
||||
def find_all_output_nodes_by_edge(model, arg):
|
||||
result = [n for n in model.graph.node if arg in n.input]
|
||||
return result
|
||||
|
||||
# Insert identity nodes to separate same output edge which feeds into different sub-graph.
|
||||
|
||||
|
||||
def add_identity(model, cuttingEdge, newEdgeIdName):
|
||||
output_nodes = None
|
||||
edgeId = cuttingEdge.edgeId
|
||||
for node in model.graph.node:
|
||||
if len(node.output) >= 1:
|
||||
for output in node.output:
|
||||
if output == edgeId:
|
||||
output_nodes = find_all_output_nodes_by_edge(model, output)
|
||||
break
|
||||
|
||||
assert output_nodes, "no output node"
|
||||
|
||||
new_identity = model.graph.node.add()
|
||||
new_identity.op_type = 'Identity'
|
||||
|
||||
new_identity.input.extend([edgeId])
|
||||
new_identity.output.extend([newEdgeIdName])
|
||||
|
||||
for i in range(len(output_nodes)):
|
||||
for output in output_nodes[i].output:
|
||||
if output in cuttingEdge.consumingNodes:
|
||||
for j in range(len(output_nodes[i].input)):
|
||||
if output_nodes[i].input[j] == edgeId:
|
||||
output_nodes[i].input[j] = newEdgeIdName
|
||||
|
||||
return newEdgeIdName
|
||||
|
||||
|
||||
def find_all_connected_nodes(model, node):
|
||||
nodes0, inputs = find_all_input_nodes(model, node)
|
||||
nodes1, outputs = find_all_output_nodes(model, node)
|
||||
|
||||
connected_nodes = nodes0 + nodes1
|
||||
return connected_nodes, inputs, outputs
|
||||
|
||||
|
||||
def get_index(node_list, node):
|
||||
found = [i for i, n in enumerate(node_list) if n == node]
|
||||
return found[0] if found else None
|
||||
|
||||
# traverse the graph, group connected nodes and generate subgraph
|
||||
|
||||
|
||||
def generate_subgraph(model, start_nodes):
|
||||
subgraphs = []
|
||||
|
||||
main_graph = onnx.ModelProto()
|
||||
main_graph.CopyFrom(model)
|
||||
|
||||
all_visited_nodes = []
|
||||
model_count = len(start_nodes)
|
||||
for start in reversed(start_nodes):
|
||||
stack0 = [start]
|
||||
|
||||
visited0 = []
|
||||
tranversed_node = 0
|
||||
inputs0 = []
|
||||
outputs0 = []
|
||||
while stack0:
|
||||
node = stack0.pop()
|
||||
if not node in visited0:
|
||||
tranversed_node += 1
|
||||
visited0.append(node)
|
||||
all_visited_nodes.append(node)
|
||||
connected_nodes, inputs, outputs = find_all_connected_nodes(
|
||||
main_graph, node)
|
||||
|
||||
stack0 = stack0 + connected_nodes
|
||||
inputs0 = inputs0 + inputs
|
||||
outputs0 = outputs0 + outputs
|
||||
|
||||
subgraph = onnx.ModelProto()
|
||||
subgraph.CopyFrom(main_graph)
|
||||
|
||||
# gather visited nodes
|
||||
visited_nodes = []
|
||||
for n in visited0:
|
||||
visited_nodes.append(get_index(main_graph.graph.node, n))
|
||||
visited_nodes.sort(reverse=True)
|
||||
|
||||
# gather visited inputs
|
||||
visited_inputs = []
|
||||
for n in inputs0:
|
||||
visited_inputs.append(get_index(main_graph.graph.input, n))
|
||||
visited_inputs.sort(reverse=True)
|
||||
|
||||
# gather visited outputs
|
||||
visited_outputs = []
|
||||
for n in outputs0:
|
||||
visited_outputs.append(
|
||||
get_index(main_graph.graph.output, n))
|
||||
visited_outputs.sort(reverse=True)
|
||||
|
||||
for i in reversed(range(len(main_graph.graph.node))):
|
||||
try:
|
||||
if i not in visited_nodes:
|
||||
del subgraph.graph.node[i]
|
||||
else:
|
||||
del main_graph.graph.node[i]
|
||||
except:
|
||||
print("error deleting node", i)
|
||||
|
||||
for i in reversed(range(len(main_graph.graph.input))):
|
||||
try:
|
||||
if i not in visited_inputs:
|
||||
del subgraph.graph.input[i]
|
||||
else:
|
||||
del main_graph.graph.input[i]
|
||||
except:
|
||||
print("error deleting inputs", i)
|
||||
|
||||
for i in reversed(range(len(main_graph.graph.output))):
|
||||
try:
|
||||
if i not in visited_outputs:
|
||||
del subgraph.graph.output[i]
|
||||
else:
|
||||
del main_graph.graph.output[i]
|
||||
except:
|
||||
print("error deleting outputs ", i)
|
||||
|
||||
print("model", str(model_count), " length ", len(subgraph.graph.node))
|
||||
subgraphs.append(subgraph)
|
||||
model_count -= 1
|
||||
|
||||
print("model", str(model_count), " length ", len(main_graph.graph.node))
|
||||
subgraphs.append(main_graph)
|
||||
|
||||
# as the subgraphs were added in reverse order (the last split is added first), reverse the order back before return
|
||||
subgraphs.reverse()
|
||||
return subgraphs
|
||||
|
||||
|
||||
def main():
|
||||
# temporary hard coded the cutting edge structure
|
||||
# TODO: move this info to a file (json?) and load the data from there.
|
||||
input_model_name = 'bert-tiny-uncased_L_3_H_128_A_2_V_30528_S_512_Dp_0.1.onnx'
|
||||
stage_count = 3
|
||||
|
||||
cut0_input = {CutEdge('186'), CutEdge('71', {'273', '395'})}
|
||||
cut1_input = {CutEdge('308'), CutEdge('71', {'395'})}
|
||||
all_cut_inputs = [cut0_input, cut1_input]
|
||||
|
||||
model = onnx.load(input_model_name)
|
||||
if len(model.graph.value_info) == 0:
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
|
||||
print("original model length ", len(model.graph.node))
|
||||
|
||||
output_model_names = [os.path.splitext(input_model_name)[0] + '_' +
|
||||
str(i) + '.onnx' for i in range(stage_count)]
|
||||
|
||||
split_edge_groups = []
|
||||
count = 0
|
||||
updated_edges = {}
|
||||
need_shape_inference = False
|
||||
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
|
||||
# insert identity node after those edges with a new ID to distinguish the rest.
|
||||
for cut_input in all_cut_inputs:
|
||||
split_edges = []
|
||||
for i in cut_input:
|
||||
if i.consumingNodes:
|
||||
# if this edge has previously been modified, update its edgeId before inserting new identity
|
||||
if i.edgeId in updated_edges:
|
||||
i.edgeId = updated_edges[i.edgeId]
|
||||
|
||||
new_edge_name = 'identity_output_' + str(count)
|
||||
add_identity(model, i, new_edge_name)
|
||||
count += 1
|
||||
split_edges.append(new_edge_name)
|
||||
updated_edges[i.edgeId] = new_edge_name
|
||||
need_shape_inference = True
|
||||
else:
|
||||
split_edges.append(i.edgeId)
|
||||
split_edge_groups.append(split_edges)
|
||||
|
||||
# new edge is being added, need to re-inference shape
|
||||
if need_shape_inference:
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
|
||||
# after all need-to-be-cut edges identified, split the graph
|
||||
new_sends, new_receives = split_graph(model, split_edge_groups)
|
||||
sub_graphs = generate_subgraph(model, new_receives)
|
||||
|
||||
for i in range(stage_count):
|
||||
sub_graphs[i] = onnx.shape_inference.infer_shapes(sub_graphs[i])
|
||||
onnx.save(sub_graphs[i], output_model_names[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue