mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Clean up status checks in gradient_graph_builder_test.cc.
This commit is contained in:
parent
7096e6f5ef
commit
6b4f652017
1 changed files with 10 additions and 10 deletions
|
|
@ -73,25 +73,25 @@ static Status BuildBackPropGraph(
|
|||
/**
|
||||
* Run a training session for this model for 1 epoch, using batch size of 1 and synthetic input data.
|
||||
* @param so - SessionOptions for this run.
|
||||
* @param backprop_model_file - Mocel file to be run. This should already contain loss function and backward prop nodes.
|
||||
* @param backprop_model_file - Model file to be run. This should already contain loss function and backward prop nodes.
|
||||
* @return TrainingSession for this run.
|
||||
*/
|
||||
static std::unique_ptr<TrainingSession> RunTrainingSessionWithChecks(
|
||||
const SessionOptions& so, const PathString& backprop_model_file) {
|
||||
std::unique_ptr<Environment> env;
|
||||
EXPECT_TRUE(Environment::Create(nullptr, env).IsOK());
|
||||
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
|
||||
|
||||
std::unique_ptr<TrainingSession> training_session = onnxruntime::make_unique<TrainingSession>(so, *env);
|
||||
|
||||
EXPECT_TRUE(training_session->Load(backprop_model_file).IsOK());
|
||||
ORT_THROW_IF_ERROR(training_session->Load(backprop_model_file));
|
||||
|
||||
std::pair<common::Status, const ModelMetadata*> res = training_session->GetModelMetadata();
|
||||
EXPECT_TRUE(res.first.IsOK());
|
||||
EXPECT_TRUE(res.second != nullptr);
|
||||
ORT_THROW_IF_ERROR(res.first);
|
||||
ORT_ENFORCE(res.second != nullptr);
|
||||
auto model_metadata = res.second;
|
||||
std::cout << "Loaded " << model_metadata->graph_name << '\n';
|
||||
|
||||
EXPECT_TRUE(training_session->Initialize().IsOK());
|
||||
ORT_THROW_IF_ERROR(training_session->Initialize());
|
||||
|
||||
std::vector<MLValue> gradient_fetches;
|
||||
RunOptions run_options;
|
||||
|
|
@ -116,7 +116,7 @@ static std::unique_ptr<TrainingSession> RunTrainingSessionWithChecks(
|
|||
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
EXPECT_TRUE(training_session->Run(run_options, fw_feeds.first, fw_feeds.second, training_output_names, &gradient_fetches).IsOK());
|
||||
ORT_THROW_IF_ERROR(training_session->Run(run_options, fw_feeds.first, fw_feeds.second, training_output_names, &gradient_fetches));
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed = TimeDiffMicroSeconds(start_time, end_time);
|
||||
|
|
@ -297,14 +297,14 @@ static void RunBertTrainingWithChecks(
|
|||
const SessionOptions& so,
|
||||
const PathString& backprop_model_file) {
|
||||
std::unique_ptr<Environment> env;
|
||||
EXPECT_TRUE(Environment::Create(nullptr, env).IsOK());
|
||||
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
|
||||
|
||||
std::unique_ptr<TrainingSession> training_session = onnxruntime::make_unique<TrainingSession>(so, *env);
|
||||
|
||||
EXPECT_TRUE(training_session->Load(backprop_model_file).IsOK());
|
||||
ASSERT_STATUS_OK(training_session->Load(backprop_model_file));
|
||||
|
||||
std::pair<common::Status, const ModelMetadata*> res = training_session->GetModelMetadata();
|
||||
EXPECT_TRUE(res.first.IsOK());
|
||||
ASSERT_STATUS_OK(res.first);
|
||||
ASSERT_TRUE(res.second != nullptr);
|
||||
auto model_metadata = res.second;
|
||||
std::cout << "Loaded " << model_metadata->graph_name << '\n';
|
||||
|
|
|
|||
Loading…
Reference in a new issue