Remove ml_status.h, add StatusCode to pybind exception mappings (#1889)

* initial checkin.

* add onnxruntime status code to ort pybind exception mapping.

* address review feedback.
This commit is contained in:
jywu-msft 2019-09-24 11:13:14 -07:00 committed by GitHub
parent 77176e8678
commit 686bd36210
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 214 additions and 185 deletions

View file

@ -1,57 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cstdint>
namespace onnxruntime {
enum class MLStatus : uint32_t {
OK = 0,
FAIL = 1,
INVALID_ARGUMENT = 2,
NO_SUCHFILE = 3,
NO_MODEL = 4,
ENGINE_ERROR = 5,
RUNTIME_EXCEPTION = 6,
INVALID_PROTOBUF = 7,
MODEL_LOADED = 8,
NOT_IMPLEMENTED = 9,
INVALID_GRAPH = 10,
SHAPE_INFERENCE_NOT_REGISTERED = 11,
REQUIREMENT_NOT_REGISTERED = 12
};
inline const char* MLStatusToString(MLStatus status) noexcept {
switch (status) {
case MLStatus::OK:
return "SUCCESS";
case MLStatus::INVALID_ARGUMENT:
return "INVALID_ARGUMENT";
case MLStatus::NO_SUCHFILE:
return "NO_SUCHFILE";
case MLStatus::NO_MODEL:
return "NO_MODEL";
case MLStatus::ENGINE_ERROR:
return "ENGINE_ERROR";
case MLStatus::RUNTIME_EXCEPTION:
return "RUNTIME_EXCEPTION";
case MLStatus::INVALID_PROTOBUF:
return "INVALID_PROTOBUF";
case MLStatus::MODEL_LOADED:
return "MODEL_LOADED";
case MLStatus::NOT_IMPLEMENTED:
return "NOT_IMPLEMENTED";
case MLStatus::INVALID_GRAPH:
return "INVALID_GRAPH";
case MLStatus::SHAPE_INFERENCE_NOT_REGISTERED:
return "SHAPE_INFERENCE_NOT_REGISTERED";
case MLStatus::REQUIREMENT_NOT_REGISTERED:
return "REQUIREMENT_NOT_REGISTERED";
default:
return "GENERAL ERROR";
}
}
} // namespace onnxruntime

View file

@ -16,7 +16,6 @@ limitations under the License.
#include <memory>
#include <ostream>
#include <string>
#include "core/common/ml_status.h"
namespace onnxruntime {
namespace common {
@ -31,21 +30,51 @@ enum StatusCategory {
Error code for ONNXRuntime.
*/
enum StatusCode {
OK = static_cast<unsigned int>(MLStatus::OK),
FAIL = static_cast<unsigned int>(MLStatus::FAIL),
INVALID_ARGUMENT = static_cast<unsigned int>(MLStatus::INVALID_ARGUMENT),
NO_SUCHFILE = static_cast<unsigned int>(MLStatus::NO_SUCHFILE),
NO_MODEL = static_cast<unsigned int>(MLStatus::NO_MODEL),
ENGINE_ERROR = static_cast<unsigned int>(MLStatus::ENGINE_ERROR),
RUNTIME_EXCEPTION = static_cast<unsigned int>(MLStatus::RUNTIME_EXCEPTION),
INVALID_PROTOBUF = static_cast<unsigned int>(MLStatus::INVALID_PROTOBUF),
MODEL_LOADED = static_cast<unsigned int>(MLStatus::MODEL_LOADED),
NOT_IMPLEMENTED = static_cast<unsigned int>(MLStatus::NOT_IMPLEMENTED),
INVALID_GRAPH = static_cast<unsigned int>(MLStatus::INVALID_GRAPH),
SHAPE_INFERENCE_NOT_REGISTERED = static_cast<unsigned int>(MLStatus::SHAPE_INFERENCE_NOT_REGISTERED),
REQUIREMENT_NOT_REGISTERED = static_cast<unsigned int>(MLStatus::REQUIREMENT_NOT_REGISTERED),
OK = 0,
FAIL = 1,
INVALID_ARGUMENT = 2,
NO_SUCHFILE = 3,
NO_MODEL = 4,
ENGINE_ERROR = 5,
RUNTIME_EXCEPTION = 6,
INVALID_PROTOBUF = 7,
MODEL_LOADED = 8,
NOT_IMPLEMENTED = 9,
INVALID_GRAPH = 10,
EP_FAIL = 11
};
inline const char* StatusCodeToString(StatusCode status) noexcept {
switch (status) {
case StatusCode::OK:
return "SUCCESS";
case StatusCode::FAIL:
return "FAIL";
case StatusCode::INVALID_ARGUMENT:
return "INVALID_ARGUMENT";
case StatusCode::NO_SUCHFILE:
return "NO_SUCHFILE";
case StatusCode::NO_MODEL:
return "NO_MODEL";
case StatusCode::ENGINE_ERROR:
return "ENGINE_ERROR";
case StatusCode::RUNTIME_EXCEPTION:
return "RUNTIME_EXCEPTION";
case StatusCode::INVALID_PROTOBUF:
return "INVALID_PROTOBUF";
case StatusCode::MODEL_LOADED:
return "MODEL_LOADED";
case StatusCode::NOT_IMPLEMENTED:
return "NOT_IMPLEMENTED";
case StatusCode::INVALID_GRAPH:
return "INVALID_GRAPH";
case StatusCode::EP_FAIL:
return "EP_FAIL";
default:
return "GENERAL ERROR";
}
}
class Status {
public:
Status() noexcept = default;

View file

@ -122,8 +122,7 @@ typedef enum OrtErrorCode {
ORT_MODEL_LOADED,
ORT_NOT_IMPLEMENTED,
ORT_INVALID_GRAPH,
ORT_SHAPE_INFERENCE_NOT_REGISTERED,
ORT_REQUIREMENT_NOT_REGISTERED,
ORT_EP_FAIL,
} OrtErrorCode;
// __VA_ARGS__ on Windows and Linux are different

View file

@ -18,14 +18,14 @@ namespace onnxruntime {
namespace common {
Status::Status(StatusCategory category, int code, const std::string& msg) {
// state_ will be allocated here causing the status to be treated as a failure
ORT_ENFORCE(code != static_cast<int>(MLStatus::OK));
ORT_ENFORCE(code != static_cast<int>(common::OK));
state_ = std::make_unique<State>(category, code, msg);
}
Status::Status(StatusCategory category, int code, const char* msg) {
// state_ will be allocated here causing the status to be treated as a failure
ORT_ENFORCE(code != static_cast<int>(MLStatus::OK));
ORT_ENFORCE(code != static_cast<int>(common::OK));
state_ = std::make_unique<State>(category, code, msg);
}
@ -62,7 +62,7 @@ std::string Status::ToString() const {
result += " : ";
result += std::to_string(Code());
result += " : ";
result += MLStatusToString(static_cast<MLStatus>(Code()));
result += StatusCodeToString(static_cast<StatusCode>(Code()));
result += " : ";
result += state_->msg;
}

View file

@ -32,7 +32,6 @@ using onnxruntime::Environment;
using onnxruntime::IAllocator;
using onnxruntime::InputDefList;
using onnxruntime::MLFloat16;
using onnxruntime::MLStatus;
using onnxruntime::OutputDefList;
using onnxruntime::Tensor;
using onnxruntime::ToOrtStatus;

View file

@ -0,0 +1,87 @@
#include <pybind11/pybind11.h>
#include <stdexcept>
#include "core/common/status.h"
namespace onnxruntime {
namespace python {
// onnxruntime::python exceptions map 1:1 to onnxruntime:common::StatusCode enum.
struct Fail : std::runtime_error {
explicit Fail(const std::string& what) : std::runtime_error(what) {}
};
struct InvalidArgument : std::runtime_error {
explicit InvalidArgument(const std::string& what) : std::runtime_error(what) {}
};
struct NoSuchFile : std::runtime_error {
explicit NoSuchFile(const std::string& what) : std::runtime_error(what) {}
};
struct NoModel : std::runtime_error {
explicit NoModel(const std::string& what) : std::runtime_error(what) {}
};
struct EngineError : std::runtime_error {
explicit EngineError(const std::string& what) : std::runtime_error(what) {}
};
struct RuntimeException : std::runtime_error {
explicit RuntimeException(const std::string& what) : std::runtime_error(what) {}
};
struct InvalidProtobuf : std::runtime_error {
explicit InvalidProtobuf(const std::string& what) : std::runtime_error(what) {}
};
struct ModelLoaded : std::runtime_error {
explicit ModelLoaded(const std::string& what) : std::runtime_error(what) {}
};
struct NotImplemented : std::runtime_error {
explicit NotImplemented(const std::string& what) : std::runtime_error(what) {}
};
struct InvalidGraph : std::runtime_error {
explicit InvalidGraph(const std::string& what) : std::runtime_error(what) {}
};
struct EPFail : std::runtime_error {
explicit EPFail(const std::string& what) : std::runtime_error(what) {}
};
void RegisterExceptions(pybind11::module& m) {
pybind11::register_exception<Fail>(m, "Fail");
pybind11::register_exception<InvalidArgument>(m, "InvalidArgument");
pybind11::register_exception<NoSuchFile>(m, "NoSuchFile");
pybind11::register_exception<NoModel>(m, "NoModel");
pybind11::register_exception<EngineError>(m, "EngineError");
pybind11::register_exception<RuntimeException>(m, "RuntimeException");
pybind11::register_exception<InvalidProtobuf>(m, "InvalidProtobuf");
pybind11::register_exception<ModelLoaded>(m, "ModelLoaded");
pybind11::register_exception<NotImplemented>(m, "NotImplemented");
pybind11::register_exception<InvalidGraph>(m, "InvalidGraph");
pybind11::register_exception<EPFail>(m, "EPFail");
}
inline void OrtPybindThrowIfError(onnxruntime::common::Status status) {
std::string msg = status.ToString();
if (!status.IsOK()) {
switch (status.Code()) {
case onnxruntime::common::StatusCode::FAIL:
throw Fail(std::move(msg));
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
throw InvalidArgument(std::move(msg));
case onnxruntime::common::StatusCode::NO_SUCHFILE:
throw NoSuchFile(std::move(msg));
case onnxruntime::common::StatusCode::NO_MODEL:
throw NoModel(std::move(msg));
case onnxruntime::common::StatusCode::ENGINE_ERROR:
throw EngineError(std::move(msg));
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
throw RuntimeException(std::move(msg));
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
throw InvalidProtobuf(std::move(msg));
case onnxruntime::common::StatusCode::NOT_IMPLEMENTED:
throw NotImplemented(std::move(msg));
case onnxruntime::common::StatusCode::INVALID_GRAPH:
throw InvalidGraph(std::move(msg));
case onnxruntime::common::StatusCode::EP_FAIL:
throw EPFail(std::move(msg));
default:
throw std::runtime_error(std::move(msg));
}
}
}
} // namespace python
} // namespace onnxruntime

View file

@ -15,8 +15,6 @@
#include "core/framework/ml_value.h"
#include "core/session/inference_session.h"
using namespace std;
namespace onnxruntime {
namespace python {

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_pybind_exceptions.h"
#include "onnxruntime_pybind_mlvalue.h"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@ -125,7 +126,6 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_BrainS
#pragma warning(disable : 4267 4996 4503 4003)
#endif // _MSC_VER
using namespace std;
namespace onnxruntime {
namespace python {
@ -144,10 +144,10 @@ static const SessionOptions& GetDefaultCPUSessionOptions() {
}
template <typename T>
void AddNonTensor(OrtValue& val, vector<py::object>& pyobjs) {
void AddNonTensor(OrtValue& val, std::vector<py::object>& pyobjs) {
pyobjs.push_back(py::cast(val.Get<T>()));
}
void AddNonTensorAsPyObj(OrtValue& val, vector<py::object>& pyobjs) {
void AddNonTensorAsPyObj(OrtValue& val, std::vector<py::object>& pyobjs) {
// Should be in sync with core/framework/datatypes.h
if (val.Type() == DataTypeImpl::GetType<MapStringToString>()) {
AddNonTensor<MapStringToString>(val, pyobjs);
@ -182,7 +182,7 @@ void AddNonTensorAsPyObj(OrtValue& val, vector<py::object>& pyobjs) {
}
}
void AddTensorAsPyObj(OrtValue& val, vector<py::object>& pyobjs) {
void AddTensorAsPyObj(OrtValue& val, std::vector<py::object>& pyobjs) {
const Tensor& rtensor = val.Get<Tensor>();
std::vector<npy_intp> npy_dims;
const TensorShape& shape = rtensor.Shape();
@ -235,10 +235,7 @@ class SessionObjectInitializer {
inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) {
auto p = f.CreateProvider();
auto status = sess->RegisterExecutionProvider(std::move(p));
if (!status.IsOK()) {
throw std::runtime_error(status.ErrorMessage().c_str());
}
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p)));
}
// ordered by default priority. highest to lowest.
@ -327,11 +324,7 @@ void InitializeSession(InferenceSession* sess, const std::vector<std::string>& p
} else {
RegisterExecutionProviders(sess, provider_types);
}
onnxruntime::common::Status status;
status = sess->Initialize();
if (!status.IsOK()) {
throw std::runtime_error(status.ToString().c_str());
}
OrtPybindThrowIfError(sess->Initialize());
}
void addGlobalMethods(py::module& m) {
@ -339,8 +332,12 @@ void addGlobalMethods(py::module& m) {
m.def(
"get_device", []() -> std::string { return BACKEND_DEVICE; },
"Return the device used to compute the prediction (CPU, MKL, ...)");
m.def("get_all_providers", []() -> const std::vector<std::string>& { return GetAllProviders(); });
m.def("get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableProviders(); });
m.def(
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllProviders(); },
"Return list of Execution Providers that this version of Onnxruntime can support.");
m.def(
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableProviders(); },
"Return list of available Execution Providers available in this installed version of Onnxruntime.");
#ifdef USE_NUPHAR
m.def("set_nuphar_settings", [](const std::string& str) {
@ -624,55 +621,57 @@ including arg name, arg type (contains both type and shape).)pbdoc")
return *(na.Type());
},
"node type")
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
res << "[]";
} else {
res << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
res << shape->dim(i).dim_value();
} else if (utils::HasDimParam(shape->dim(i))) {
res << "'" << shape->dim(i).dim_param() << "'";
.def(
"__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
res << "[]";
} else {
res << "None";
res << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
res << shape->dim(i).dim_value();
} else if (utils::HasDimParam(shape->dim(i))) {
res << "'" << shape->dim(i).dim_param() << "'";
} else {
res << "None";
}
if (i < shape->dim_size() - 1) {
res << ", ";
}
}
res << "]";
}
res << ")";
return std::string(res.str());
},
"converts the node into a readable string")
.def_property_readonly(
"shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
return arr;
}
if (i < shape->dim_size() - 1) {
res << ", ";
arr.resize(shape->dim_size());
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_value());
} else if (utils::HasDimParam(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_param());
} else {
arr[i] = py::none();
}
}
}
res << "]";
}
res << ")";
return std::string(res.str());
},
"converts the node into a readable string")
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
return arr;
}
arr.resize(shape->dim_size());
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_value());
} else if (utils::HasDimParam(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_param());
} else {
arr[i] = py::none();
}
}
return arr;
},
"node shape (assuming the node holds a tensor)");
return arr;
},
"node shape (assuming the node holds a tensor)");
py::class_<SessionObjectInitializer>(m, "SessionObjectInitializer");
py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
@ -680,22 +679,17 @@ including arg name, arg type (contains both type and shape).)pbdoc")
.def(py::init<SessionOptions, SessionObjectInitializer>())
.def(
"load_model", [](InferenceSession* sess, const std::string& path, std::vector<std::string>& provider_types) {
auto status = sess->Load(path);
if (!status.IsOK()) {
throw std::runtime_error(status.ToString().c_str());
}
OrtPybindThrowIfError(sess->Load(path));
InitializeSession(sess, provider_types);
},
R"pbdoc(Load a model saved in ONNX format.)pbdoc")
.def("read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel, std::vector<std::string>& provider_types) {
std::istringstream buffer(serializedModel);
auto status = sess->Load(buffer);
if (!status.IsOK()) {
throw std::runtime_error(status.ToString().c_str());
}
InitializeSession(sess, provider_types);
},
R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
.def(
"read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel, std::vector<std::string>& provider_types) {
std::istringstream buffer(serializedModel);
OrtPybindThrowIfError(sess->Load(buffer));
InitializeSession(sess, provider_types);
},
R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
.def("run", [](InferenceSession* sess, std::vector<std::string> output_names, std::map<std::string, py::object> pyfeeds, RunOptions* run_options = nullptr) -> std::vector<py::object> {
NameMLValMap feeds;
for (auto _ : pyfeeds) {
@ -724,17 +718,12 @@ including arg name, arg type (contains both type and shape).)pbdoc")
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (run_options != nullptr) {
status = sess->Run(*run_options, feeds, output_names, &fetches);
OrtPybindThrowIfError(sess->Run(*run_options, feeds, output_names, &fetches));
} else {
status = sess->Run(feeds, output_names, &fetches);
OrtPybindThrowIfError(sess->Run(feeds, output_names, &fetches));
}
}
if (!status.IsOK()) {
auto mes = status.ToString();
throw std::runtime_error(std::string("Method run failed due to: ") + std::string(mes.c_str()));
}
std::vector<py::object> rfetch;
rfetch.reserve(fetches.size());
for (auto _ : fetches) {
@ -754,40 +743,29 @@ including arg name, arg type (contains both type and shape).)pbdoc")
})
.def_property_readonly("inputs_meta", [](const InferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetModelInputs();
if (!res.first.IsOK()) {
throw std::runtime_error(res.first.ToString().c_str());
} else {
return *(res.second);
}
OrtPybindThrowIfError(res.first);
return *(res.second);
})
.def_property_readonly("outputs_meta", [](const InferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetModelOutputs();
if (!res.first.IsOK()) {
throw std::runtime_error(res.first.ToString().c_str());
} else {
return *(res.second);
}
OrtPybindThrowIfError(res.first);
return *(res.second);
})
.def_property_readonly("overridable_initializers", [](const InferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetOverridableInitializers();
if (!res.first.IsOK()) {
throw std::runtime_error(res.first.ToString().c_str());
} else {
return *res.second;
}
OrtPybindThrowIfError(res.first);
return *(res.second);
})
.def_property_readonly("model_meta", [](const InferenceSession* sess) -> const onnxruntime::ModelMetadata& {
auto res = sess->GetModelMetadata();
if (!res.first.IsOK()) {
throw std::runtime_error(res.first.ToString().c_str());
} else {
return *(res.second);
}
OrtPybindThrowIfError(res.first);
return *(res.second);
});
}
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
m.doc() = "pybind11 stateful interface to ONNX runtime";
RegisterExceptions(m);
auto initialize = [&]() {
// Initialization of the module
@ -797,10 +775,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
})();
static std::unique_ptr<Environment> env;
auto status = Environment::Create(env);
if (!status.IsOK()) {
throw std::runtime_error(status.ToString().c_str());
}
OrtPybindThrowIfError(Environment::Create(env));
static bool initialized = false;
if (initialized) {

View file

@ -23,8 +23,6 @@ protobufutil::Status GenerateProtobufStatus(const int& onnx_status, const std::s
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
case onnxruntime::common::StatusCode::INVALID_GRAPH:
case onnxruntime::common::StatusCode::SHAPE_INFERENCE_NOT_REGISTERED:
case onnxruntime::common::StatusCode::REQUIREMENT_NOT_REGISTERED:
case onnxruntime::common::StatusCode::NO_SUCHFILE:
case onnxruntime::common::StatusCode::NO_MODEL:
code = protobufutil::error::Code::INVALID_ARGUMENT;
@ -33,6 +31,7 @@ protobufutil::Status GenerateProtobufStatus(const int& onnx_status, const std::s
code = protobufutil::error::Code::UNIMPLEMENTED;
break;
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
case onnxruntime::common::StatusCode::EP_FAIL:
code = protobufutil::error::Code::INTERNAL;
break;
default: