2020-04-14 16:06:04 +00:00
|
|
|
import argparse
|
|
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
import onnx
|
|
|
|
|
|
2020-04-14 16:06:04 +00:00
|
|
|
|
|
|
|
|
def get_args():
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument("--input", required=True, help="input model")
|
|
|
|
|
parser.add_argument("--output", required=True, help="output model")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_initializer_from_input():
|
|
|
|
|
args = get_args()
|
|
|
|
|
|
|
|
|
|
model = onnx.load(args.input)
|
|
|
|
|
if model.ir_version < 4:
|
2022-04-26 16:35:16 +00:00
|
|
|
print("Model with ir_version below 4 requires to include initilizer in graph input")
|
2020-04-14 16:06:04 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
|
|
inputs = model.graph.input
|
|
|
|
|
name_to_input = {}
|
|
|
|
|
for input in inputs:
|
|
|
|
|
name_to_input[input.name] = input
|
|
|
|
|
|
|
|
|
|
for initializer in model.graph.initializer:
|
|
|
|
|
if initializer.name in name_to_input:
|
|
|
|
|
inputs.remove(name_to_input[initializer.name])
|
|
|
|
|
|
|
|
|
|
onnx.save(model, args.output)
|
|
|
|
|
|
|
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
if __name__ == "__main__":
|
2020-04-14 16:06:04 +00:00
|
|
|
remove_initializer_from_input()
|