diff --git a/onnxruntime/server/environment.cc b/onnxruntime/server/environment.cc index f4a87ee7c4..97c10a6d47 100644 --- a/onnxruntime/server/environment.cc +++ b/onnxruntime/server/environment.cc @@ -36,36 +36,52 @@ ServerEnvironment::ServerEnvironment(OrtLoggingLevel severity, spdlog::sinks_ini logger_id_("ServerApp"), sink_(sink), default_logger_(std::make_shared(logger_id_, sink)), - runtime_environment_(severity, logger_id_.c_str(), Log, default_logger_.get()), - session(nullptr) { + runtime_environment_(severity, logger_id_.c_str(), Log, default_logger_.get()) { spdlog::set_automatic_registration(false); spdlog::set_level(Convert(severity_)); spdlog::initialize_logger(default_logger_); } -void ServerEnvironment::InitializeModel(const std::string& model_path) { - session = Ort::Session(runtime_environment_, model_path.c_str(), Ort::SessionOptions()); +void ServerEnvironment::InitializeModel(const std::string& model_path, const std::string& model_name, const std::string& model_version) { + auto result = sessions_.emplace(std::piecewise_construct, std::forward_as_tuple(model_name, model_version), std::forward_as_tuple(runtime_environment_, model_path.c_str(), Ort::SessionOptions())); - auto output_count = session.GetOutputCount(); + if (!result.second) { + throw Ort::Exception("Model of that name already loaded.", ORT_INVALID_ARGUMENT); + } + + auto iterator = result.first; + auto output_count = (iterator->second).session.GetOutputCount(); Ort::AllocatorWithDefaultOptions allocator; for (size_t i = 0; i < output_count; i++) { - auto name = session.GetOutputName(i, allocator); - model_output_names_.push_back(name); + auto name = (iterator->second).session.GetOutputName(i, allocator); + (iterator->second).output_names.push_back(name); allocator.Free(name); } } -const std::vector& ServerEnvironment::GetModelOutputNames() const { - return model_output_names_; +const std::vector& ServerEnvironment::GetModelOutputNames(const std::string& model_name, const std::string& model_version) const { + auto identifier = std::make_pair(model_name, model_version); + auto it = sessions_.find(identifier); + if (it == sessions_.end()) { + throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL); + } + + return it->second.output_names; } OrtLoggingLevel ServerEnvironment::GetLogSeverity() const { return severity_; } -const Ort::Session& ServerEnvironment::GetSession() const { - return session; +const Ort::Session& ServerEnvironment::GetSession(const std::string& model_name, const std::string& model_version) const { + auto identifier = std::make_pair(model_name, model_version); + auto it = sessions_.find(identifier); + if (it == sessions_.end()) { + throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL); + } + + return it->second.session; } std::shared_ptr ServerEnvironment::GetLogger(const std::string& request_id) const { @@ -78,5 +94,15 @@ std::shared_ptr ServerEnvironment::GetAppLogger() const { return default_logger_; } +void ServerEnvironment::UnloadModel(const std::string& model_name, const std::string& model_version) { + auto identifier = std::make_pair(model_name, model_version); + auto it = sessions_.find(identifier); + if (it == sessions_.end()) { + throw Ort::Exception("No model loaded of that name.", ORT_NO_MODEL); + } + + sessions_.erase(it); +} + } // namespace server } // namespace onnxruntime diff --git a/onnxruntime/server/environment.h b/onnxruntime/server/environment.h index 332ab6e064..fe58125dda 100644 --- a/onnxruntime/server/environment.h +++ b/onnxruntime/server/environment.h @@ -8,6 +8,8 @@ #include "core/session/onnxruntime_cxx_api.h" #include +#include +#include namespace onnxruntime { namespace server { @@ -20,11 +22,12 @@ class ServerEnvironment { OrtLoggingLevel GetLogSeverity() const; - const Ort::Session& GetSession() const; - void InitializeModel(const std::string& model_path); - const std::vector& GetModelOutputNames() const; + const Ort::Session& GetSession(const std::string& model_name, const std::string& model_version) const; + void InitializeModel(const std::string& model_path, const std::string& model_name, const std::string& model_version); + const std::vector& GetModelOutputNames(const std::string& model_name, const std::string& model_version) const; std::shared_ptr GetLogger(const std::string& request_id) const; std::shared_ptr GetAppLogger() const; + void UnloadModel(const std::string& model_name, const std::string& model_version); private: const OrtLoggingLevel severity_; @@ -34,8 +37,20 @@ class ServerEnvironment { Ort::Env runtime_environment_; Ort::SessionOptions options_; - Ort::Session session; - std::vector model_output_names_; + + struct SessionHolder { + Ort::Session session; + std::vector output_names; + explicit SessionHolder(Ort::Env& env, std::string path, const Ort::SessionOptions& options) : session(nullptr) { + session = Ort::Session(env, path.c_str(), options); + }; + ~SessionHolder() = default; + SessionHolder(const SessionHolder&) = delete; + SessionHolder(const SessionHolder&&) = delete; + SessionHolder& operator=(const SessionHolder&) = delete; + }; + + std::unordered_map, ServerEnvironment::SessionHolder, boost::hash>> sessions_; }; } // namespace server diff --git a/onnxruntime/server/executor.cc b/onnxruntime/server/executor.cc index e50e390082..0e5e54dd6d 100644 --- a/onnxruntime/server/executor.cc +++ b/onnxruntime/server/executor.cc @@ -131,12 +131,12 @@ protobufutil::Status Executor::Predict(const std::string& model_name, output_names.push_back(name); } } else { - output_names = env_->GetModelOutputNames(); + output_names = env_->GetModelOutputNames(model_name, model_version); } std::vector outputs; try { - outputs = Run(env_->GetSession(), run_options, input_names, input_values, output_names); + outputs = Run(env_->GetSession(model_name, model_version), run_options, input_names, input_values, output_names); } catch (const Ort::Exception& e) { return GenerateProtobufStatus(e.GetOrtErrorCode(), e.what()); } diff --git a/onnxruntime/server/main.cc b/onnxruntime/server/main.cc index 724b563689..a69d967131 100644 --- a/onnxruntime/server/main.cc +++ b/onnxruntime/server/main.cc @@ -63,7 +63,7 @@ int main(int argc, char* argv[]) { logger->info("Model path: {}", config.model_path); try { - env->InitializeModel(config.model_path); + env->InitializeModel(config.model_path, "default", "1"); logger->debug("Initialize Model Successfully!"); } catch (const Ort::Exception& ex) { logger->critical("Initialize Model Failed: {} ---- Error: [{}]", ex.GetOrtErrorCode(), ex.what()); diff --git a/onnxruntime/test/server/unit_tests/executor_test.cc b/onnxruntime/test/server/unit_tests/executor_test.cc index e069360eba..40561198d8 100644 --- a/onnxruntime/test/server/unit_tests/executor_test.cc +++ b/onnxruntime/test/server/unit_tests/executor_test.cc @@ -16,13 +16,26 @@ namespace onnxruntime { namespace server { namespace test { -TEST(ExecutorTests, TestMul_1) { - const static auto model_file = "testdata/mul_1.onnx"; +class ExecutorTest : public ::testing::Test { + protected: + void SetUp() override { + const static auto model_file = "testdata/mul_1.onnx"; + + onnxruntime::server::ServerEnvironment* env = ServerEnv(); + env->InitializeModel(model_file, "Name", "version"); + } + + void TearDown() override { + onnxruntime::server::ServerEnvironment* env = ServerEnv(); + env->UnloadModel("Name", "version"); + } +}; + +TEST_F(ExecutorTest, TestMul_1) { const static auto input_json = R"({"inputs":{"X":{"dims":[3,2],"dataType":1,"floatData":[1,2,3,4,5,6]}},"outputFilter":["Y"]})"; const static auto expected = R"({"outputs":{"Y":{"dims":["3","2"],"dataType":1,"floatData":[1,4,9,16,25,36]}}})"; onnxruntime::server::ServerEnvironment* env = ServerEnv(); - env->InitializeModel(model_file); onnxruntime::server::Executor executor(env, "RequestId"); onnxruntime::server::PredictRequest request{}; diff --git a/onnxruntime/test/server/unit_tests/prediction_service_impl_test.cc b/onnxruntime/test/server/unit_tests/prediction_service_impl_test.cc index 9b38c3f577..39f32b62c1 100644 --- a/onnxruntime/test/server/unit_tests/prediction_service_impl_test.cc +++ b/onnxruntime/test/server/unit_tests/prediction_service_impl_test.cc @@ -8,11 +8,14 @@ #include "test_server_environment.h" #include "external/server_context_test_spouse.h" #include +#include "test_server_environment.h" namespace onnxruntime { namespace server { namespace grpc { namespace test { + +// This class's constructor calls the GRPC library's init method. static ::grpc::internal::GrpcLibraryInitializer g_initializer; PredictRequest GetRequest() { @@ -28,11 +31,25 @@ PredictRequest GetRequest() { return req; } -std::shared_ptr GetEnvironment() { - return std::shared_ptr(onnxruntime::server::test::ServerEnv(), [](onnxruntime::server::ServerEnvironment *){}); -} +class PredictionServiceImplTest : public ::testing::Test { + protected: + void SetUp() override { + const static auto model_file = "testdata/mul_1.onnx"; -TEST(PredictionServiceImplTests, HappyPath) { + onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv(); + // Implementation detail - currently predict is hardcoded to model "default" version 1. + env->InitializeModel(model_file, "default", "1"); + } + void TearDown() override { + onnxruntime::server::ServerEnvironment* env = onnxruntime::server::test::ServerEnv(); + env->UnloadModel("default", "1"); + } + std::shared_ptr GetEnvironment() { + return std::shared_ptr(onnxruntime::server::test::ServerEnv(), [](onnxruntime::server::ServerEnvironment*) {}); + } +}; + +TEST_F(PredictionServiceImplTest, HappyPath) { auto env = GetEnvironment(); PredictionServiceImpl test{env}; auto request = GetRequest(); @@ -42,7 +59,7 @@ TEST(PredictionServiceImplTests, HappyPath) { EXPECT_TRUE(status.ok()); } -TEST(PredictionServiceImplTests, InvalidInput) { +TEST_F(PredictionServiceImplTest, InvalidInput) { auto env = GetEnvironment(); PredictionServiceImpl test{env}; auto request = GetRequest(); @@ -53,7 +70,7 @@ TEST(PredictionServiceImplTests, InvalidInput) { EXPECT_EQ(status.error_code(), ::grpc::INVALID_ARGUMENT); } -TEST(PredictionServiceImplTests, SuccessRequestID) { +TEST_F(PredictionServiceImplTest, SuccessRequestID) { auto env = GetEnvironment(); PredictionServiceImpl test{env}; auto request = GetRequest(); @@ -66,7 +83,7 @@ TEST(PredictionServiceImplTests, SuccessRequestID) { EXPECT_TRUE(status.ok()); } -TEST(PredictionServiceImplTests, InvalidInputRequestID) { +TEST_F(PredictionServiceImplTest, InvalidInputRequestID) { auto env = GetEnvironment(); PredictionServiceImpl test{env}; auto request = GetRequest(); @@ -81,7 +98,7 @@ TEST(PredictionServiceImplTests, InvalidInputRequestID) { EXPECT_FALSE(status.ok()); } -TEST(PredictionServiceImplTests, SuccessClientID) { +TEST_F(PredictionServiceImplTest, SuccessClientID) { auto env = GetEnvironment(); PredictionServiceImpl test{env}; auto request = GetRequest(); @@ -96,7 +113,7 @@ TEST(PredictionServiceImplTests, SuccessClientID) { EXPECT_TRUE(status.ok()); } -TEST(PredictionServiceImplTests, InvalidInputClientID) { +TEST_F(PredictionServiceImplTest, InvalidInputClientID) { auto env = GetEnvironment(); PredictionServiceImpl test{env}; auto request = GetRequest(); diff --git a/onnxruntime/test/server/unit_tests/test_main.cc b/onnxruntime/test/server/unit_tests/test_main.cc index 09a5304d5f..64f20061cc 100644 --- a/onnxruntime/test/server/unit_tests/test_main.cc +++ b/onnxruntime/test/server/unit_tests/test_main.cc @@ -10,7 +10,7 @@ GTEST_API_ int main(int argc, char** argv) { try { onnxruntime::server::test::TestServerEnvironment server_env{}; - onnxruntime::test::TestEnvironment env{argc, argv, false}; + onnxruntime::test::TestEnvironment env{argc, argv, true}; status = RUN_ALL_TESTS(); } catch (const std::exception& ex) { std::cerr << ex.what();