mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update gradient ops tests (#18783)
### Description <!-- Describe your changes. --> TrainingSession has been deprecated for a while now, but the gradient ops tests are still using training session. This PR updates these tests to use inference session instead of training session. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This will enable us to remove all the training session related deprecated code from the repo.
This commit is contained in:
parent
17eaf9b053
commit
487abcd25e
1 changed files with 4 additions and 5 deletions
|
|
@ -8,7 +8,6 @@
|
|||
#include "core/framework/kernel_type_str_resolver.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/framework/gradient_graph_builder.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
|
||||
|
|
@ -76,7 +75,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss,
|
|||
}
|
||||
}
|
||||
|
||||
onnxruntime::training::TrainingSession session_object{so, GetEnvironment()};
|
||||
onnxruntime::InferenceSession session_object{so, GetEnvironment()};
|
||||
|
||||
ASSERT_TRUE(!execution_providers->empty()) << "Empty execution providers vector.";
|
||||
std::string provider_types;
|
||||
|
|
@ -102,7 +101,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss,
|
|||
|
||||
has_run = true;
|
||||
|
||||
ExecuteModel<onnxruntime::training::TrainingSession>(
|
||||
ExecuteModel<onnxruntime::InferenceSession>(
|
||||
model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_types);
|
||||
} else {
|
||||
for (const std::string& provider_type : all_provider_types) {
|
||||
|
|
@ -158,11 +157,11 @@ void GradientOpTester::Run(int output_index_to_use_as_loss,
|
|||
continue;
|
||||
|
||||
has_run = true;
|
||||
onnxruntime::training::TrainingSession session_object{so, GetEnvironment()};
|
||||
onnxruntime::InferenceSession session_object{so, GetEnvironment()};
|
||||
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
|
||||
ExecuteModel<onnxruntime::training::TrainingSession>(
|
||||
ExecuteModel<onnxruntime::InferenceSession>(
|
||||
model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_type);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue