Improve Session Capabilities in ORTServer (#1862)

* add unload and fixtures

* update to create logger

* clang-format
This commit is contained in:
Colin Versteeg 2019-09-19 09:04:12 -07:00 committed by Vinitra Swamy
parent b43254282f
commit 60e5eee52a
7 changed files with 103 additions and 32 deletions

View file

@ -36,36 +36,52 @@ ServerEnvironment::ServerEnvironment(OrtLoggingLevel severity, spdlog::sinks_ini
logger_id_("ServerApp"),
sink_(sink),
default_logger_(std::make_shared<spdlog::logger>(logger_id_, sink)),
runtime_environment_(severity, logger_id_.c_str(), Log, default_logger_.get()),
session(nullptr) {
runtime_environment_(severity, logger_id_.c_str(), Log, default_logger_.get()) {
spdlog::set_automatic_registration(false);
spdlog::set_level(Convert(severity_));
spdlog::initialize_logger(default_logger_);
}
void ServerEnvironment::InitializeModel(const std::string& model_path) {
session = Ort::Session(runtime_environment_, model_path.c_str(), Ort::SessionOptions());
void ServerEnvironment::InitializeModel(const std::string& model_path, const std::string& model_name, const std::string& model_version) {
auto result = sessions_.emplace(std::piecewise_construct, std::forward_as_tuple(model_name, model_version), std::forward_as_tuple(runtime_environment_, model_path.c_str(), Ort::SessionOptions()));
auto output_count = session.GetOutputCount();
if (!result.second) {
throw Ort::Exception("Model of that name already loaded.", ORT_INVALID_ARGUMENT);
}
auto iterator = result.first;
auto output_count = (iterator->second).session.GetOutputCount();
Ort::AllocatorWithDefaultOptions allocator;
for (size_t i = 0; i < output_count; i++) {
auto name = session.GetOutputName(i, allocator);
model_output_names_.push_back(name);
auto name = (iterator->second).session.GetOutputName(i, allocator);
(iterator->second).output_names.push_back(name);
allocator.Free(name);
}
}
const std::vector<std::string>& ServerEnvironment::GetModelOutputNames() const {
return model_output_names_;
const std::vector<std::string>& ServerEnvironment::GetModelOutputNames(const std::string& model_name, const std::string& model_version) const {
auto identifier = std::make_pair(model_name, model_version);
auto it = sessions_.find(identifier);
if (it == sessions_.end()) {
throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL);
}
return it->second.output_names;
}
OrtLoggingLevel ServerEnvironment::GetLogSeverity() const {
return severity_;
}
const Ort::Session& ServerEnvironment::GetSession() const {
return session;
const Ort::Session& ServerEnvironment::GetSession(const std::string& model_name, const std::string& model_version) const {
auto identifier = std::make_pair(model_name, model_version);
auto it = sessions_.find(identifier);
if (it == sessions_.end()) {
throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL);
}
return it->second.session;
}
std::shared_ptr<spdlog::logger> ServerEnvironment::GetLogger(const std::string& request_id) const {
@ -78,5 +94,15 @@ std::shared_ptr<spdlog::logger> ServerEnvironment::GetAppLogger() const {
return default_logger_;
}
void ServerEnvironment::UnloadModel(const std::string& model_name, const std::string& model_version) {
auto identifier = std::make_pair(model_name, model_version);
auto it = sessions_.find(identifier);
if (it == sessions_.end()) {
throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL);
}
sessions_.erase(it);
}
} // namespace server
} // namespace onnxruntime

View file

@ -8,6 +8,8 @@
#include "core/session/onnxruntime_cxx_api.h"
#include <spdlog/spdlog.h>
#include <unordered_map>
#include <boost/functional/hash.hpp>
namespace onnxruntime {
namespace server {
@ -20,11 +22,12 @@ class ServerEnvironment {
OrtLoggingLevel GetLogSeverity() const;
const Ort::Session& GetSession() const;
void InitializeModel(const std::string& model_path);
const std::vector<std::string>& GetModelOutputNames() const;
const Ort::Session& GetSession(const std::string& model_name, const std::string& model_version) const;
void InitializeModel(const std::string& model_path, const std::string& model_name, const std::string& model_version);
const std::vector<std::string>& GetModelOutputNames(const std::string& model_name, const std::string& model_version) const;
std::shared_ptr<spdlog::logger> GetLogger(const std::string& request_id) const;
std::shared_ptr<spdlog::logger> GetAppLogger() const;
void UnloadModel(const std::string& model_name, const std::string& model_version);
private:
const OrtLoggingLevel severity_;
@ -34,8 +37,20 @@ class ServerEnvironment {
Ort::Env runtime_environment_;
Ort::SessionOptions options_;
Ort::Session session;
std::vector<std::string> model_output_names_;
struct SessionHolder {
Ort::Session session;
std::vector<std::string> output_names;
explicit SessionHolder(Ort::Env& env, std::string path, const Ort::SessionOptions& options) : session(nullptr) {
session = Ort::Session(env, path.c_str(), options);
};
~SessionHolder() = default;
SessionHolder(const SessionHolder&) = delete;
SessionHolder(const SessionHolder&&) = delete;
SessionHolder& operator=(const SessionHolder&) = delete;
};
std::unordered_map<std::pair<std::string, std::string>, ServerEnvironment::SessionHolder, boost::hash<std::pair<std::string, std::string>>> sessions_;
};
} // namespace server

View file

@ -131,12 +131,12 @@ protobufutil::Status Executor::Predict(const std::string& model_name,
output_names.push_back(name);
}
} else {
output_names = env_->GetModelOutputNames();
output_names = env_->GetModelOutputNames(model_name, model_version);
}
std::vector<Ort::Value> outputs;
try {
outputs = Run(env_->GetSession(), run_options, input_names, input_values, output_names);
outputs = Run(env_->GetSession(model_name, model_version), run_options, input_names, input_values, output_names);
} catch (const Ort::Exception& e) {
return GenerateProtobufStatus(e.GetOrtErrorCode(), e.what());
}

View file

@ -63,7 +63,7 @@ int main(int argc, char* argv[]) {
logger->info("Model path: {}", config.model_path);
try {
env->InitializeModel(config.model_path);
env->InitializeModel(config.model_path, "default", "1");
logger->debug("Initialize Model Successfully!");
} catch (const Ort::Exception& ex) {
logger->critical("Initialize Model Failed: {} ---- Error: [{}]", ex.GetOrtErrorCode(), ex.what());

View file

@ -16,13 +16,26 @@ namespace onnxruntime {
namespace server {
namespace test {
TEST(ExecutorTests, TestMul_1) {
const static auto model_file = "testdata/mul_1.onnx";
class ExecutorTest : public ::testing::Test {
protected:
void SetUp() override {
const static auto model_file = "testdata/mul_1.onnx";
onnxruntime::server::ServerEnvironment* env = ServerEnv();
env->InitializeModel(model_file, "Name", "version");
}
void TearDown() override {
onnxruntime::server::ServerEnvironment* env = ServerEnv();
env->UnloadModel("Name", "version");
}
};
TEST_F(ExecutorTest, TestMul_1) {
const static auto input_json = R"({"inputs":{"X":{"dims":[3,2],"dataType":1,"floatData":[1,2,3,4,5,6]}},"outputFilter":["Y"]})";
const static auto expected = R"({"outputs":{"Y":{"dims":["3","2"],"dataType":1,"floatData":[1,4,9,16,25,36]}}})";
onnxruntime::server::ServerEnvironment* env = ServerEnv();
env->InitializeModel(model_file);
onnxruntime::server::Executor executor(env, "RequestId");
onnxruntime::server::PredictRequest request{};

View file

@ -8,11 +8,14 @@
#include "test_server_environment.h"
#include "external/server_context_test_spouse.h"
#include <grpcpp/impl/grpc_library.h>
#include "test_server_environment.h"
namespace onnxruntime {
namespace server {
namespace grpc {
namespace test {
// This class's constructor calls the GRPC library's init method.
static ::grpc::internal::GrpcLibraryInitializer g_initializer;
PredictRequest GetRequest() {
@ -28,11 +31,25 @@ PredictRequest GetRequest() {
return req;
}
std::shared_ptr<onnxruntime::server::ServerEnvironment> GetEnvironment() {
return std::shared_ptr<onnxruntime::server::ServerEnvironment>(onnxruntime::server::test::ServerEnv(), [](onnxruntime::server::ServerEnvironment *){});
}
class PredictionServiceImplTest : public ::testing::Test {
protected:
void SetUp() override {
const static auto model_file = "testdata/mul_1.onnx";
TEST(PredictionServiceImplTests, HappyPath) {
onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv();
// Implementation detail - currently predict is hardcoded to model "default" version 1.
env->InitializeModel(model_file, "default", "1");
}
void TearDown() override {
onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv();
env->UnloadModel("default", "1");
}
std::shared_ptr<onnxruntime::server::ServerEnvironment> GetEnvironment() {
return std::shared_ptr<onnxruntime::server::ServerEnvironment>(onnxruntime::server::test::ServerEnv(), [](onnxruntime::server::ServerEnvironment*) {});
}
};
TEST_F(PredictionServiceImplTest, HappyPath) {
auto env = GetEnvironment();
PredictionServiceImpl test{env};
auto request = GetRequest();
@ -42,7 +59,7 @@ TEST(PredictionServiceImplTests, HappyPath) {
EXPECT_TRUE(status.ok());
}
TEST(PredictionServiceImplTests, InvalidInput) {
TEST_F(PredictionServiceImplTest, InvalidInput) {
auto env = GetEnvironment();
PredictionServiceImpl test{env};
auto request = GetRequest();
@ -53,7 +70,7 @@ TEST(PredictionServiceImplTests, InvalidInput) {
EXPECT_EQ(status.error_code(), ::grpc::INVALID_ARGUMENT);
}
TEST(PredictionServiceImplTests, SuccessRequestID) {
TEST_F(PredictionServiceImplTest, SuccessRequestID) {
auto env = GetEnvironment();
PredictionServiceImpl test{env};
auto request = GetRequest();
@ -66,7 +83,7 @@ TEST(PredictionServiceImplTests, SuccessRequestID) {
EXPECT_TRUE(status.ok());
}
TEST(PredictionServiceImplTests, InvalidInputRequestID) {
TEST_F(PredictionServiceImplTest, InvalidInputRequestID) {
auto env = GetEnvironment();
PredictionServiceImpl test{env};
auto request = GetRequest();
@ -81,7 +98,7 @@ TEST(PredictionServiceImplTests, InvalidInputRequestID) {
EXPECT_FALSE(status.ok());
}
TEST(PredictionServiceImplTests, SuccessClientID) {
TEST_F(PredictionServiceImplTest, SuccessClientID) {
auto env = GetEnvironment();
PredictionServiceImpl test{env};
auto request = GetRequest();
@ -96,7 +113,7 @@ TEST(PredictionServiceImplTests, SuccessClientID) {
EXPECT_TRUE(status.ok());
}
TEST(PredictionServiceImplTests, InvalidInputClientID) {
TEST_F(PredictionServiceImplTest, InvalidInputClientID) {
auto env = GetEnvironment();
PredictionServiceImpl test{env};
auto request = GetRequest();

View file

@ -10,7 +10,7 @@ GTEST_API_ int main(int argc, char** argv) {
try {
onnxruntime::server::test::TestServerEnvironment server_env{};
onnxruntime::test::TestEnvironment env{argc, argv, false};
onnxruntime::test::TestEnvironment env{argc, argv, true};
status = RUN_ALL_TESTS();
} catch (const std::exception& ex) {
std::cerr << ex.what();