import argparse import onnx def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--input", required=True, help="input model") parser.add_argument("--output", required=True, help="output model") args = parser.parse_args() return args def remove_initializer_from_input(): args = get_args() model = onnx.load(args.input) if model.ir_version < 4: print("Model with ir_version below 4 requires to include initilizer in graph input") return inputs = model.graph.input name_to_input = {} for input in inputs: name_to_input[input.name] = input for initializer in model.graph.initializer: if initializer.name in name_to_input: inputs.remove(name_to_input[initializer.name]) onnx.save(model, args.output) if __name__ == "__main__": remove_initializer_from_input()