From ee3532097477adf4e53aec270a782dd228d3625d Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Mon, 8 Jun 2020 10:27:32 -0700 Subject: [PATCH] The fixings for python scripts in ONNXRuntime (#4135) * The fixings for python scripts in ONNXRuntime * update according the comments --- onnxruntime/python/session.py | 6 +++--- setup.py | 2 ++ tools/ci_build/build.py | 5 +++-- tools/python/dump_subgraphs.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/session.py b/onnxruntime/python/session.py index 03608c418b..d187e1b500 100644 --- a/onnxruntime/python/session.py +++ b/onnxruntime/python/session.py @@ -146,7 +146,7 @@ class InferenceSession(Session): """ This is the main class used to run a model. """ - def __init__(self, path_or_bytes, sess_options=None, providers=[]): + def __init__(self, path_or_bytes, sess_options=None, providers=None): """ :param path_or_bytes: filename or serialized model in a byte string :param sess_options: session options @@ -155,11 +155,11 @@ class InferenceSession(Session): """ self._path_or_bytes = path_or_bytes self._sess_options = sess_options - self._load_model(providers) + self._load_model(providers or []) self._enable_fallback = True Session.__init__(self, self._sess) - def _load_model(self, providers=[]): + def _load_model(self, providers): if isinstance(self._path_or_bytes, str): self._sess = C.InferenceSession( self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes, diff --git a/setup.py b/setup.py index c62eae7af1..5d47d51401 100644 --- a/setup.py +++ b/setup.py @@ -326,6 +326,8 @@ setup( 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', 'Operating System :: POSIX :: Linux', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: MacOS', 'Programming Language :: Python', 'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3.5', diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0a0c1cbc57..6907161e85 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1403,10 +1403,10 @@ def nuphar_run_python_tests(build_dir, configs): def build_python_wheel( source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl, use_tensorrt, use_openvino, use_nuphar, use_vitisai, wheel_name_suffix, - use_acl, nightly_build=False, featurizers_build=False): + use_acl, nightly_build=False, featurizers_build=False, use_ninja=False): for config in configs: cwd = get_config_build_dir(build_dir, config) - if is_windows(): + if is_windows() and not use_ninja: cwd = os.path.join(cwd, config) args = [sys.executable, os.path.join(source_dir, 'setup.py'), @@ -1796,6 +1796,7 @@ def main(): args.use_acl, nightly_build=nightly_build, featurizers_build=args.use_featurizers, + use_ninja=(args.cmake_generator == 'Ninja') ) if args.gen_doc and (args.build or args.test): diff --git a/tools/python/dump_subgraphs.py b/tools/python/dump_subgraphs.py index 12f6884352..52036eda45 100644 --- a/tools/python/dump_subgraphs.py +++ b/tools/python/dump_subgraphs.py @@ -42,7 +42,7 @@ def main(): out = os.path.abspath(args.out) if not os.path.exists(out): - os.mkdirs(out) + os.makedirs(out) model = onnx.load_model(model_path) dump_subgraph(model, out)