mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Remove ml_status.h, add StatusCode to pybind exception mappings (#1889)
* initial checkin. * add onnxruntime status code to ort pybind exception mapping. * address review feedback.
This commit is contained in:
parent
77176e8678
commit
686bd36210
9 changed files with 214 additions and 185 deletions
|
|
@ -1,57 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
enum class MLStatus : uint32_t {
|
||||
OK = 0,
|
||||
FAIL = 1,
|
||||
INVALID_ARGUMENT = 2,
|
||||
NO_SUCHFILE = 3,
|
||||
NO_MODEL = 4,
|
||||
ENGINE_ERROR = 5,
|
||||
RUNTIME_EXCEPTION = 6,
|
||||
INVALID_PROTOBUF = 7,
|
||||
MODEL_LOADED = 8,
|
||||
NOT_IMPLEMENTED = 9,
|
||||
INVALID_GRAPH = 10,
|
||||
SHAPE_INFERENCE_NOT_REGISTERED = 11,
|
||||
REQUIREMENT_NOT_REGISTERED = 12
|
||||
};
|
||||
|
||||
inline const char* MLStatusToString(MLStatus status) noexcept {
|
||||
switch (status) {
|
||||
case MLStatus::OK:
|
||||
return "SUCCESS";
|
||||
case MLStatus::INVALID_ARGUMENT:
|
||||
return "INVALID_ARGUMENT";
|
||||
case MLStatus::NO_SUCHFILE:
|
||||
return "NO_SUCHFILE";
|
||||
case MLStatus::NO_MODEL:
|
||||
return "NO_MODEL";
|
||||
case MLStatus::ENGINE_ERROR:
|
||||
return "ENGINE_ERROR";
|
||||
case MLStatus::RUNTIME_EXCEPTION:
|
||||
return "RUNTIME_EXCEPTION";
|
||||
case MLStatus::INVALID_PROTOBUF:
|
||||
return "INVALID_PROTOBUF";
|
||||
case MLStatus::MODEL_LOADED:
|
||||
return "MODEL_LOADED";
|
||||
case MLStatus::NOT_IMPLEMENTED:
|
||||
return "NOT_IMPLEMENTED";
|
||||
case MLStatus::INVALID_GRAPH:
|
||||
return "INVALID_GRAPH";
|
||||
case MLStatus::SHAPE_INFERENCE_NOT_REGISTERED:
|
||||
return "SHAPE_INFERENCE_NOT_REGISTERED";
|
||||
case MLStatus::REQUIREMENT_NOT_REGISTERED:
|
||||
return "REQUIREMENT_NOT_REGISTERED";
|
||||
default:
|
||||
return "GENERAL ERROR";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "core/common/ml_status.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace common {
|
||||
|
|
@ -31,21 +30,51 @@ enum StatusCategory {
|
|||
Error code for ONNXRuntime.
|
||||
*/
|
||||
enum StatusCode {
|
||||
OK = static_cast<unsigned int>(MLStatus::OK),
|
||||
FAIL = static_cast<unsigned int>(MLStatus::FAIL),
|
||||
INVALID_ARGUMENT = static_cast<unsigned int>(MLStatus::INVALID_ARGUMENT),
|
||||
NO_SUCHFILE = static_cast<unsigned int>(MLStatus::NO_SUCHFILE),
|
||||
NO_MODEL = static_cast<unsigned int>(MLStatus::NO_MODEL),
|
||||
ENGINE_ERROR = static_cast<unsigned int>(MLStatus::ENGINE_ERROR),
|
||||
RUNTIME_EXCEPTION = static_cast<unsigned int>(MLStatus::RUNTIME_EXCEPTION),
|
||||
INVALID_PROTOBUF = static_cast<unsigned int>(MLStatus::INVALID_PROTOBUF),
|
||||
MODEL_LOADED = static_cast<unsigned int>(MLStatus::MODEL_LOADED),
|
||||
NOT_IMPLEMENTED = static_cast<unsigned int>(MLStatus::NOT_IMPLEMENTED),
|
||||
INVALID_GRAPH = static_cast<unsigned int>(MLStatus::INVALID_GRAPH),
|
||||
SHAPE_INFERENCE_NOT_REGISTERED = static_cast<unsigned int>(MLStatus::SHAPE_INFERENCE_NOT_REGISTERED),
|
||||
REQUIREMENT_NOT_REGISTERED = static_cast<unsigned int>(MLStatus::REQUIREMENT_NOT_REGISTERED),
|
||||
OK = 0,
|
||||
FAIL = 1,
|
||||
INVALID_ARGUMENT = 2,
|
||||
NO_SUCHFILE = 3,
|
||||
NO_MODEL = 4,
|
||||
ENGINE_ERROR = 5,
|
||||
RUNTIME_EXCEPTION = 6,
|
||||
INVALID_PROTOBUF = 7,
|
||||
MODEL_LOADED = 8,
|
||||
NOT_IMPLEMENTED = 9,
|
||||
INVALID_GRAPH = 10,
|
||||
EP_FAIL = 11
|
||||
};
|
||||
|
||||
inline const char* StatusCodeToString(StatusCode status) noexcept {
|
||||
switch (status) {
|
||||
case StatusCode::OK:
|
||||
return "SUCCESS";
|
||||
case StatusCode::FAIL:
|
||||
return "FAIL";
|
||||
case StatusCode::INVALID_ARGUMENT:
|
||||
return "INVALID_ARGUMENT";
|
||||
case StatusCode::NO_SUCHFILE:
|
||||
return "NO_SUCHFILE";
|
||||
case StatusCode::NO_MODEL:
|
||||
return "NO_MODEL";
|
||||
case StatusCode::ENGINE_ERROR:
|
||||
return "ENGINE_ERROR";
|
||||
case StatusCode::RUNTIME_EXCEPTION:
|
||||
return "RUNTIME_EXCEPTION";
|
||||
case StatusCode::INVALID_PROTOBUF:
|
||||
return "INVALID_PROTOBUF";
|
||||
case StatusCode::MODEL_LOADED:
|
||||
return "MODEL_LOADED";
|
||||
case StatusCode::NOT_IMPLEMENTED:
|
||||
return "NOT_IMPLEMENTED";
|
||||
case StatusCode::INVALID_GRAPH:
|
||||
return "INVALID_GRAPH";
|
||||
case StatusCode::EP_FAIL:
|
||||
return "EP_FAIL";
|
||||
default:
|
||||
return "GENERAL ERROR";
|
||||
}
|
||||
}
|
||||
|
||||
class Status {
|
||||
public:
|
||||
Status() noexcept = default;
|
||||
|
|
|
|||
|
|
@ -122,8 +122,7 @@ typedef enum OrtErrorCode {
|
|||
ORT_MODEL_LOADED,
|
||||
ORT_NOT_IMPLEMENTED,
|
||||
ORT_INVALID_GRAPH,
|
||||
ORT_SHAPE_INFERENCE_NOT_REGISTERED,
|
||||
ORT_REQUIREMENT_NOT_REGISTERED,
|
||||
ORT_EP_FAIL,
|
||||
} OrtErrorCode;
|
||||
|
||||
// __VA_ARGS__ on Windows and Linux are different
|
||||
|
|
|
|||
|
|
@ -18,14 +18,14 @@ namespace onnxruntime {
|
|||
namespace common {
|
||||
Status::Status(StatusCategory category, int code, const std::string& msg) {
|
||||
// state_ will be allocated here causing the status to be treated as a failure
|
||||
ORT_ENFORCE(code != static_cast<int>(MLStatus::OK));
|
||||
ORT_ENFORCE(code != static_cast<int>(common::OK));
|
||||
|
||||
state_ = std::make_unique<State>(category, code, msg);
|
||||
}
|
||||
|
||||
Status::Status(StatusCategory category, int code, const char* msg) {
|
||||
// state_ will be allocated here causing the status to be treated as a failure
|
||||
ORT_ENFORCE(code != static_cast<int>(MLStatus::OK));
|
||||
ORT_ENFORCE(code != static_cast<int>(common::OK));
|
||||
|
||||
state_ = std::make_unique<State>(category, code, msg);
|
||||
}
|
||||
|
|
@ -62,7 +62,7 @@ std::string Status::ToString() const {
|
|||
result += " : ";
|
||||
result += std::to_string(Code());
|
||||
result += " : ";
|
||||
result += MLStatusToString(static_cast<MLStatus>(Code()));
|
||||
result += StatusCodeToString(static_cast<StatusCode>(Code()));
|
||||
result += " : ";
|
||||
result += state_->msg;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ using onnxruntime::Environment;
|
|||
using onnxruntime::IAllocator;
|
||||
using onnxruntime::InputDefList;
|
||||
using onnxruntime::MLFloat16;
|
||||
using onnxruntime::MLStatus;
|
||||
using onnxruntime::OutputDefList;
|
||||
using onnxruntime::Tensor;
|
||||
using onnxruntime::ToOrtStatus;
|
||||
|
|
|
|||
87
onnxruntime/python/onnxruntime_pybind_exceptions.h
Normal file
87
onnxruntime/python/onnxruntime_pybind_exceptions.h
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
#include <stdexcept>
|
||||
#include "core/common/status.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
// onnxruntime::python exceptions map 1:1 to onnxruntime:common::StatusCode enum.
|
||||
struct Fail : std::runtime_error {
|
||||
explicit Fail(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct InvalidArgument : std::runtime_error {
|
||||
explicit InvalidArgument(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct NoSuchFile : std::runtime_error {
|
||||
explicit NoSuchFile(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct NoModel : std::runtime_error {
|
||||
explicit NoModel(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct EngineError : std::runtime_error {
|
||||
explicit EngineError(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct RuntimeException : std::runtime_error {
|
||||
explicit RuntimeException(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct InvalidProtobuf : std::runtime_error {
|
||||
explicit InvalidProtobuf(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct ModelLoaded : std::runtime_error {
|
||||
explicit ModelLoaded(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct NotImplemented : std::runtime_error {
|
||||
explicit NotImplemented(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct InvalidGraph : std::runtime_error {
|
||||
explicit InvalidGraph(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
struct EPFail : std::runtime_error {
|
||||
explicit EPFail(const std::string& what) : std::runtime_error(what) {}
|
||||
};
|
||||
|
||||
void RegisterExceptions(pybind11::module& m) {
|
||||
pybind11::register_exception<Fail>(m, "Fail");
|
||||
pybind11::register_exception<InvalidArgument>(m, "InvalidArgument");
|
||||
pybind11::register_exception<NoSuchFile>(m, "NoSuchFile");
|
||||
pybind11::register_exception<NoModel>(m, "NoModel");
|
||||
pybind11::register_exception<EngineError>(m, "EngineError");
|
||||
pybind11::register_exception<RuntimeException>(m, "RuntimeException");
|
||||
pybind11::register_exception<InvalidProtobuf>(m, "InvalidProtobuf");
|
||||
pybind11::register_exception<ModelLoaded>(m, "ModelLoaded");
|
||||
pybind11::register_exception<NotImplemented>(m, "NotImplemented");
|
||||
pybind11::register_exception<InvalidGraph>(m, "InvalidGraph");
|
||||
pybind11::register_exception<EPFail>(m, "EPFail");
|
||||
}
|
||||
|
||||
inline void OrtPybindThrowIfError(onnxruntime::common::Status status) {
|
||||
std::string msg = status.ToString();
|
||||
if (!status.IsOK()) {
|
||||
switch (status.Code()) {
|
||||
case onnxruntime::common::StatusCode::FAIL:
|
||||
throw Fail(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
|
||||
throw InvalidArgument(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::NO_SUCHFILE:
|
||||
throw NoSuchFile(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::NO_MODEL:
|
||||
throw NoModel(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::ENGINE_ERROR:
|
||||
throw EngineError(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
|
||||
throw RuntimeException(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
|
||||
throw InvalidProtobuf(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::NOT_IMPLEMENTED:
|
||||
throw NotImplemented(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::INVALID_GRAPH:
|
||||
throw InvalidGraph(std::move(msg));
|
||||
case onnxruntime::common::StatusCode::EP_FAIL:
|
||||
throw EPFail(std::move(msg));
|
||||
default:
|
||||
throw std::runtime_error(std::move(msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -15,8 +15,6 @@
|
|||
#include "core/framework/ml_value.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnxruntime_pybind_exceptions.h"
|
||||
#include "onnxruntime_pybind_mlvalue.h"
|
||||
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
|
@ -125,7 +126,6 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_BrainS
|
|||
#pragma warning(disable : 4267 4996 4503 4003)
|
||||
#endif // _MSC_VER
|
||||
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
|
|
@ -144,10 +144,10 @@ static const SessionOptions& GetDefaultCPUSessionOptions() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void AddNonTensor(OrtValue& val, vector<py::object>& pyobjs) {
|
||||
void AddNonTensor(OrtValue& val, std::vector<py::object>& pyobjs) {
|
||||
pyobjs.push_back(py::cast(val.Get<T>()));
|
||||
}
|
||||
void AddNonTensorAsPyObj(OrtValue& val, vector<py::object>& pyobjs) {
|
||||
void AddNonTensorAsPyObj(OrtValue& val, std::vector<py::object>& pyobjs) {
|
||||
// Should be in sync with core/framework/datatypes.h
|
||||
if (val.Type() == DataTypeImpl::GetType<MapStringToString>()) {
|
||||
AddNonTensor<MapStringToString>(val, pyobjs);
|
||||
|
|
@ -182,7 +182,7 @@ void AddNonTensorAsPyObj(OrtValue& val, vector<py::object>& pyobjs) {
|
|||
}
|
||||
}
|
||||
|
||||
void AddTensorAsPyObj(OrtValue& val, vector<py::object>& pyobjs) {
|
||||
void AddTensorAsPyObj(OrtValue& val, std::vector<py::object>& pyobjs) {
|
||||
const Tensor& rtensor = val.Get<Tensor>();
|
||||
std::vector<npy_intp> npy_dims;
|
||||
const TensorShape& shape = rtensor.Shape();
|
||||
|
|
@ -235,10 +235,7 @@ class SessionObjectInitializer {
|
|||
|
||||
inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) {
|
||||
auto p = f.CreateProvider();
|
||||
auto status = sess->RegisterExecutionProvider(std::move(p));
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ErrorMessage().c_str());
|
||||
}
|
||||
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(p)));
|
||||
}
|
||||
|
||||
// ordered by default priority. highest to lowest.
|
||||
|
|
@ -327,11 +324,7 @@ void InitializeSession(InferenceSession* sess, const std::vector<std::string>& p
|
|||
} else {
|
||||
RegisterExecutionProviders(sess, provider_types);
|
||||
}
|
||||
onnxruntime::common::Status status;
|
||||
status = sess->Initialize();
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ToString().c_str());
|
||||
}
|
||||
OrtPybindThrowIfError(sess->Initialize());
|
||||
}
|
||||
|
||||
void addGlobalMethods(py::module& m) {
|
||||
|
|
@ -339,8 +332,12 @@ void addGlobalMethods(py::module& m) {
|
|||
m.def(
|
||||
"get_device", []() -> std::string { return BACKEND_DEVICE; },
|
||||
"Return the device used to compute the prediction (CPU, MKL, ...)");
|
||||
m.def("get_all_providers", []() -> const std::vector<std::string>& { return GetAllProviders(); });
|
||||
m.def("get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableProviders(); });
|
||||
m.def(
|
||||
"get_all_providers", []() -> const std::vector<std::string>& { return GetAllProviders(); },
|
||||
"Return list of Execution Providers that this version of Onnxruntime can support.");
|
||||
m.def(
|
||||
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableProviders(); },
|
||||
"Return list of available Execution Providers available in this installed version of Onnxruntime.");
|
||||
|
||||
#ifdef USE_NUPHAR
|
||||
m.def("set_nuphar_settings", [](const std::string& str) {
|
||||
|
|
@ -624,55 +621,57 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
return *(na.Type());
|
||||
},
|
||||
"node type")
|
||||
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
|
||||
std::ostringstream res;
|
||||
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
res << "[]";
|
||||
} else {
|
||||
res << "[";
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (utils::HasDimValue(shape->dim(i))) {
|
||||
res << shape->dim(i).dim_value();
|
||||
} else if (utils::HasDimParam(shape->dim(i))) {
|
||||
res << "'" << shape->dim(i).dim_param() << "'";
|
||||
.def(
|
||||
"__str__", [](const onnxruntime::NodeArg& na) -> std::string {
|
||||
std::ostringstream res;
|
||||
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
res << "[]";
|
||||
} else {
|
||||
res << "None";
|
||||
res << "[";
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (utils::HasDimValue(shape->dim(i))) {
|
||||
res << shape->dim(i).dim_value();
|
||||
} else if (utils::HasDimParam(shape->dim(i))) {
|
||||
res << "'" << shape->dim(i).dim_param() << "'";
|
||||
} else {
|
||||
res << "None";
|
||||
}
|
||||
|
||||
if (i < shape->dim_size() - 1) {
|
||||
res << ", ";
|
||||
}
|
||||
}
|
||||
res << "]";
|
||||
}
|
||||
res << ")";
|
||||
|
||||
return std::string(res.str());
|
||||
},
|
||||
"converts the node into a readable string")
|
||||
.def_property_readonly(
|
||||
"shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
return arr;
|
||||
}
|
||||
|
||||
if (i < shape->dim_size() - 1) {
|
||||
res << ", ";
|
||||
arr.resize(shape->dim_size());
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (utils::HasDimValue(shape->dim(i))) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_value());
|
||||
} else if (utils::HasDimParam(shape->dim(i))) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_param());
|
||||
} else {
|
||||
arr[i] = py::none();
|
||||
}
|
||||
}
|
||||
}
|
||||
res << "]";
|
||||
}
|
||||
res << ")";
|
||||
|
||||
return std::string(res.str());
|
||||
},
|
||||
"converts the node into a readable string")
|
||||
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
|
||||
auto shape = na.Shape();
|
||||
std::vector<py::object> arr;
|
||||
if (shape == nullptr || shape->dim_size() == 0) {
|
||||
return arr;
|
||||
}
|
||||
|
||||
arr.resize(shape->dim_size());
|
||||
for (int i = 0; i < shape->dim_size(); ++i) {
|
||||
if (utils::HasDimValue(shape->dim(i))) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_value());
|
||||
} else if (utils::HasDimParam(shape->dim(i))) {
|
||||
arr[i] = py::cast(shape->dim(i).dim_param());
|
||||
} else {
|
||||
arr[i] = py::none();
|
||||
}
|
||||
}
|
||||
return arr;
|
||||
},
|
||||
"node shape (assuming the node holds a tensor)");
|
||||
return arr;
|
||||
},
|
||||
"node shape (assuming the node holds a tensor)");
|
||||
|
||||
py::class_<SessionObjectInitializer>(m, "SessionObjectInitializer");
|
||||
py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
|
||||
|
|
@ -680,22 +679,17 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
.def(py::init<SessionOptions, SessionObjectInitializer>())
|
||||
.def(
|
||||
"load_model", [](InferenceSession* sess, const std::string& path, std::vector<std::string>& provider_types) {
|
||||
auto status = sess->Load(path);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ToString().c_str());
|
||||
}
|
||||
OrtPybindThrowIfError(sess->Load(path));
|
||||
InitializeSession(sess, provider_types);
|
||||
},
|
||||
R"pbdoc(Load a model saved in ONNX format.)pbdoc")
|
||||
.def("read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel, std::vector<std::string>& provider_types) {
|
||||
std::istringstream buffer(serializedModel);
|
||||
auto status = sess->Load(buffer);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ToString().c_str());
|
||||
}
|
||||
InitializeSession(sess, provider_types);
|
||||
},
|
||||
R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
|
||||
.def(
|
||||
"read_bytes", [](InferenceSession* sess, const py::bytes& serializedModel, std::vector<std::string>& provider_types) {
|
||||
std::istringstream buffer(serializedModel);
|
||||
OrtPybindThrowIfError(sess->Load(buffer));
|
||||
InitializeSession(sess, provider_types);
|
||||
},
|
||||
R"pbdoc(Load a model serialized in ONNX format.)pbdoc")
|
||||
.def("run", [](InferenceSession* sess, std::vector<std::string> output_names, std::map<std::string, py::object> pyfeeds, RunOptions* run_options = nullptr) -> std::vector<py::object> {
|
||||
NameMLValMap feeds;
|
||||
for (auto _ : pyfeeds) {
|
||||
|
|
@ -724,17 +718,12 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
// release GIL to allow multiple python threads to invoke Run() in parallel.
|
||||
py::gil_scoped_release release;
|
||||
if (run_options != nullptr) {
|
||||
status = sess->Run(*run_options, feeds, output_names, &fetches);
|
||||
OrtPybindThrowIfError(sess->Run(*run_options, feeds, output_names, &fetches));
|
||||
} else {
|
||||
status = sess->Run(feeds, output_names, &fetches);
|
||||
OrtPybindThrowIfError(sess->Run(feeds, output_names, &fetches));
|
||||
}
|
||||
}
|
||||
|
||||
if (!status.IsOK()) {
|
||||
auto mes = status.ToString();
|
||||
throw std::runtime_error(std::string("Method run failed due to: ") + std::string(mes.c_str()));
|
||||
}
|
||||
|
||||
std::vector<py::object> rfetch;
|
||||
rfetch.reserve(fetches.size());
|
||||
for (auto _ : fetches) {
|
||||
|
|
@ -754,40 +743,29 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
})
|
||||
.def_property_readonly("inputs_meta", [](const InferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
|
||||
auto res = sess->GetModelInputs();
|
||||
if (!res.first.IsOK()) {
|
||||
throw std::runtime_error(res.first.ToString().c_str());
|
||||
} else {
|
||||
return *(res.second);
|
||||
}
|
||||
OrtPybindThrowIfError(res.first);
|
||||
return *(res.second);
|
||||
})
|
||||
.def_property_readonly("outputs_meta", [](const InferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
|
||||
auto res = sess->GetModelOutputs();
|
||||
if (!res.first.IsOK()) {
|
||||
throw std::runtime_error(res.first.ToString().c_str());
|
||||
} else {
|
||||
return *(res.second);
|
||||
}
|
||||
OrtPybindThrowIfError(res.first);
|
||||
return *(res.second);
|
||||
})
|
||||
.def_property_readonly("overridable_initializers", [](const InferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
|
||||
auto res = sess->GetOverridableInitializers();
|
||||
if (!res.first.IsOK()) {
|
||||
throw std::runtime_error(res.first.ToString().c_str());
|
||||
} else {
|
||||
return *res.second;
|
||||
}
|
||||
OrtPybindThrowIfError(res.first);
|
||||
return *(res.second);
|
||||
})
|
||||
.def_property_readonly("model_meta", [](const InferenceSession* sess) -> const onnxruntime::ModelMetadata& {
|
||||
auto res = sess->GetModelMetadata();
|
||||
if (!res.first.IsOK()) {
|
||||
throw std::runtime_error(res.first.ToString().c_str());
|
||||
} else {
|
||||
return *(res.second);
|
||||
}
|
||||
OrtPybindThrowIfError(res.first);
|
||||
return *(res.second);
|
||||
});
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
||||
m.doc() = "pybind11 stateful interface to ONNX runtime";
|
||||
RegisterExceptions(m);
|
||||
|
||||
auto initialize = [&]() {
|
||||
// Initialization of the module
|
||||
|
|
@ -797,10 +775,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
|||
})();
|
||||
|
||||
static std::unique_ptr<Environment> env;
|
||||
auto status = Environment::Create(env);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ToString().c_str());
|
||||
}
|
||||
OrtPybindThrowIfError(Environment::Create(env));
|
||||
|
||||
static bool initialized = false;
|
||||
if (initialized) {
|
||||
|
|
|
|||
|
|
@ -23,8 +23,6 @@ protobufutil::Status GenerateProtobufStatus(const int& onnx_status, const std::s
|
|||
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
|
||||
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
|
||||
case onnxruntime::common::StatusCode::INVALID_GRAPH:
|
||||
case onnxruntime::common::StatusCode::SHAPE_INFERENCE_NOT_REGISTERED:
|
||||
case onnxruntime::common::StatusCode::REQUIREMENT_NOT_REGISTERED:
|
||||
case onnxruntime::common::StatusCode::NO_SUCHFILE:
|
||||
case onnxruntime::common::StatusCode::NO_MODEL:
|
||||
code = protobufutil::error::Code::INVALID_ARGUMENT;
|
||||
|
|
@ -33,6 +31,7 @@ protobufutil::Status GenerateProtobufStatus(const int& onnx_status, const std::s
|
|||
code = protobufutil::error::Code::UNIMPLEMENTED;
|
||||
break;
|
||||
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
|
||||
case onnxruntime::common::StatusCode::EP_FAIL:
|
||||
code = protobufutil::error::Code::INTERNAL;
|
||||
break;
|
||||
default:
|
||||
|
|
|
|||
Loading…
Reference in a new issue