mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Revert Session and InferenceSession implementation
This commit is contained in:
parent
0b1e3f1e10
commit
e2afe5e054
2 changed files with 39 additions and 44 deletions
|
|
@ -1155,7 +1155,8 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
|
|||
}
|
||||
// execute the graph
|
||||
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
|
||||
session_options_.execution_mode, run_options.terminate, run_logger));
|
||||
session_options_.execution_mode, run_options.terminate, run_logger,
|
||||
run_options.only_execute_path_to_fetches));
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
retval = Status(common::ONNXRUNTIME, common::FAIL, e.what());
|
||||
|
|
|
|||
|
|
@ -20,49 +20,9 @@ class Session:
|
|||
"""
|
||||
This is the main class used to run a model.
|
||||
"""
|
||||
|
||||
def __init__(self, path_or_bytes, sess_options=None, providers=[]):
|
||||
"""
|
||||
:param path_or_bytes: filename or serialized model in a byte string
|
||||
:param sess_options: session options
|
||||
:param providers: providers to use for session. If empty, will use
|
||||
all available providers.
|
||||
"""
|
||||
self._path_or_bytes = path_or_bytes
|
||||
self._sess_options = sess_options
|
||||
self._load_model(providers)
|
||||
def __init__(self, sess):
|
||||
self._enable_fallback = True
|
||||
|
||||
def _load_model(self, providers=[]):
|
||||
if isinstance(self._path_or_bytes, str):
|
||||
self._sess = C.InferenceSession(
|
||||
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
|
||||
True)
|
||||
elif isinstance(self._path_or_bytes, bytes):
|
||||
self._sess = C.InferenceSession(
|
||||
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
|
||||
False)
|
||||
# elif isinstance(self._path_or_bytes, tuple):
|
||||
# to remove, hidden trick
|
||||
# self._sess.load_model_no_init(self._path_or_bytes[0], providers)
|
||||
else:
|
||||
raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))
|
||||
|
||||
self._sess.load_model(providers)
|
||||
|
||||
self._session_options = self._sess.session_options
|
||||
self._inputs_meta = self._sess.inputs_meta
|
||||
self._outputs_meta = self._sess.outputs_meta
|
||||
self._overridable_initializers = self._sess.overridable_initializers
|
||||
self._model_meta = self._sess.model_meta
|
||||
self._providers = self._sess.get_providers()
|
||||
|
||||
# Tensorrt can fall back to CUDA. All others fall back to CPU.
|
||||
if 'TensorrtExecutionProvider' in C.get_available_providers():
|
||||
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
self._fallback_providers = ['CPUExecutionProvider']
|
||||
|
||||
def _reset_session(self):
|
||||
"release underlying session object."
|
||||
# meta data references session internal structures
|
||||
|
|
@ -76,7 +36,7 @@ class Session:
|
|||
|
||||
def get_session_options(self):
|
||||
"Return the session options. See :class:`onnxruntime.SessionOptions`."
|
||||
return self._session_options
|
||||
return self._sess_options
|
||||
|
||||
def get_inputs(self):
|
||||
"Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
|
||||
|
|
@ -193,7 +153,41 @@ class InferenceSession(Session):
|
|||
:param providers: providers to use for session. If empty, will use
|
||||
all available providers.
|
||||
"""
|
||||
Session.__init__(self, path_or_bytes, sess_options, providers)
|
||||
self._path_or_bytes = path_or_bytes
|
||||
self._sess_options = sess_options
|
||||
self._load_model(providers)
|
||||
self._enable_fallback = True
|
||||
Session.__init__(self, self._sess)
|
||||
|
||||
def _load_model(self, providers=[]):
|
||||
if isinstance(self._path_or_bytes, str):
|
||||
self._sess = C.InferenceSession(
|
||||
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
|
||||
True)
|
||||
elif isinstance(self._path_or_bytes, bytes):
|
||||
self._sess = C.InferenceSession(
|
||||
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
|
||||
False)
|
||||
# elif isinstance(self._path_or_bytes, tuple):
|
||||
# to remove, hidden trick
|
||||
# self._sess.load_model_no_init(self._path_or_bytes[0], providers)
|
||||
else:
|
||||
raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))
|
||||
|
||||
self._sess.load_model(providers)
|
||||
|
||||
self._sess_options = self._sess.session_options
|
||||
self._inputs_meta = self._sess.inputs_meta
|
||||
self._outputs_meta = self._sess.outputs_meta
|
||||
self._overridable_initializers = self._sess.overridable_initializers
|
||||
self._model_meta = self._sess.model_meta
|
||||
self._providers = self._sess.get_providers()
|
||||
|
||||
# Tensorrt can fall back to CUDA. All others fall back to CPU.
|
||||
if 'TensorrtExecutionProvider' in C.get_available_providers():
|
||||
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
self._fallback_providers = ['CPUExecutionProvider']
|
||||
|
||||
|
||||
class IOBinding:
|
||||
|
|
|
|||
Loading…
Reference in a new issue