Enable multiple step run for adamw tests (on device training) (#14520)

(cherry picked from commit 414b73a02123b672e496326664cd2dc3bd6c6d24)

### Rework for PR https://github.com/microsoft/onnxruntime/pull/14068:
Enable multiple step run for adamw tests (on device training)
### Removed duplicated MACRO checks for training.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-02-02 18:40:30 +08:00 committed by GitHub
parent 7954976e0a
commit 62442c3d27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 39 deletions

View file

@ -224,25 +224,6 @@ if (onnxruntime_ENABLE_TRAINING)
set(onnxruntime_ENABLE_ATEN ON)
endif()
# ENABLE_TRAINING includes all training functionality
# The following 2 entry points
# 1. ORTModule
# 2. ORT Training APIs
# It includes all the feature additions as well like
# 1. Python OP
# 2. Aten Fallback
# 3. Strided Tensors
# 4. All training ops including communication and collectives ops
# 5. ONNXBlock (Front end for training preparation when using training apis)
# Some features are only enabled when onnxruntime_ENABLE_PYTHON is ON as they are only relevant
# when using python env
if (onnxruntime_ENABLE_TRAINING)
set(onnxruntime_ENABLE_TRAINING_OPS ON)
set(onnxruntime_ENABLE_TRAINING_APIS ON)
set(onnxruntime_ENABLE_TRAINING_TORCH_INTEROP ON)
set(onnxruntime_ENABLE_ATEN ON)
endif()
if (onnxruntime_ENABLE_TRAINING_APIS)
set(onnxruntime_ENABLE_TRAINING_OPS ON)
endif()

View file

@ -36,14 +36,14 @@ void TorchAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_ite
momentum2_tolerance.second = 1e-6f;
}
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cpu/adamw_test_single_weight_mode_0.json"),
DefaultCpuExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
#if USE_CUDA
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cuda/adamw_test_single_weight_mode_0.json"),
DefaultCudaExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
#endif
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {
@ -79,7 +79,7 @@ void TorchAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_ite
std::unordered_map<std::string, VectorInt64> weight_name_shape_mapping =
{{"fc1.weight", {2, 3}}};
AdamWTestLoop(std::move(it->second),
AdamWTestLoop(it->second,
use_baseline_inputs_for_each_iteration, total_step, lr,
static_cast<float>(0.9f), // alpha
static_cast<float>(0.999f), // beta
@ -125,14 +125,14 @@ void TorchAdamWMultipleWeightsTestLoop10Steps(bool use_baseline_inputs_for_each_
momentum2_tolerance.second = 1e-6f;
}
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cpu/adamw_test_multiple_weights_mode_0.json"),
DefaultCpuExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
#if USE_CUDA
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cuda/adamw_test_multiple_weights_mode_0.json"),
DefaultCudaExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
#endif
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {
@ -208,14 +208,14 @@ void HFAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_iterat
std::pair<float, float> momentum1_tolerance{1e-3f, 1e-6f};
std::pair<float, float> momentum2_tolerance{1e-2f, 1e-7f};
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cpu/adamw_test_single_weight_mode_1.json"),
DefaultCpuExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
#if USE_CUDA
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cuda/adamw_test_single_weight_mode_1.json"),
DefaultCudaExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
#endif
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {
@ -291,14 +291,14 @@ void HFAdamWMultipleWeightsTestLoop10Steps(
momentum2_tolerance.second = 1e-6f;
}
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cpu/adamw_test_multiple_weights_mode_1.json"),
DefaultCpuExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
#if USE_CUDA
testdata_ep_pair_vector.push_back(std::make_pair(
ORT_TSTR("cuda/adamw_test_multiple_weights_mode_1.json"),
DefaultCudaExecutionProvider()));
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
#endif
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {

View file

@ -105,7 +105,7 @@ void GetPerStepInput(
}
void AdamWTestLoop(
std::unique_ptr<IExecutionProvider> execution_provider,
ExecutionProviderCreationFunc execution_provider_creator,
bool use_baseline_inputs_for_each_iteration, size_t total_step, float lr,
float alpha, float beta, float epsilon, float weight_decay, int64_t adam_mode, int64_t correct_bias,
std::unordered_map<std::string, std::vector<std::vector<float>>>& named_weights,
@ -117,9 +117,6 @@ void AdamWTestLoop(
std::pair<float, float> momentum_1_tolerance,
std::pair<float, float> momentum_2_tolerance,
bool* update_signal) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(std::move(execution_provider));
std::vector<std::string> ordered_weight_names;
for (auto it = weight_name_shape_mapping.begin(); it != weight_name_shape_mapping.end(); ++it) {
const std::string& weight_name = it->first;
@ -134,12 +131,12 @@ void AdamWTestLoop(
GetPerStepInput(weight_name_shape_mapping, named_weights, named_momentums_1, named_momentums_2,
0, weights_to_train, momentum1_to_train, momentum2_to_train);
for (size_t step = 0; step < 1; ++step) {
for (size_t step = 0; step < total_step; ++step) {
OpTester test("AdamWOptimizer", 1, onnxruntime::kMSDomain);
// Update the steps for each param group update.
// Both torch and HF increase training step before applying gradients.
// The test alignes with them.
// The test aligns with them.
int64_t increased_update_count = step + 1;
// Weights/momentums before applying optimization.
@ -206,6 +203,9 @@ void AdamWTestLoop(
momentum_2_tolerance.second);
}
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(std::move(execution_provider_creator()));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
if (use_baseline_inputs_for_each_iteration) {

View file

@ -107,7 +107,7 @@ struct AdamTestInputOutput {
};
void AdamWTestLoop(
std::unique_ptr<IExecutionProvider> execution_provider,
ExecutionProviderCreationFunc execution_provider_creator,
bool use_baseline_inputs_for_each_iteration, size_t total_step, float lr,
float alpha, float beta, float epsilon, float weight_decay, int64_t adam_mode, int64_t correct_bias,
std::unordered_map<std::string, std::vector<std::vector<float>>>& named_weights,