python session.run() fallback to CPU/CUDA provider for EP failures. (#1960)

* py fallback initial commit.

* fixes.

* update NGRAPHCustomOp::Initialize() to return Status

* fixes in session.py

* FAIL status to EP_FAIL in ngraph custom op

* disable fallback for backend api
This commit is contained in:
George Wu 2019-10-02 02:38:03 -07:00 committed by GitHub
parent 622ea4248d
commit f9bf546e3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 40 deletions

View file

@ -39,9 +39,7 @@ static bool check_ngraph_dump_ops() {
NGRAPHCustomOp::NGRAPHCustomOp(const ComputeContext* context,
const ONNX_NAMESPACE::ModelProto& model_proto,
const std::shared_ptr<ngraph::runtime::Backend>& ng_backend) :
ng_backend_{ng_backend}, model_proto_{model_proto}
{
const std::shared_ptr<ngraph::runtime::Backend>& ng_backend) : ng_backend_{ng_backend}, model_proto_{model_proto} {
allocate_func_ = context->allocate_func;
release_func_ = context->release_func;
allocator_ = context->allocator_handle;
@ -60,7 +58,7 @@ NGRAPHCustomOp::~NGRAPHCustomOp() {
}
//This method gets called in critical path of execution: Optimize
void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const {
Status NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const {
Ort::CustomOpApi ort{*api};
size_t num_inputs = ort.KernelContext_GetInputCount(context);
@ -84,18 +82,18 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
// Get cache size from environment
std::string tempSize;
#ifdef _WIN32
char *buf{nullptr};
#ifdef _WIN32
char* buf{nullptr};
size_t bufSize = 0;
if (!_dupenv_s(&buf, &bufSize, "ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE") && buf) {
tempSize = buf;
free(buf);
}
#else
#else
if (std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE")) {
tempSize = std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE");
}
#endif
#endif
size_t cacheSize = tempSize.empty() ? NGRAPH_EP_LRU_CACHE_DEFAULT_SIZE : std::stoi(tempSize);
// Not in cache
@ -104,20 +102,20 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
if (keyCache.size() == cacheSize) {
// Delete least recently used element
std::string last = keyCache.back();
// Pop the last elmeent
keyCache.pop_back();
// Erase the last element from cache
ng_exe_map_.erase(ng_exe_map_.find(last));
}
}
// Found in cache
ng_exe_map_.erase(ng_exe_map_.find(last));
}
}
// Found in cache
else {
keyCache.remove(uniq_input_shape);
}
// update reference
keyCache.push_front(uniq_input_shape);
auto it = ng_exe_map_.insert({uniq_input_shape, nullptr});
@ -125,7 +123,7 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
//ng_exe with current shape already exists
if (!it.second) {
ng_curr_exe_ = it.first->second;
return;
return Status::OK();
} else {
auto graph_proto = model_proto_.mutable_graph();
@ -151,13 +149,11 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
try {
ng_function = ngraph::onnx_import::import_onnx_model(model_stream);
} catch (const std::exception& exp) {
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - "
<< "Exception while importing model to nGraph: " << std::string(exp.what());
throw;
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[NGRAPHCustomOp] - " + name_ + " - Exception while importing model to nGraph: " + std::string(exp.what()));
} catch (...) {
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - "
<< "Unknown exception while importing model to nGraph";
throw;
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[NGRAPHCustomOp] - " + name_ + " - Unknown exception while importing model to nGraph");
}
for (auto& result : ng_function->get_results()) {
@ -168,14 +164,16 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
try {
ng_curr_exe_ = ng_backend_->compile(ng_function);
} catch (const std::exception& exp) {
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - "
<< "Exception while compiling ngraph::Function: " << std::string(exp.what());
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[NGRAPHCustomOp] - " + name_ + " - Exception while compiling ngraph::Function: " + std::string(exp.what()));
} catch (...) {
LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " << "Unknown exception while compiling ngraph::Function";
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[NGRAPHCustomOp] - " + name_ + " - Unknown exception while compiling ngraph::Function");
}
it.first->second = ng_curr_exe_;
}
} // namespace ngraph_ep
return Status::OK();
}
//This method gets called in critical path of execution: Optimize
Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const {
@ -184,7 +182,7 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
// Initialize nGraph function if it is not already initialized.
{
std::lock_guard<std::mutex> lock(compute_lock_);
Initialize(api, context);
ORT_RETURN_IF_ERROR(Initialize(api, context));
}
ORT_ENFORCE(ng_curr_exe_ != nullptr);
@ -202,9 +200,9 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
ng_inputs.emplace_back(ng_backend_->create_tensor(ng_param->get_output_element_type(0), ng_param->get_output_shape(0), input_data));
}
} catch (const std::exception& exp) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while copying input data to nGraph: " + std::string(exp.what()));
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while copying input data to nGraph: " + std::string(exp.what()));
} catch (...) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while copying input data to nGraph");
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while copying input data to nGraph");
}
// Initialize output tensors
@ -222,20 +220,20 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
ng_outputs.emplace_back(ng_backend_->create_tensor(dtype, shape, output_data));
}
} catch (const std::exception& exp) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while creating nGraph output Tensor: " + std::string(exp.what()));
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while creating nGraph output Tensor: " + std::string(exp.what()));
} catch (...) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while creating nGraph output Tensor");
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while creating nGraph output Tensor");
}
// Run the graph through nGraph.
try {
std::lock_guard<std::mutex> lock(compute_lock_);
if (!ng_curr_exe_->call(ng_outputs, ng_inputs))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Error while executing nGraph computation");
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Error while executing nGraph computation");
} catch (const std::exception& exp) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while executing nGraph computation: " + std::string(exp.what()));
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Exception while executing nGraph computation: " + std::string(exp.what()));
} catch (...) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while executing nGraph computation");
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, name_ + ": Unknown exception while executing nGraph computation");
}
return Status::OK();

