From 62442c3d2747c2291c8cf9367dcab803e223dda2 Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 2 Feb 2023 18:40:30 +0800 Subject: [PATCH] Enable multiple step run for adamw tests (on device training) (#14520) (cherry picked from commit 414b73a02123b672e496326664cd2dc3bd6c6d24) ### Rework for PR https://github.com/microsoft/onnxruntime/pull/14068: Enable multiple step run for adamw tests (on device training) ### Removed duplicated MACRO checks for training. ### Motivation and Context --- cmake/CMakeLists.txt | 19 -------------- .../training_ops/cuda/optimizer/adamw_test.cc | 26 +++++++++---------- .../training_ops/cuda/optimizer/common.cc | 12 ++++----- .../test/training_ops/cuda/optimizer/common.h | 2 +- 4 files changed, 20 insertions(+), 39 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ec6e340d98..5c088aa8cd 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -224,25 +224,6 @@ if (onnxruntime_ENABLE_TRAINING) set(onnxruntime_ENABLE_ATEN ON) endif() -# ENABLE_TRAINING includes all training functionality -# The following 2 entry points -# 1. ORTModule -# 2. ORT Training APIs -# It includes all the feature additions as well like -# 1. Python OP -# 2. Aten Fallback -# 3. Strided Tensors -# 4. All training ops including communication and collectives ops -# 5. ONNXBlock (Front end for training preparation when using training apis) -# Some features are only enabled when onnxruntime_ENABLE_PYTHON is ON as they are only relevant -# when using python env -if (onnxruntime_ENABLE_TRAINING) - set(onnxruntime_ENABLE_TRAINING_OPS ON) - set(onnxruntime_ENABLE_TRAINING_APIS ON) - set(onnxruntime_ENABLE_TRAINING_TORCH_INTEROP ON) - set(onnxruntime_ENABLE_ATEN ON) -endif() - if (onnxruntime_ENABLE_TRAINING_APIS) set(onnxruntime_ENABLE_TRAINING_OPS ON) endif() diff --git a/orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc b/orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc index a6bc39e59e..2e7bef63d7 100644 --- a/orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc @@ -36,14 +36,14 @@ void TorchAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_ite momentum2_tolerance.second = 1e-6f; } - std::vector>> testdata_ep_pair_vector; + std::vector> testdata_ep_pair_vector; testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cpu/adamw_test_single_weight_mode_0.json"), - DefaultCpuExecutionProvider())); + []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); })); #if USE_CUDA testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cuda/adamw_test_single_weight_mode_0.json"), - DefaultCudaExecutionProvider())); + []() -> std::unique_ptr { return DefaultCudaExecutionProvider(); })); #endif for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) { @@ -79,7 +79,7 @@ void TorchAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_ite std::unordered_map weight_name_shape_mapping = {{"fc1.weight", {2, 3}}}; - AdamWTestLoop(std::move(it->second), + AdamWTestLoop(it->second, use_baseline_inputs_for_each_iteration, total_step, lr, static_cast(0.9f), // alpha static_cast(0.999f), // beta @@ -125,14 +125,14 @@ void TorchAdamWMultipleWeightsTestLoop10Steps(bool use_baseline_inputs_for_each_ momentum2_tolerance.second = 1e-6f; } - std::vector>> testdata_ep_pair_vector; + std::vector> testdata_ep_pair_vector; testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cpu/adamw_test_multiple_weights_mode_0.json"), - DefaultCpuExecutionProvider())); + []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); })); #if USE_CUDA testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cuda/adamw_test_multiple_weights_mode_0.json"), - DefaultCudaExecutionProvider())); + []() -> std::unique_ptr { return DefaultCudaExecutionProvider(); })); #endif for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) { @@ -208,14 +208,14 @@ void HFAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_iterat std::pair momentum1_tolerance{1e-3f, 1e-6f}; std::pair momentum2_tolerance{1e-2f, 1e-7f}; - std::vector>> testdata_ep_pair_vector; + std::vector> testdata_ep_pair_vector; testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cpu/adamw_test_single_weight_mode_1.json"), - DefaultCpuExecutionProvider())); + []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); })); #if USE_CUDA testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cuda/adamw_test_single_weight_mode_1.json"), - DefaultCudaExecutionProvider())); + []() -> std::unique_ptr { return DefaultCudaExecutionProvider(); })); #endif for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) { @@ -291,14 +291,14 @@ void HFAdamWMultipleWeightsTestLoop10Steps( momentum2_tolerance.second = 1e-6f; } - std::vector>> testdata_ep_pair_vector; + std::vector> testdata_ep_pair_vector; testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cpu/adamw_test_multiple_weights_mode_1.json"), - DefaultCpuExecutionProvider())); + []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); })); #if USE_CUDA testdata_ep_pair_vector.push_back(std::make_pair( ORT_TSTR("cuda/adamw_test_multiple_weights_mode_1.json"), - DefaultCudaExecutionProvider())); + []() -> std::unique_ptr { return DefaultCudaExecutionProvider(); })); #endif for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) { diff --git a/orttraining/orttraining/test/training_ops/cuda/optimizer/common.cc b/orttraining/orttraining/test/training_ops/cuda/optimizer/common.cc index 59fffee85c..7bbf413e4b 100644 --- a/orttraining/orttraining/test/training_ops/cuda/optimizer/common.cc +++ b/orttraining/orttraining/test/training_ops/cuda/optimizer/common.cc @@ -105,7 +105,7 @@ void GetPerStepInput( } void AdamWTestLoop( - std::unique_ptr execution_provider, + ExecutionProviderCreationFunc execution_provider_creator, bool use_baseline_inputs_for_each_iteration, size_t total_step, float lr, float alpha, float beta, float epsilon, float weight_decay, int64_t adam_mode, int64_t correct_bias, std::unordered_map>>& named_weights, @@ -117,9 +117,6 @@ void AdamWTestLoop( std::pair momentum_1_tolerance, std::pair momentum_2_tolerance, bool* update_signal) { - std::vector> execution_providers; - execution_providers.emplace_back(std::move(execution_provider)); - std::vector ordered_weight_names; for (auto it = weight_name_shape_mapping.begin(); it != weight_name_shape_mapping.end(); ++it) { const std::string& weight_name = it->first; @@ -134,12 +131,12 @@ void AdamWTestLoop( GetPerStepInput(weight_name_shape_mapping, named_weights, named_momentums_1, named_momentums_2, 0, weights_to_train, momentum1_to_train, momentum2_to_train); - for (size_t step = 0; step < 1; ++step) { + for (size_t step = 0; step < total_step; ++step) { OpTester test("AdamWOptimizer", 1, onnxruntime::kMSDomain); // Update the steps for each param group update. // Both torch and HF increase training step before applying gradients. - // The test alignes with them. + // The test aligns with them. int64_t increased_update_count = step + 1; // Weights/momentums before applying optimization. @@ -206,6 +203,9 @@ void AdamWTestLoop( momentum_2_tolerance.second); } + std::vector> execution_providers; + execution_providers.emplace_back(std::move(execution_provider_creator())); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); if (use_baseline_inputs_for_each_iteration) { diff --git a/orttraining/orttraining/test/training_ops/cuda/optimizer/common.h b/orttraining/orttraining/test/training_ops/cuda/optimizer/common.h index cc6045c552..7a983df73a 100644 --- a/orttraining/orttraining/test/training_ops/cuda/optimizer/common.h +++ b/orttraining/orttraining/test/training_ops/cuda/optimizer/common.h @@ -107,7 +107,7 @@ struct AdamTestInputOutput { }; void AdamWTestLoop( - std::unique_ptr execution_provider, + ExecutionProviderCreationFunc execution_provider_creator, bool use_baseline_inputs_for_each_iteration, size_t total_step, float lr, float alpha, float beta, float epsilon, float weight_decay, int64_t adam_mode, int64_t correct_bias, std::unordered_map>>& named_weights,