onnxruntime/orttraining/tools/scripts/layer_norm_transform.py
mindest 5b9369e93c
Fix typos according to reviewdog report. (#21335)
### Description
Fix typos based on reviewdog report but with some
exceptions/corrections.
2024-07-22 13:37:32 -07:00

186 lines
6.8 KiB
Python

import os.path
import sys
import numpy as np
import onnx
from onnx import * # noqa: F403
def find_node(graph_proto, op_type):
nodes = []
map_input_node = {}
for node in graph_proto.node:
if node.op_type == op_type:
node_input = node.input[1] if op_type == "Div" or op_type == "Mul" else node.input[0]
nodes.append(node)
map_input_node[node_input] = node
return nodes, map_input_node
def gen_attribute(key, value):
attr = AttributeProto() # noqa: F405
attr.name = key
attr.ints.extend(int(v) for v in value)
attr.type = AttributeProto.INTS # noqa: F405
return attr
def main():
if len(sys.argv) < 2:
print("Please give model path...")
return
model_file_path = sys.argv[1]
# model_file_path = os.path.dirname(sys.argv[1:])
print("model_file_path: " + model_file_path)
model_file_name = os.path.basename(model_file_path)
print("model_file_name: " + model_file_name)
new_model_file_path = model_file_path[:-5] + "_layer_norm.onnx"
print(new_model_file_path)
model_proto = onnx.load(model_file_path)
# print(model_proto)
graph_proto = model_proto.graph
# print(graph_proto)
# print(graph_proto.input)
nodes_Div, map_input_Div = find_node(graph_proto, "Div") # noqa: N806
# print(map_input_Div)
nodes_Sqrt, map_input_Sqrt = find_node(graph_proto, "Sqrt") # noqa: N806
# print(map_input_Sqrt)
nodes_Add, map_input_Add = find_node(graph_proto, "Add") # noqa: N806
# print(map_input_Add)
nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, "ReduceMean") # noqa: N806
# print(map_input_ReduceMean)
nodes_Pow, map_input_Pow = find_node(graph_proto, "Pow") # noqa: N806
# print(map_input_Pow)
nodes_Mul, map_input_Mul = find_node(graph_proto, "Mul") # noqa: N806
# find right side Sub
nodes_Sub = [] # noqa: N806
map_input_Sub = {} # noqa: N806
for node in graph_proto.node:
if node.op_type == "Sub":
if node.output[0] in map_input_Pow:
nodes_Sub.append(node)
map_input_Sub[node.input[1]] = node
# print(map_input_Sub)
# find first ReduceMean
first_ReduceMean = [] # noqa: N806
first_ReduceMean_outputs = [] # noqa: N806
for node in nodes_ReduceMean:
if node.output[0] in map_input_Sub:
first_ReduceMean.append(node)
first_ReduceMean_outputs.append(node.output[0])
# print(first_ReduceMean)
# find constant node
nodes_Constant = [] # noqa: N806
map_output_Constant = {} # noqa: N806
for node in graph_proto.node:
if node.op_type == "Constant":
nodes_Constant.append(node)
map_output_Constant[node.output[0]] = node
# print(map_input_Sub)
id = 0
removed_nodes = []
layer_norm_nodes = []
# Replace with layer norm
for node in first_ReduceMean:
layer_norm_input = []
layer_norm_output = []
layer_norm_input.append(node.input[0])
node_sub = map_input_Sub[node.output[0]]
node_pow = map_input_Pow[node_sub.output[0]]
node_reduce = map_input_ReduceMean[node_pow.output[0]]
node_Add = map_input_Add[node_reduce.output[0]] # noqa: N806
node_Sqrt = map_input_Sqrt[node_Add.output[0]] # noqa: N806
node_Div = map_input_Div[node_Sqrt.output[0]] # noqa: N806
node_Mul = map_input_Mul[node_Div.output[0]] # noqa: N806
layer_norm_input.append(node_Mul.input[0])
node_Add1 = map_input_Add[node_Mul.output[0]] # noqa: N806
layer_norm_input.append(node_Add1.input[1])
removed_nodes.append(node)
removed_nodes.append(node_sub)
removed_nodes.append(node_pow)
removed_nodes.append(node_reduce)
removed_nodes.append(node_Add)
removed_nodes.append(node_Sqrt)
removed_nodes.append(node_Div)
removed_nodes.append(node_Mul)
removed_nodes.append(node_Add1)
removed_nodes.append(map_output_Constant[node_pow.input[1]])
# print(map_output_Constant[node_Add.input[1]])
removed_nodes.append(map_output_Constant[node_Add.input[1]])
layer_norm_output.append(node_Add1.output[0])
id = id + 1
layer_norm_output.append("saved_mean_" + str(id))
id = id + 1
layer_norm_output.append("saved_inv_std_var_" + str(id))
layer_norm = helper.make_node( # noqa: F405
"LayerNormalization",
layer_norm_input,
layer_norm_output,
"LayerNormalization_" + str(id),
None,
axis=node_reduce.attribute[0].ints[0],
epsilon=9.999999960041972e-13,
)
layer_norm_nodes.append(layer_norm)
# remove left side Subs
for node in graph_proto.node:
if node.op_type == "Sub":
if node.input[1] in first_ReduceMean_outputs:
removed_nodes.append(node)
all_nodes = []
for node in graph_proto.node:
if node not in removed_nodes:
all_nodes.append(node)
for node in layer_norm_nodes:
all_nodes.append(node) # noqa: PERF402
graph_proto.ClearField("node")
graph_proto.node.extend(all_nodes)
with open(new_model_file_path, "wb") as f:
f.write(model_proto.SerializeToString())
# Use ORT to verify the converted model. Notice that you must use python package from the
# training branch because training requires some extra ops.
import onnxruntime as ort
# We convert model to accept variable-length batch size, so it can be any positive integer.
batch = 3
# This should match --max_seq_length when calling nv_run_pretraining.py.
sq_length = 512
# This should match vocab_size in bert_config.json in DeepLearningExamples/PyTorch/LanguageModeling/BERT.
vocab_size = 30528
# Create a fake data point.
vocab_size = 30528 # It should match the value from BERT config file.
input_ids = np.random.randint(low=0, high=vocab_size, size=(batch, sq_length), dtype=np.int64)
segment_ids = np.random.randint(low=0, high=2, size=(batch, sq_length), dtype=np.int64)
input_mask = np.ones((batch, sq_length), dtype=np.int64)
# Do forward using the original model.
sess = ort.InferenceSession(model_file_path, providers=ort.get_available_providers())
result = sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask})
# Do forward using the new model.
new_sess = ort.InferenceSession(new_model_file_path, providers=ort.get_available_providers())
new_result = new_sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask})
# Compare the outcomes from the two models.
print(np.linalg.norm(result[0] - new_result[0]))
print(np.linalg.norm(result[1] - new_result[1]))
if __name__ == "__main__":
main()