View file

@ -34,7 +34,7 @@ class NGRAPHCustomOp {
~NGRAPHCustomOp();
private:
void Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const;
Status Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const;
std::shared_ptr<ngraph::runtime::Backend> ng_backend_;
@ -57,7 +57,7 @@ class NGRAPHCustomOp {
*/
mutable std::unordered_map<std::string, std::shared_ptr<ngraph::runtime::Executable>> ng_exe_map_;
mutable std::list<std::string> keyCache;
mutable std::mutex compute_lock_;
mutable ONNX_NAMESPACE::ModelProto model_proto_;

View file

@ -67,6 +67,9 @@ class OnnxRuntimeBackend(Backend):
if hasattr(options, k):
setattr(options, k, v)
inf = InferenceSession(model, options)
# backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
# which may hide test failures.
inf.disable_fallback()
if device is not None and not cls.supports_device(device):
raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device()))
return cls.prepare(inf, device, **kwargs)

View file

@ -21,6 +21,7 @@ class InferenceSession:
self._path_or_bytes = path_or_bytes
self._sess_options = sess_options
self._load_model()
self._enable_fallback = True
def _load_model(self, providers=[]):
if self._sess_options:
@ -46,6 +47,12 @@ class InferenceSession:
self._model_meta = self._sess.model_meta
self._providers = self._sess.get_providers()
# Tensorrt can fall back to CUDA. All others fall back to CPU.
if 'TensorrtExecutionProvider' in C.get_available_providers():
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
self._fallback_providers = ['CPUExecutionProvider']
def _reset_session(self):
"release underlying session object."
# meta data references session internal structures
@ -78,12 +85,34 @@ class InferenceSession:
return self._providers
def set_providers(self, providers):
"Register the input list of execution providers. The underlying session is re-created."
"""
Register the input list of execution providers. The underlying session is re-created.
:param providers: list of execution providers
The list of providers is ordered by Priority. For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] means
execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
"""
if not set(providers).issubset(C.get_available_providers()):
raise ValueError("{} does not contain a subset of available providers {}".format(providers, C.get_available_providers()))
raise ValueError("{} does not contain a subset of available providers {}".format(providers, C.get_available_providers()))
self._reset_session()
self._load_model(providers)
def disable_fallback(self):
"""
Disable session.run() fallback mechanism.
"""
self._enable_fallback = False
def enable_fallback(self):
"""
Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure, reset the Execution Providers
enabled for this session.
If GPU is enabled, fall back to CUDAExecutionProvider.
otherwise fall back to CPUExecutionProvider.
"""
self._enable_fallback = True
def run(self, output_names, input_feed, run_options=None):
"""
Compute the predictions.
@ -103,7 +132,19 @@ class InferenceSession:
raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
if not output_names:
output_names = [output.name for output in self._outputs_meta]
return self._sess.run(output_names, input_feed, run_options)
try:
return self._sess.run(output_names, input_feed, run_options)
except C.EPFail as err:
if self._enable_fallback:
print("EP Error: {} using {}".format(str(err), self._providers))
print("Falling back to {} and retrying.".format(self._fallback_providers))
self.set_providers(self._fallback_providers)
# Fallback only once.
self.disable_fallback()
return self._sess.run(output_names, input_feed, run_options)
else:
raise
def end_profiling(self):
"""