mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Improve Session Capabilities in ORTServer (#1862)
* add unload and fixtures * update to create logger * clang-format
This commit is contained in:
parent
b43254282f
commit
60e5eee52a
7 changed files with 103 additions and 32 deletions
|
|
@ -36,36 +36,52 @@ ServerEnvironment::ServerEnvironment(OrtLoggingLevel severity, spdlog::sinks_ini
|
|||
logger_id_("ServerApp"),
|
||||
sink_(sink),
|
||||
default_logger_(std::make_shared<spdlog::logger>(logger_id_, sink)),
|
||||
runtime_environment_(severity, logger_id_.c_str(), Log, default_logger_.get()),
|
||||
session(nullptr) {
|
||||
runtime_environment_(severity, logger_id_.c_str(), Log, default_logger_.get()) {
|
||||
spdlog::set_automatic_registration(false);
|
||||
spdlog::set_level(Convert(severity_));
|
||||
spdlog::initialize_logger(default_logger_);
|
||||
}
|
||||
|
||||
void ServerEnvironment::InitializeModel(const std::string& model_path) {
|
||||
session = Ort::Session(runtime_environment_, model_path.c_str(), Ort::SessionOptions());
|
||||
void ServerEnvironment::InitializeModel(const std::string& model_path, const std::string& model_name, const std::string& model_version) {
|
||||
auto result = sessions_.emplace(std::piecewise_construct, std::forward_as_tuple(model_name, model_version), std::forward_as_tuple(runtime_environment_, model_path.c_str(), Ort::SessionOptions()));
|
||||
|
||||
auto output_count = session.GetOutputCount();
|
||||
if (!result.second) {
|
||||
throw Ort::Exception("Model of that name already loaded.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto iterator = result.first;
|
||||
auto output_count = (iterator->second).session.GetOutputCount();
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
for (size_t i = 0; i < output_count; i++) {
|
||||
auto name = session.GetOutputName(i, allocator);
|
||||
model_output_names_.push_back(name);
|
||||
auto name = (iterator->second).session.GetOutputName(i, allocator);
|
||||
(iterator->second).output_names.push_back(name);
|
||||
allocator.Free(name);
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::string>& ServerEnvironment::GetModelOutputNames() const {
|
||||
return model_output_names_;
|
||||
const std::vector<std::string>& ServerEnvironment::GetModelOutputNames(const std::string& model_name, const std::string& model_version) const {
|
||||
auto identifier = std::make_pair(model_name, model_version);
|
||||
auto it = sessions_.find(identifier);
|
||||
if (it == sessions_.end()) {
|
||||
throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL);
|
||||
}
|
||||
|
||||
return it->second.output_names;
|
||||
}
|
||||
|
||||
OrtLoggingLevel ServerEnvironment::GetLogSeverity() const {
|
||||
return severity_;
|
||||
}
|
||||
|
||||
const Ort::Session& ServerEnvironment::GetSession() const {
|
||||
return session;
|
||||
const Ort::Session& ServerEnvironment::GetSession(const std::string& model_name, const std::string& model_version) const {
|
||||
auto identifier = std::make_pair(model_name, model_version);
|
||||
auto it = sessions_.find(identifier);
|
||||
if (it == sessions_.end()) {
|
||||
throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL);
|
||||
}
|
||||
|
||||
return it->second.session;
|
||||
}
|
||||
|
||||
std::shared_ptr<spdlog::logger> ServerEnvironment::GetLogger(const std::string& request_id) const {
|
||||
|
|
@ -78,5 +94,15 @@ std::shared_ptr<spdlog::logger> ServerEnvironment::GetAppLogger() const {
|
|||
return default_logger_;
|
||||
}
|
||||
|
||||
void ServerEnvironment::UnloadModel(const std::string& model_name, const std::string& model_version) {
|
||||
auto identifier = std::make_pair(model_name, model_version);
|
||||
auto it = sessions_.find(identifier);
|
||||
if (it == sessions_.end()) {
|
||||
throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL);
|
||||
}
|
||||
|
||||
sessions_.erase(it);
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <unordered_map>
|
||||
#include <boost/functional/hash.hpp>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
|
@ -20,11 +22,12 @@ class ServerEnvironment {
|
|||
|
||||
OrtLoggingLevel GetLogSeverity() const;
|
||||
|
||||
const Ort::Session& GetSession() const;
|
||||
void InitializeModel(const std::string& model_path);
|
||||
const std::vector<std::string>& GetModelOutputNames() const;
|
||||
const Ort::Session& GetSession(const std::string& model_name, const std::string& model_version) const;
|
||||
void InitializeModel(const std::string& model_path, const std::string& model_name, const std::string& model_version);
|
||||
const std::vector<std::string>& GetModelOutputNames(const std::string& model_name, const std::string& model_version) const;
|
||||
std::shared_ptr<spdlog::logger> GetLogger(const std::string& request_id) const;
|
||||
std::shared_ptr<spdlog::logger> GetAppLogger() const;
|
||||
void UnloadModel(const std::string& model_name, const std::string& model_version);
|
||||
|
||||
private:
|
||||
const OrtLoggingLevel severity_;
|
||||
|
|
@ -34,8 +37,20 @@ class ServerEnvironment {
|
|||
|
||||
Ort::Env runtime_environment_;
|
||||
Ort::SessionOptions options_;
|
||||
Ort::Session session;
|
||||
std::vector<std::string> model_output_names_;
|
||||
|
||||
struct SessionHolder {
|
||||
Ort::Session session;
|
||||
std::vector<std::string> output_names;
|
||||
explicit SessionHolder(Ort::Env& env, std::string path, const Ort::SessionOptions& options) : session(nullptr) {
|
||||
session = Ort::Session(env, path.c_str(), options);
|
||||
};
|
||||
~SessionHolder() = default;
|
||||
SessionHolder(const SessionHolder&) = delete;
|
||||
SessionHolder(const SessionHolder&&) = delete;
|
||||
SessionHolder& operator=(const SessionHolder&) = delete;
|
||||
};
|
||||
|
||||
std::unordered_map<std::pair<std::string, std::string>, ServerEnvironment::SessionHolder, boost::hash<std::pair<std::string, std::string>>> sessions_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
|
|
|
|||
|
|
@ -131,12 +131,12 @@ protobufutil::Status Executor::Predict(const std::string& model_name,
|
|||
output_names.push_back(name);
|
||||
}
|
||||
} else {
|
||||
output_names = env_->GetModelOutputNames();
|
||||
output_names = env_->GetModelOutputNames(model_name, model_version);
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> outputs;
|
||||
try {
|
||||
outputs = Run(env_->GetSession(), run_options, input_names, input_values, output_names);
|
||||
outputs = Run(env_->GetSession(model_name, model_version), run_options, input_names, input_values, output_names);
|
||||
} catch (const Ort::Exception& e) {
|
||||
return GenerateProtobufStatus(e.GetOrtErrorCode(), e.what());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ int main(int argc, char* argv[]) {
|
|||
logger->info("Model path: {}", config.model_path);
|
||||
|
||||
try {
|
||||
env->InitializeModel(config.model_path);
|
||||
env->InitializeModel(config.model_path, "default", "1");
|
||||
logger->debug("Initialize Model Successfully!");
|
||||
} catch (const Ort::Exception& ex) {
|
||||
logger->critical("Initialize Model Failed: {} ---- Error: [{}]", ex.GetOrtErrorCode(), ex.what());
|
||||
|
|
|
|||
|
|
@ -16,13 +16,26 @@ namespace onnxruntime {
|
|||
namespace server {
|
||||
namespace test {
|
||||
|
||||
TEST(ExecutorTests, TestMul_1) {
|
||||
const static auto model_file = "testdata/mul_1.onnx";
|
||||
class ExecutorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
const static auto model_file = "testdata/mul_1.onnx";
|
||||
|
||||
onnxruntime::server::ServerEnvironment* env = ServerEnv();
|
||||
env->InitializeModel(model_file, "Name", "version");
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
onnxruntime::server::ServerEnvironment* env = ServerEnv();
|
||||
env->UnloadModel("Name", "version");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ExecutorTest, TestMul_1) {
|
||||
const static auto input_json = R"({"inputs":{"X":{"dims":[3,2],"dataType":1,"floatData":[1,2,3,4,5,6]}},"outputFilter":["Y"]})";
|
||||
const static auto expected = R"({"outputs":{"Y":{"dims":["3","2"],"dataType":1,"floatData":[1,4,9,16,25,36]}}})";
|
||||
|
||||
onnxruntime::server::ServerEnvironment* env = ServerEnv();
|
||||
env->InitializeModel(model_file);
|
||||
|
||||
onnxruntime::server::Executor executor(env, "RequestId");
|
||||
onnxruntime::server::PredictRequest request{};
|
||||
|
|
|
|||
|
|
@ -8,11 +8,14 @@
|
|||
#include "test_server_environment.h"
|
||||
#include "external/server_context_test_spouse.h"
|
||||
#include <grpcpp/impl/grpc_library.h>
|
||||
#include "test_server_environment.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
namespace grpc {
|
||||
namespace test {
|
||||
|
||||
// This class's constructor calls the GRPC library's init method.
|
||||
static ::grpc::internal::GrpcLibraryInitializer g_initializer;
|
||||
|
||||
PredictRequest GetRequest() {
|
||||
|
|
@ -28,11 +31,25 @@ PredictRequest GetRequest() {
|
|||
return req;
|
||||
}
|
||||
|
||||
std::shared_ptr<onnxruntime::server::ServerEnvironment> GetEnvironment() {
|
||||
return std::shared_ptr<onnxruntime::server::ServerEnvironment>(onnxruntime::server::test::ServerEnv(), [](onnxruntime::server::ServerEnvironment *){});
|
||||
}
|
||||
class PredictionServiceImplTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
const static auto model_file = "testdata/mul_1.onnx";
|
||||
|
||||
TEST(PredictionServiceImplTests, HappyPath) {
|
||||
onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv();
|
||||
// Implementation detail - currently predict is hardcoded to model "default" version 1.
|
||||
env->InitializeModel(model_file, "default", "1");
|
||||
}
|
||||
void TearDown() override {
|
||||
onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv();
|
||||
env->UnloadModel("default", "1");
|
||||
}
|
||||
std::shared_ptr<onnxruntime::server::ServerEnvironment> GetEnvironment() {
|
||||
return std::shared_ptr<onnxruntime::server::ServerEnvironment>(onnxruntime::server::test::ServerEnv(), [](onnxruntime::server::ServerEnvironment*) {});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(PredictionServiceImplTest, HappyPath) {
|
||||
auto env = GetEnvironment();
|
||||
PredictionServiceImpl test{env};
|
||||
auto request = GetRequest();
|
||||
|
|
@ -42,7 +59,7 @@ TEST(PredictionServiceImplTests, HappyPath) {
|
|||
EXPECT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
TEST(PredictionServiceImplTests, InvalidInput) {
|
||||
TEST_F(PredictionServiceImplTest, InvalidInput) {
|
||||
auto env = GetEnvironment();
|
||||
PredictionServiceImpl test{env};
|
||||
auto request = GetRequest();
|
||||
|
|
@ -53,7 +70,7 @@ TEST(PredictionServiceImplTests, InvalidInput) {
|
|||
EXPECT_EQ(status.error_code(), ::grpc::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST(PredictionServiceImplTests, SuccessRequestID) {
|
||||
TEST_F(PredictionServiceImplTest, SuccessRequestID) {
|
||||
auto env = GetEnvironment();
|
||||
PredictionServiceImpl test{env};
|
||||
auto request = GetRequest();
|
||||
|
|
@ -66,7 +83,7 @@ TEST(PredictionServiceImplTests, SuccessRequestID) {
|
|||
EXPECT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
TEST(PredictionServiceImplTests, InvalidInputRequestID) {
|
||||
TEST_F(PredictionServiceImplTest, InvalidInputRequestID) {
|
||||
auto env = GetEnvironment();
|
||||
PredictionServiceImpl test{env};
|
||||
auto request = GetRequest();
|
||||
|
|
@ -81,7 +98,7 @@ TEST(PredictionServiceImplTests, InvalidInputRequestID) {
|
|||
EXPECT_FALSE(status.ok());
|
||||
}
|
||||
|
||||
TEST(PredictionServiceImplTests, SuccessClientID) {
|
||||
TEST_F(PredictionServiceImplTest, SuccessClientID) {
|
||||
auto env = GetEnvironment();
|
||||
PredictionServiceImpl test{env};
|
||||
auto request = GetRequest();
|
||||
|
|
@ -96,7 +113,7 @@ TEST(PredictionServiceImplTests, SuccessClientID) {
|
|||
EXPECT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
TEST(PredictionServiceImplTests, InvalidInputClientID) {
|
||||
TEST_F(PredictionServiceImplTest, InvalidInputClientID) {
|
||||
auto env = GetEnvironment();
|
||||
PredictionServiceImpl test{env};
|
||||
auto request = GetRequest();
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ GTEST_API_ int main(int argc, char** argv) {
|
|||
|
||||
try {
|
||||
onnxruntime::server::test::TestServerEnvironment server_env{};
|
||||
onnxruntime::test::TestEnvironment env{argc, argv, false};
|
||||
onnxruntime::test::TestEnvironment env{argc, argv, true};
|
||||
status = RUN_ALL_TESTS();
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << ex.what();
|
||||
|
|
|
|||
Loading…
Reference in a new issue