onnxruntime/tools/ci_build/gen_def.py
2019-08-20 12:04:10 -07:00

64 lines
No EOL
2 KiB
Python
Executable file

#!/usr/bin/python3
import sys
import argparse
import os
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--src_root", required=True, help="input symbol file")
parser.add_argument("--output", required=True, help="output file")
parser.add_argument("--output_source", required=True, help="output file")
parser.add_argument("--version_file", required=True, help="VERSION_NUMBER file")
parser.add_argument("--style", required=True, choices=["gcc", "vc"])
parser.add_argument("--config", required=True, nargs="+")
return parser.parse_args()
args = parse_arguments()
print("Generating symbol file for %s" % str(args.config))
with open(args.version_file, 'r') as f:
VERSION_STRING=f.read().strip();
print("VERSION:%s" % VERSION_STRING);
symbols = set()
for c in args.config:
file_name = os.path.join(args.src_root,'core', 'providers',c,'symbols.txt')
with open(file_name, 'r') as file:
for line in file:
line = line.strip()
if line in symbols:
print("dup symbol: %s", line)
exit(-1)
symbols.add(line)
symbols = sorted(symbols)
symbol_index = 1
with open(args.output, 'w') as file:
if args.style == 'vc':
file.write('LIBRARY "onnxruntime.dll"\n')
file.write('EXPORTS\n')
else:
file.write('VERS_%s {\n' % VERSION_STRING)
file.write(' global:\n')
for symbol in symbols:
if args.style == 'vc':
file.write(" %s @%d\n" % (symbol,symbol_index))
else:
file.write(" %s;\n" % symbol)
symbol_index +=1
if args.style == 'gcc':
file.write(" local:\n")
file.write(" *;\n")
file.write("}; \n")
with open(args.output_source, 'w') as file:
file.write("#include <onnxruntime_c_api.h>\n")
for c in args.config:
file.write("#include <core/providers/%s/%s_provider_factory.h>\n" % (c,c))
file.write("void* GetFunctionEntryByName(const char* name){\n")
for symbol in symbols:
file.write("if(strcmp(name,\"%s\") ==0) return (void*)&%s;\n" % (symbol,symbol))
file.write("return NULL;\n");
file.write("}\n");