Clean up status checks in gradient_graph_builder_test.cc.

This commit is contained in:
Edward Chen 2020-06-11 00:53:45 +00:00 committed by edgchen1
parent 7096e6f5ef
commit 6b4f652017

View file

@ -73,25 +73,25 @@ static Status BuildBackPropGraph(
/**
* Run a training session for this model for 1 epoch, using batch size of 1 and synthetic input data.
* @param so - SessionOptions for this run.
* @param backprop_model_file - Mocel file to be run. This should already contain loss function and backward prop nodes.
* @param backprop_model_file - Model file to be run. This should already contain loss function and backward prop nodes.
* @return TrainingSession for this run.
*/
static std::unique_ptr<TrainingSession> RunTrainingSessionWithChecks(
const SessionOptions& so, const PathString& backprop_model_file) {
std::unique_ptr<Environment> env;
EXPECT_TRUE(Environment::Create(nullptr, env).IsOK());
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
std::unique_ptr<TrainingSession> training_session = onnxruntime::make_unique<TrainingSession>(so, *env);
EXPECT_TRUE(training_session->Load(backprop_model_file).IsOK());
ORT_THROW_IF_ERROR(training_session->Load(backprop_model_file));
std::pair<common::Status, const ModelMetadata*> res = training_session->GetModelMetadata();
EXPECT_TRUE(res.first.IsOK());
EXPECT_TRUE(res.second != nullptr);
ORT_THROW_IF_ERROR(res.first);
ORT_ENFORCE(res.second != nullptr);
auto model_metadata = res.second;
std::cout << "Loaded " << model_metadata->graph_name << '\n';
EXPECT_TRUE(training_session->Initialize().IsOK());
ORT_THROW_IF_ERROR(training_session->Initialize());
std::vector<MLValue> gradient_fetches;
RunOptions run_options;
@ -116,7 +116,7 @@ static std::unique_ptr<TrainingSession> RunTrainingSessionWithChecks(
auto start_time = std::chrono::high_resolution_clock::now();
EXPECT_TRUE(training_session->Run(run_options, fw_feeds.first, fw_feeds.second, training_output_names, &gradient_fetches).IsOK());
ORT_THROW_IF_ERROR(training_session->Run(run_options, fw_feeds.first, fw_feeds.second, training_output_names, &gradient_fetches));
auto end_time = std::chrono::high_resolution_clock::now();
auto elapsed = TimeDiffMicroSeconds(start_time, end_time);
@ -297,14 +297,14 @@ static void RunBertTrainingWithChecks(
const SessionOptions& so,
const PathString& backprop_model_file) {
std::unique_ptr<Environment> env;
EXPECT_TRUE(Environment::Create(nullptr, env).IsOK());
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
std::unique_ptr<TrainingSession> training_session = onnxruntime::make_unique<TrainingSession>(so, *env);
EXPECT_TRUE(training_session->Load(backprop_model_file).IsOK());
ASSERT_STATUS_OK(training_session->Load(backprop_model_file));
std::pair<common::Status, const ModelMetadata*> res = training_session->GetModelMetadata();
EXPECT_TRUE(res.first.IsOK());
ASSERT_STATUS_OK(res.first);
ASSERT_TRUE(res.second != nullptr);
auto model_metadata = res.second;
std::cout << "Loaded " << model_metadata->graph_name << '\n';