Revert Session and InferenceSession implementation

This commit is contained in:
Thiago Crepaldi 2020-04-02 03:13:46 +00:00
parent 0b1e3f1e10
commit e2afe5e054
2 changed files with 39 additions and 44 deletions

View file

@ -1155,7 +1155,8 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
}
// execute the graph
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode, run_options.terminate, run_logger));
session_options_.execution_mode, run_options.terminate, run_logger,
run_options.only_execute_path_to_fetches));
} catch (const std::exception& e) {
retval = Status(common::ONNXRUNTIME, common::FAIL, e.what());

View file

@ -20,49 +20,9 @@ class Session:
"""
This is the main class used to run a model.
"""
def __init__(self, path_or_bytes, sess_options=None, providers=[]):
"""
:param path_or_bytes: filename or serialized model in a byte string
:param sess_options: session options
:param providers: providers to use for session. If empty, will use
all available providers.
"""
self._path_or_bytes = path_or_bytes
self._sess_options = sess_options
self._load_model(providers)
def __init__(self, sess):
self._enable_fallback = True
def _load_model(self, providers=[]):
if isinstance(self._path_or_bytes, str):
self._sess = C.InferenceSession(
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
True)
elif isinstance(self._path_or_bytes, bytes):
self._sess = C.InferenceSession(
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
False)
# elif isinstance(self._path_or_bytes, tuple):
# to remove, hidden trick
# self._sess.load_model_no_init(self._path_or_bytes[0], providers)
else:
raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))
self._sess.load_model(providers)
self._session_options = self._sess.session_options
self._inputs_meta = self._sess.inputs_meta
self._outputs_meta = self._sess.outputs_meta
self._overridable_initializers = self._sess.overridable_initializers
self._model_meta = self._sess.model_meta
self._providers = self._sess.get_providers()
# Tensorrt can fall back to CUDA. All others fall back to CPU.
if 'TensorrtExecutionProvider' in C.get_available_providers():
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
self._fallback_providers = ['CPUExecutionProvider']
def _reset_session(self):
"release underlying session object."
# meta data references session internal structures
@ -76,7 +36,7 @@ class Session:
def get_session_options(self):
"Return the session options. See :class:`onnxruntime.SessionOptions`."
return self._session_options
return self._sess_options
def get_inputs(self):
"Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
@ -193,7 +153,41 @@ class InferenceSession(Session):
:param providers: providers to use for session. If empty, will use
all available providers.
"""
Session.__init__(self, path_or_bytes, sess_options, providers)
self._path_or_bytes = path_or_bytes
self._sess_options = sess_options
self._load_model(providers)
self._enable_fallback = True
Session.__init__(self, self._sess)
def _load_model(self, providers=[]):
if isinstance(self._path_or_bytes, str):
self._sess = C.InferenceSession(
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
True)
elif isinstance(self._path_or_bytes, bytes):
self._sess = C.InferenceSession(
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
False)
# elif isinstance(self._path_or_bytes, tuple):
# to remove, hidden trick
# self._sess.load_model_no_init(self._path_or_bytes[0], providers)
else:
raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))
self._sess.load_model(providers)
self._sess_options = self._sess.session_options
self._inputs_meta = self._sess.inputs_meta
self._outputs_meta = self._sess.outputs_meta
self._overridable_initializers = self._sess.overridable_initializers
self._model_meta = self._sess.model_meta
self._providers = self._sess.get_providers()
# Tensorrt can fall back to CUDA. All others fall back to CPU.
if 'TensorrtExecutionProvider' in C.get_available_providers():
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
self._fallback_providers = ['CPUExecutionProvider']
class IOBinding: