Added Environment::IsInitialized() and added check to InferenceSession constructor. (#169)

This commit is contained in:
edgchen1 2018-12-13 13:34:49 -08:00 committed by Pranav Sharma
parent 34175826df
commit c5a0119d42
5 changed files with 62 additions and 3 deletions

View file

@ -43,8 +43,8 @@ function(AddTest)
endif()
endif()
target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings})
else()
target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM})
else()
target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM})
endif()
set(TEST_ARGS)
@ -304,6 +304,15 @@ else()
set(test_data_target onnxruntime_test_ir)
endif() # SingleUnitTestProject
# standalone test for inference session without environment
# the normal test executables set up a default runtime environment, which we don't want here
AddTest(
TARGET onnxruntime_test_framework_session_without_environment_standalone
SOURCES "${TEST_SRC_DIR}/framework/inference_session_without_environment/inference_session_without_environment_standalone_test.cc"
LIBS ${onnxruntime_test_framework_libs}
DEPENDS ${onnxruntime_EXTERNAL_DEPENDENCIES}
)
#
# onnxruntime_ir_graph test data
#

View file

@ -3,6 +3,7 @@
#pragma once
#include <atomic>
#include <memory>
#include "core/common/common.h"
#include "core/common/status.h"
@ -20,14 +21,21 @@ class Environment {
static Status Create(std::unique_ptr<Environment>& environment);
/**
* This function will call ::google::protobuf::ShutdownProtobufLibrary
This function will call ::google::protobuf::ShutdownProtobufLibrary
*/
~Environment();
/**
Returns whether any runtime environment instance has been initialized.
*/
static bool IsInitialized() { return is_initialized_; }
private:
ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
Environment() = default;
Status Initialize();
static std::atomic<bool> is_initialized_;
};
} // namespace onnxruntime

View file

@ -13,6 +13,8 @@ using namespace ONNX_NAMESPACE;
std::once_flag schemaRegistrationOnceFlag;
std::atomic<bool> Environment::is_initialized_{false};
Status Environment::Create(std::unique_ptr<Environment>& environment) {
environment = std::unique_ptr<Environment>(new Environment());
auto status = environment->Initialize();
@ -88,6 +90,8 @@ Internal copy node
// Register contributed schemas.
// The corresponding kernels are registered inside the appropriate execution provider.
contrib::RegisterContribSchemas();
is_initialized_ = true;
} catch (std::exception& ex) {
status = Status{ONNXRUNTIME, common::RUNTIME_EXCEPTION, std::string{"Exception caught: "} + ex.what()};
} catch (...) {

View file

@ -17,6 +17,7 @@
#include "core/graph/model.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/customregistry.h"
#include "core/framework/environment.h"
#include "core/framework/execution_frame.h"
#include "core/framework/graph_partitioner.h"
#include "core/framework/insert_cast_transformer.h"
@ -50,6 +51,9 @@ class InferenceSession::Impl {
logging_manager_{logging_manager},
session_state_{execution_providers_},
insert_cast_transformer_{"CastFloat16Transformer"} {
ONNXRUNTIME_ENFORCE(Environment::IsInitialized(),
"Environment must be initialized before creating an InferenceSession.");
InitLogger(logging_manager);
// currently the threadpool is used by the parallel executor only and hence

View file

@ -0,0 +1,34 @@
#include "gtest/gtest.h"
#include "google/protobuf/stubs/common.h"
#include "core/framework/environment.h"
#include "core/session/inference_session.h"
namespace onnxruntime {
namespace test {
TEST(InferenceSessionWithoutEnvironment, UninitializedEnvironment)
{
EXPECT_FALSE(onnxruntime::Environment::IsInitialized());
onnxruntime::SessionOptions session_options{};
EXPECT_THROW(onnxruntime::InferenceSession{session_options},
onnxruntime::OnnxRuntimeException);
}
// call protobuf shutdown to avoid memory leak
class TestEnvironment : public ::testing::Environment {
public:
void TearDown() override {
::google::protobuf::ShutdownProtobufLibrary();
}
};
} // namespace test
} // namespace onnxruntime
GTEST_API_ int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
// the following call takes ownership of the test environment
::testing::AddGlobalTestEnvironment(new onnxruntime::test::TestEnvironment{});
int status = RUN_ALL_TESTS();
return status;
}