The fixings for python scripts in ONNXRuntime (#4135)

* The fixings for python scripts in ONNXRuntime

* update according the comments
This commit is contained in:
Wenbing Li 2020-06-08 10:27:32 -07:00 committed by GitHub
parent 3390431d80
commit ee35320974
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 6 deletions

View file

@ -146,7 +146,7 @@ class InferenceSession(Session):
"""
This is the main class used to run a model.
"""
def __init__(self, path_or_bytes, sess_options=None, providers=[]):
def __init__(self, path_or_bytes, sess_options=None, providers=None):
"""
:param path_or_bytes: filename or serialized model in a byte string
:param sess_options: session options
@ -155,11 +155,11 @@ class InferenceSession(Session):
"""
self._path_or_bytes = path_or_bytes
self._sess_options = sess_options
self._load_model(providers)
self._load_model(providers or [])
self._enable_fallback = True
Session.__init__(self, self._sess)
def _load_model(self, providers=[]):
def _load_model(self, providers):
if isinstance(self._path_or_bytes, str):
self._sess = C.InferenceSession(
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,

View file

@ -326,6 +326,8 @@ setup(
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Operating System :: POSIX :: Linux',
'Operating System :: Microsoft :: Windows',
'Operating System :: MacOS',
'Programming Language :: Python',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.5',

View file

@ -1403,10 +1403,10 @@ def nuphar_run_python_tests(build_dir, configs):
def build_python_wheel(
source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl,
use_tensorrt, use_openvino, use_nuphar, use_vitisai, wheel_name_suffix,
use_acl, nightly_build=False, featurizers_build=False):
use_acl, nightly_build=False, featurizers_build=False, use_ninja=False):
for config in configs:
cwd = get_config_build_dir(build_dir, config)
if is_windows():
if is_windows() and not use_ninja:
cwd = os.path.join(cwd, config)
args = [sys.executable, os.path.join(source_dir, 'setup.py'),
@ -1796,6 +1796,7 @@ def main():
args.use_acl,
nightly_build=nightly_build,
featurizers_build=args.use_featurizers,
use_ninja=(args.cmake_generator == 'Ninja')
)
if args.gen_doc and (args.build or args.test):

View file

@ -42,7 +42,7 @@ def main():
out = os.path.abspath(args.out)
if not os.path.exists(out):
os.mkdirs(out)
os.makedirs(out)
model = onnx.load_model(model_path)
dump_subgraph(model, out)