From 487abcd25ec2bcb2255a361e4b061f020a90c043 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Wed, 13 Dec 2023 11:26:52 -0800 Subject: [PATCH] Update gradient ops tests (#18783) ### Description TrainingSession has been deprecated for a while now, but the gradient ops tests are still using training session. This PR updates these tests to use inference session instead of training session. ### Motivation and Context This will enable us to remove all the training session related deprecated code from the repo. --- .../orttraining/test/gradient/gradient_op_test_utils.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index b9f7e3fe46..0944e46ff8 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -8,7 +8,6 @@ #include "core/framework/kernel_type_str_resolver.h" #include "core/session/inference_session.h" -#include "orttraining/core/session/training_session.h" #include "orttraining/core/framework/gradient_graph_builder.h" #include "orttraining/core/graph/gradient_config.h" @@ -76,7 +75,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, } } - onnxruntime::training::TrainingSession session_object{so, GetEnvironment()}; + onnxruntime::InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(!execution_providers->empty()) << "Empty execution providers vector."; std::string provider_types; @@ -102,7 +101,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, has_run = true; - ExecuteModel( + ExecuteModel( model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_types); } else { for (const std::string& provider_type : all_provider_types) { @@ -158,11 +157,11 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, continue; has_run = true; - onnxruntime::training::TrainingSession session_object{so, GetEnvironment()}; + onnxruntime::InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - ExecuteModel( + ExecuteModel( model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_type); } }