mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
python session.run() fallback to CPU/CUDA provider for EP failures. (#1960)
* py fallback initial commit. * fixes. * update NGRAPHCustomOp::Initialize() to return Status * fixes in session.py * FAIL status to EP_FAIL in ngraph custom op * disable fallback for backend api
This commit is contained in:
parent
622ea4248d
commit
f9bf546e3c
4 changed files with 82 additions and 40 deletions
|
|
@ -39,9 +39,7 @@ static bool check_ngraph_dump_ops() {
|
|||
|
||||
NGRAPHCustomOp::NGRAPHCustomOp(const ComputeContext* context,
|
||||
const ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const std::shared_ptr<ngraph::runtime::Backend>& ng_backend) :
|
||||
ng_backend_{ng_backend}, model_proto_{model_proto}
|
||||
{
|
||||
const std::shared_ptr<ngraph::runtime::Backend>& ng_backend) : ng_backend_{ng_backend}, model_proto_{model_proto} {
|
||||
allocate_func_ = context->allocate_func;
|
||||
release_func_ = context->release_func;
|
||||
allocator_ = context->allocator_handle;
|
||||
|
|
@ -60,7 +58,7 @@ NGRAPHCustomOp::~NGRAPHCustomOp() {
|
|||
}
|
||||
|
||||
//This method gets called in critical path of execution: Optimize
|
||||
void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const {
|
||||
Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const {
|
||||
Ort::CustomOpApi ort{*api};
|
||||
|
||||
size_t num_inputs = ort.KernelContext_GetInputCount(context);
|
||||
|
|
@ -84,18 +82,18 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
|
|||
|
||||
// Get cache size from environment
|
||||
std::string tempSize;
|
||||
#ifdef _WIN32
|
||||
char *buf{nullptr};
|
||||
#ifdef _WIN32
|
||||
char* buf{nullptr};
|
||||
size_t bufSize = 0;
|
||||
if (!_dupenv_s(&buf, &bufSize, "ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE") && buf) {
|
||||
tempSize = buf;
|
||||
free(buf);
|
||||
}
|
||||
#else
|
||||
#else
|
||||
if (std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE")) {
|
||||
tempSize = std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE");
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
size_t cacheSize = tempSize.empty() ? NGRAPH_EP_LRU_CACHE_DEFAULT_SIZE : std::stoi(tempSize);
|
||||
|
||||
// Not in cache
|
||||
|
|
@ -104,20 +102,20 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
|
|||
if (keyCache.size() == cacheSize) {
|
||||
// Delete least recently used element
|
||||
std::string last = keyCache.back();
|
||||
|
||||
|
||||
// Pop the last elmeent
|
||||
keyCache.pop_back();
|
||||
|
||||
|
||||
// Erase the last element from cache
|
||||
ng_exe_map_.erase(ng_exe_map_.find(last));
|
||||
}
|
||||
}
|
||||
|
||||
// Found in cache
|
||||
ng_exe_map_.erase(ng_exe_map_.find(last));
|
||||
}
|
||||
}
|
||||
|
||||
// Found in cache
|
||||
else {
|
||||
keyCache.remove(uniq_input_shape);
|
||||
}
|
||||
|
||||
|
||||
// update reference
|
||||
keyCache.push_front(uniq_input_shape);
|
||||
auto it = ng_exe_map_.insert({uniq_input_shape, nullptr});
|
||||
|
|
@ -125,7 +123,7 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
|
|||
//ng_exe with current shape already exists
|
||||
if (!it.second) {
|
||||
ng_curr_exe_ = it.first->second;
|
||||
return;
|
||||
return Status::OK();
|
||||
} else {
|
||||
auto graph_proto = model_proto_.mutable_graph();
|
||||
|
||||
|
|
@ -151,13 +149,11 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
|
|||
try {
|
||||
ng_function = ngraph::onnx_import::import_onnx_model(model_stream);
|
||||
} catch (const std::exception& exp) {
|
||||
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - "
|
||||
<< "Exception while importing model to nGraph: " << std::string(exp.what());
|
||||
throw;
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"[NGRAPHCustomOp] - " + name_ + " - Exception while importing model to nGraph: " + std::string(exp.what()));
|
||||
} catch (...) {
|
||||
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - "
|
||||
<< "Unknown exception while importing model to nGraph";
|
||||
throw;
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"[NGRAPHCustomOp] - " + name_ + " - Unknown exception while importing model to nGraph");
|
||||
}
|
||||
|
||||
for (auto& result : ng_function->get_results()) {
|
||||
|
|
@ -168,14 +164,16 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
|
|||
try {
|
||||
ng_curr_exe_ = ng_backend_->compile(ng_function);
|
||||
} catch (const std::exception& exp) {
|
||||
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - "
|
||||
<< "Exception while compiling ngraph::Function: " << std::string(exp.what());
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"[NGRAPHCustomOp] - " + name_ + " - Exception while compiling ngraph::Function: " + std::string(exp.what()));
|
||||
} catch (...) {
|
||||
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " << "Unknown exception while compiling ngraph::Function";
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"[NGRAPHCustomOp] - " + name_ + " - Unknown exception while compiling ngraph::Function");
|
||||
}
|
||||
it.first->second = ng_curr_exe_;
|
||||
}
|
||||
} // namespace ngraph_ep
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//This method gets called in critical path of execution: Optimize
|
||||
Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const {
|
||||
|
|
@ -184,7 +182,7 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
|
|||
// Initialize nGraph function if it is not already initialized.
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(compute_lock_);
|
||||
Initialize(api, context);
|
||||
ORT_RETURN_IF_ERROR(Initialize(api, context));
|
||||
}
|
||||
|
||||
ORT_ENFORCE(ng_curr_exe_ != nullptr);
|
||||
|
|
@ -202,9 +200,9 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
|
|||
ng_inputs.emplace_back(ng_backend_->create_tensor(ng_param->get_output_element_type(0), ng_param->get_output_shape(0), input_data));
|
||||
}
|
||||
} catch (const std::exception& exp) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while copying input data to nGraph: " + std::string(exp.what()));
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while copying input data to nGraph: " + std::string(exp.what()));
|
||||
} catch (...) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while copying input data to nGraph");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while copying input data to nGraph");
|
||||
}
|
||||
|
||||
// Initialize output tensors
|
||||
|
|
@ -222,20 +220,20 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
|
|||
ng_outputs.emplace_back(ng_backend_->create_tensor(dtype, shape, output_data));
|
||||
}
|
||||
} catch (const std::exception& exp) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while creating nGraph output Tensor: " + std::string(exp.what()));
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while creating nGraph output Tensor: " + std::string(exp.what()));
|
||||
} catch (...) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while creating nGraph output Tensor");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while creating nGraph output Tensor");
|
||||
}
|
||||
|
||||
// Run the graph through nGraph.
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(compute_lock_);
|
||||
if (!ng_curr_exe_->call(ng_outputs, ng_inputs))
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Error while executing nGraph computation");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Error while executing nGraph computation");
|
||||
} catch (const std::exception& exp) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while executing nGraph computation: " + std::string(exp.what()));
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while executing nGraph computation: " + std::string(exp.what()));
|
||||
} catch (...) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while executing nGraph computation");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while executing nGraph computation");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class NGRAPHCustomOp {
|
|||
~NGRAPHCustomOp();
|
||||
|
||||
private:
|
||||
void Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const;
|
||||
Status Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const;
|
||||
|
||||
std::shared_ptr<ngraph::runtime::Backend> ng_backend_;
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ class NGRAPHCustomOp {
|
|||
*/
|
||||
mutable std::unordered_map<std::string, std::shared_ptr<ngraph::runtime::Executable>> ng_exe_map_;
|
||||
mutable std::list<std::string> keyCache;
|
||||
|
||||
|
||||
mutable std::mutex compute_lock_;
|
||||
|
||||
mutable ONNX_NAMESPACE::ModelProto model_proto_;
|
||||
|
|
|
|||
|
|
@ -67,6 +67,9 @@ class OnnxRuntimeBackend(Backend):
|
|||
if hasattr(options, k):
|
||||
setattr(options, k, v)
|
||||
inf = InferenceSession(model, options)
|
||||
# backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
|
||||
# which may hide test failures.
|
||||
inf.disable_fallback()
|
||||
if device is not None and not cls.supports_device(device):
|
||||
raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device()))
|
||||
return cls.prepare(inf, device, **kwargs)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class InferenceSession:
|
|||
self._path_or_bytes = path_or_bytes
|
||||
self._sess_options = sess_options
|
||||
self._load_model()
|
||||
self._enable_fallback = True
|
||||
|
||||
def _load_model(self, providers=[]):
|
||||
if self._sess_options:
|
||||
|
|
@ -46,6 +47,12 @@ class InferenceSession:
|
|||
self._model_meta = self._sess.model_meta
|
||||
self._providers = self._sess.get_providers()
|
||||
|
||||
# Tensorrt can fall back to CUDA. All others fall back to CPU.
|
||||
if 'TensorrtExecutionProvider' in C.get_available_providers():
|
||||
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
self._fallback_providers = ['CPUExecutionProvider']
|
||||
|
||||
def _reset_session(self):
|
||||
"release underlying session object."
|
||||
# meta data references session internal structures
|
||||
|
|
@ -78,12 +85,34 @@ class InferenceSession:
|
|||
return self._providers
|
||||
|
||||
def set_providers(self, providers):
|
||||
"Register the input list of execution providers. The underlying session is re-created."
|
||||
"""
|
||||
Register the input list of execution providers. The underlying session is re-created.
|
||||
|
||||
:param providers: list of execution providers
|
||||
|
||||
The list of providers is ordered by Priority. For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] means
|
||||
execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
|
||||
"""
|
||||
if not set(providers).issubset(C.get_available_providers()):
|
||||
raise ValueError("{} does not contain a subset of available providers {}".format(providers, C.get_available_providers()))
|
||||
raise ValueError("{} does not contain a subset of available providers {}".format(providers, C.get_available_providers()))
|
||||
self._reset_session()
|
||||
self._load_model(providers)
|
||||
|
||||
def disable_fallback(self):
|
||||
"""
|
||||
Disable session.run() fallback mechanism.
|
||||
"""
|
||||
self._enable_fallback = False
|
||||
|
||||
def enable_fallback(self):
|
||||
"""
|
||||
Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure, reset the Execution Providers
|
||||
enabled for this session.
|
||||
If GPU is enabled, fall back to CUDAExecutionProvider.
|
||||
otherwise fall back to CPUExecutionProvider.
|
||||
"""
|
||||
self._enable_fallback = True
|
||||
|
||||
def run(self, output_names, input_feed, run_options=None):
|
||||
"""
|
||||
Compute the predictions.
|
||||
|
|
@ -103,7 +132,19 @@ class InferenceSession:
|
|||
raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
|
||||
if not output_names:
|
||||
output_names = [output.name for output in self._outputs_meta]
|
||||
return self._sess.run(output_names, input_feed, run_options)
|
||||
try:
|
||||
return self._sess.run(output_names, input_feed, run_options)
|
||||
except C.EPFail as err:
|
||||
if self._enable_fallback:
|
||||
print("EP Error: {} using {}".format(str(err), self._providers))
|
||||
print("Falling back to {} and retrying.".format(self._fallback_providers))
|
||||
self.set_providers(self._fallback_providers)
|
||||
# Fallback only once.
|
||||
self.disable_fallback()
|
||||
return self._sess.run(output_names, input_feed, run_options)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def end_profiling(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue