From f9bf546e3c19d1739590a24cfdf3b08a503f8c73 Mon Sep 17 00:00:00 2001 From: George Wu Date: Wed, 2 Oct 2019 02:38:03 -0700 Subject: [PATCH] python session.run() fallback to CPU/CUDA provider for EP failures. (#1960) * py fallback initial commit. * fixes. * update NGRAPHCustomOp::Initialize() to return Status * fixes in session.py * FAIL status to EP_FAIL in ngraph custom op * disable fallback for backend api --- .../core/providers/ngraph/ngraph_custom_op.cc | 68 +++++++++---------- .../core/providers/ngraph/ngraph_custom_op.h | 4 +- onnxruntime/python/backend/backend.py | 3 + onnxruntime/python/session.py | 47 ++++++++++++- 4 files changed, 82 insertions(+), 40 deletions(-) diff --git a/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc b/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc index 326e878cbc..3aeefac818 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc +++ b/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc @@ -39,9 +39,7 @@ static bool check_ngraph_dump_ops() { NGRAPHCustomOp::NGRAPHCustomOp(const ComputeContext* context, const ONNX_NAMESPACE::ModelProto& model_proto, - const std::shared_ptr& ng_backend) : - ng_backend_{ng_backend}, model_proto_{model_proto} -{ + const std::shared_ptr& ng_backend) : ng_backend_{ng_backend}, model_proto_{model_proto} { allocate_func_ = context->allocate_func; release_func_ = context->release_func; allocator_ = context->allocator_handle; @@ -60,7 +58,7 @@ NGRAPHCustomOp::~NGRAPHCustomOp() { } //This method gets called in critical path of execution: Optimize -void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const { +Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const { Ort::CustomOpApi ort{*api}; size_t num_inputs = ort.KernelContext_GetInputCount(context); @@ -84,18 +82,18 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con // Get cache size from environment std::string tempSize; - #ifdef _WIN32 - char *buf{nullptr}; +#ifdef _WIN32 + char* buf{nullptr}; size_t bufSize = 0; if (!_dupenv_s(&buf, &bufSize, "ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE") && buf) { tempSize = buf; free(buf); } - #else +#else if (std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE")) { tempSize = std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE"); } - #endif +#endif size_t cacheSize = tempSize.empty() ? NGRAPH_EP_LRU_CACHE_DEFAULT_SIZE : std::stoi(tempSize); // Not in cache @@ -104,20 +102,20 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con if (keyCache.size() == cacheSize) { // Delete least recently used element std::string last = keyCache.back(); - + // Pop the last elmeent keyCache.pop_back(); - + // Erase the last element from cache - ng_exe_map_.erase(ng_exe_map_.find(last)); - } - } - - // Found in cache + ng_exe_map_.erase(ng_exe_map_.find(last)); + } + } + + // Found in cache else { keyCache.remove(uniq_input_shape); } - + // update reference keyCache.push_front(uniq_input_shape); auto it = ng_exe_map_.insert({uniq_input_shape, nullptr}); @@ -125,7 +123,7 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con //ng_exe with current shape already exists if (!it.second) { ng_curr_exe_ = it.first->second; - return; + return Status::OK(); } else { auto graph_proto = model_proto_.mutable_graph(); @@ -151,13 +149,11 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con try { ng_function = ngraph::onnx_import::import_onnx_model(model_stream); } catch (const std::exception& exp) { - LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " - << "Exception while importing model to nGraph: " << std::string(exp.what()); - throw; + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NGRAPHCustomOp] - " + name_ + " - Exception while importing model to nGraph: " + std::string(exp.what())); } catch (...) { - LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " - << "Unknown exception while importing model to nGraph"; - throw; + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NGRAPHCustomOp] - " + name_ + " - Unknown exception while importing model to nGraph"); } for (auto& result : ng_function->get_results()) { @@ -168,14 +164,16 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con try { ng_curr_exe_ = ng_backend_->compile(ng_function); } catch (const std::exception& exp) { - LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " - << "Exception while compiling ngraph::Function: " << std::string(exp.what()); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NGRAPHCustomOp] - " + name_ + " - Exception while compiling ngraph::Function: " + std::string(exp.what())); } catch (...) { - LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " << "Unknown exception while compiling ngraph::Function"; + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NGRAPHCustomOp] - " + name_ + " - Unknown exception while compiling ngraph::Function"); } it.first->second = ng_curr_exe_; } -} // namespace ngraph_ep + return Status::OK(); +} //This method gets called in critical path of execution: Optimize Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const { @@ -184,7 +182,7 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont // Initialize nGraph function if it is not already initialized. { std::lock_guard lock(compute_lock_); - Initialize(api, context); + ORT_RETURN_IF_ERROR(Initialize(api, context)); } ORT_ENFORCE(ng_curr_exe_ != nullptr); @@ -202,9 +200,9 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont ng_inputs.emplace_back(ng_backend_->create_tensor(ng_param->get_output_element_type(0), ng_param->get_output_shape(0), input_data)); } } catch (const std::exception& exp) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while copying input data to nGraph: " + std::string(exp.what())); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while copying input data to nGraph: " + std::string(exp.what())); } catch (...) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while copying input data to nGraph"); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while copying input data to nGraph"); } // Initialize output tensors @@ -222,20 +220,20 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont ng_outputs.emplace_back(ng_backend_->create_tensor(dtype, shape, output_data)); } } catch (const std::exception& exp) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while creating nGraph output Tensor: " + std::string(exp.what())); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while creating nGraph output Tensor: " + std::string(exp.what())); } catch (...) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while creating nGraph output Tensor"); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while creating nGraph output Tensor"); } // Run the graph through nGraph. try { std::lock_guard lock(compute_lock_); if (!ng_curr_exe_->call(ng_outputs, ng_inputs)) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Error while executing nGraph computation"); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Error while executing nGraph computation"); } catch (const std::exception& exp) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while executing nGraph computation: " + std::string(exp.what())); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while executing nGraph computation: " + std::string(exp.what())); } catch (...) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while executing nGraph computation"); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while executing nGraph computation"); } return Status::OK(); diff --git a/onnxruntime/core/providers/ngraph/ngraph_custom_op.h b/onnxruntime/core/providers/ngraph/ngraph_custom_op.h index ad9955872d..beb7ece220 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_custom_op.h +++ b/onnxruntime/core/providers/ngraph/ngraph_custom_op.h @@ -34,7 +34,7 @@ class NGRAPHCustomOp { ~NGRAPHCustomOp(); private: - void Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const; + Status Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const; std::shared_ptr ng_backend_; @@ -57,7 +57,7 @@ class NGRAPHCustomOp { */ mutable std::unordered_map> ng_exe_map_; mutable std::list keyCache; - + mutable std::mutex compute_lock_; mutable ONNX_NAMESPACE::ModelProto model_proto_; diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py index 3c88a4eb3f..e3eaf9cd35 100644 --- a/onnxruntime/python/backend/backend.py +++ b/onnxruntime/python/backend/backend.py @@ -67,6 +67,9 @@ class OnnxRuntimeBackend(Backend): if hasattr(options, k): setattr(options, k, v) inf = InferenceSession(model, options) + # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback + # which may hide test failures. + inf.disable_fallback() if device is not None and not cls.supports_device(device): raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device())) return cls.prepare(inf, device, **kwargs) diff --git a/onnxruntime/python/session.py b/onnxruntime/python/session.py index e46ed675d4..7e3a7ffb88 100644 --- a/onnxruntime/python/session.py +++ b/onnxruntime/python/session.py @@ -21,6 +21,7 @@ class InferenceSession: self._path_or_bytes = path_or_bytes self._sess_options = sess_options self._load_model() + self._enable_fallback = True def _load_model(self, providers=[]): if self._sess_options: @@ -46,6 +47,12 @@ class InferenceSession: self._model_meta = self._sess.model_meta self._providers = self._sess.get_providers() + # Tensorrt can fall back to CUDA. All others fall back to CPU. + if 'TensorrtExecutionProvider' in C.get_available_providers(): + self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + self._fallback_providers = ['CPUExecutionProvider'] + def _reset_session(self): "release underlying session object." # meta data references session internal structures @@ -78,12 +85,34 @@ class InferenceSession: return self._providers def set_providers(self, providers): - "Register the input list of execution providers. The underlying session is re-created." + """ + Register the input list of execution providers. The underlying session is re-created. + + :param providers: list of execution providers + + The list of providers is ordered by Priority. For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] means + execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider. + """ if not set(providers).issubset(C.get_available_providers()): - raise ValueError("{} does not contain a subset of available providers {}".format(providers, C.get_available_providers())) + raise ValueError("{} does not contain a subset of available providers {}".format(providers, C.get_available_providers())) self._reset_session() self._load_model(providers) + def disable_fallback(self): + """ + Disable session.run() fallback mechanism. + """ + self._enable_fallback = False + + def enable_fallback(self): + """ + Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure, reset the Execution Providers + enabled for this session. + If GPU is enabled, fall back to CUDAExecutionProvider. + otherwise fall back to CPUExecutionProvider. + """ + self._enable_fallback = True + def run(self, output_names, input_feed, run_options=None): """ Compute the predictions. @@ -103,7 +132,19 @@ class InferenceSession: raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs)) if not output_names: output_names = [output.name for output in self._outputs_meta] - return self._sess.run(output_names, input_feed, run_options) + try: + return self._sess.run(output_names, input_feed, run_options) + except C.EPFail as err: + if self._enable_fallback: + print("EP Error: {} using {}".format(str(err), self._providers)) + print("Falling back to {} and retrying.".format(self._fallback_providers)) + self.set_providers(self._fallback_providers) + # Fallback only once. + self.disable_fallback() + return self._sess.run(output_names, input_feed, run_options) + else: + raise + def end_profiling(self): """