mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Enable multiple step run for adamw tests (on device training) (#14520)
(cherry picked from commit 414b73a02123b672e496326664cd2dc3bd6c6d24) ### Rework for PR https://github.com/microsoft/onnxruntime/pull/14068: Enable multiple step run for adamw tests (on device training) ### Removed duplicated MACRO checks for training. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
7954976e0a
commit
62442c3d27
4 changed files with 20 additions and 39 deletions
|
|
@ -224,25 +224,6 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
set(onnxruntime_ENABLE_ATEN ON)
|
||||
endif()
|
||||
|
||||
# ENABLE_TRAINING includes all training functionality
|
||||
# The following 2 entry points
|
||||
# 1. ORTModule
|
||||
# 2. ORT Training APIs
|
||||
# It includes all the feature additions as well like
|
||||
# 1. Python OP
|
||||
# 2. Aten Fallback
|
||||
# 3. Strided Tensors
|
||||
# 4. All training ops including communication and collectives ops
|
||||
# 5. ONNXBlock (Front end for training preparation when using training apis)
|
||||
# Some features are only enabled when onnxruntime_ENABLE_PYTHON is ON as they are only relevant
|
||||
# when using python env
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
set(onnxruntime_ENABLE_TRAINING_OPS ON)
|
||||
set(onnxruntime_ENABLE_TRAINING_APIS ON)
|
||||
set(onnxruntime_ENABLE_TRAINING_TORCH_INTEROP ON)
|
||||
set(onnxruntime_ENABLE_ATEN ON)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING_APIS)
|
||||
set(onnxruntime_ENABLE_TRAINING_OPS ON)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -36,14 +36,14 @@ void TorchAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_ite
|
|||
momentum2_tolerance.second = 1e-6f;
|
||||
}
|
||||
|
||||
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
|
||||
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cpu/adamw_test_single_weight_mode_0.json"),
|
||||
DefaultCpuExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
|
||||
#if USE_CUDA
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cuda/adamw_test_single_weight_mode_0.json"),
|
||||
DefaultCudaExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
|
||||
#endif
|
||||
|
||||
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {
|
||||
|
|
@ -79,7 +79,7 @@ void TorchAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_ite
|
|||
std::unordered_map<std::string, VectorInt64> weight_name_shape_mapping =
|
||||
{{"fc1.weight", {2, 3}}};
|
||||
|
||||
AdamWTestLoop(std::move(it->second),
|
||||
AdamWTestLoop(it->second,
|
||||
use_baseline_inputs_for_each_iteration, total_step, lr,
|
||||
static_cast<float>(0.9f), // alpha
|
||||
static_cast<float>(0.999f), // beta
|
||||
|
|
@ -125,14 +125,14 @@ void TorchAdamWMultipleWeightsTestLoop10Steps(bool use_baseline_inputs_for_each_
|
|||
momentum2_tolerance.second = 1e-6f;
|
||||
}
|
||||
|
||||
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
|
||||
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cpu/adamw_test_multiple_weights_mode_0.json"),
|
||||
DefaultCpuExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
|
||||
#if USE_CUDA
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cuda/adamw_test_multiple_weights_mode_0.json"),
|
||||
DefaultCudaExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
|
||||
#endif
|
||||
|
||||
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {
|
||||
|
|
@ -208,14 +208,14 @@ void HFAdamWSingleWeightTestLoop10Steps(bool use_baseline_inputs_for_each_iterat
|
|||
std::pair<float, float> momentum1_tolerance{1e-3f, 1e-6f};
|
||||
std::pair<float, float> momentum2_tolerance{1e-2f, 1e-7f};
|
||||
|
||||
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
|
||||
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cpu/adamw_test_single_weight_mode_1.json"),
|
||||
DefaultCpuExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
|
||||
#if USE_CUDA
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cuda/adamw_test_single_weight_mode_1.json"),
|
||||
DefaultCudaExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
|
||||
#endif
|
||||
|
||||
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {
|
||||
|
|
@ -291,14 +291,14 @@ void HFAdamWMultipleWeightsTestLoop10Steps(
|
|||
momentum2_tolerance.second = 1e-6f;
|
||||
}
|
||||
|
||||
std::vector<std::pair<const ORTCHAR_T*, std::unique_ptr<IExecutionProvider>>> testdata_ep_pair_vector;
|
||||
std::vector<std::pair<const ORTCHAR_T*, ExecutionProviderCreationFunc>> testdata_ep_pair_vector;
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cpu/adamw_test_multiple_weights_mode_1.json"),
|
||||
DefaultCpuExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); }));
|
||||
#if USE_CUDA
|
||||
testdata_ep_pair_vector.push_back(std::make_pair(
|
||||
ORT_TSTR("cuda/adamw_test_multiple_weights_mode_1.json"),
|
||||
DefaultCudaExecutionProvider()));
|
||||
[]() -> std::unique_ptr<IExecutionProvider> { return DefaultCudaExecutionProvider(); }));
|
||||
#endif
|
||||
|
||||
for (auto it = testdata_ep_pair_vector.begin(); it != testdata_ep_pair_vector.end(); ++it) {
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ void GetPerStepInput(
|
|||
}
|
||||
|
||||
void AdamWTestLoop(
|
||||
std::unique_ptr<IExecutionProvider> execution_provider,
|
||||
ExecutionProviderCreationFunc execution_provider_creator,
|
||||
bool use_baseline_inputs_for_each_iteration, size_t total_step, float lr,
|
||||
float alpha, float beta, float epsilon, float weight_decay, int64_t adam_mode, int64_t correct_bias,
|
||||
std::unordered_map<std::string, std::vector<std::vector<float>>>& named_weights,
|
||||
|
|
@ -117,9 +117,6 @@ void AdamWTestLoop(
|
|||
std::pair<float, float> momentum_1_tolerance,
|
||||
std::pair<float, float> momentum_2_tolerance,
|
||||
bool* update_signal) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.emplace_back(std::move(execution_provider));
|
||||
|
||||
std::vector<std::string> ordered_weight_names;
|
||||
for (auto it = weight_name_shape_mapping.begin(); it != weight_name_shape_mapping.end(); ++it) {
|
||||
const std::string& weight_name = it->first;
|
||||
|
|
@ -134,12 +131,12 @@ void AdamWTestLoop(
|
|||
GetPerStepInput(weight_name_shape_mapping, named_weights, named_momentums_1, named_momentums_2,
|
||||
0, weights_to_train, momentum1_to_train, momentum2_to_train);
|
||||
|
||||
for (size_t step = 0; step < 1; ++step) {
|
||||
for (size_t step = 0; step < total_step; ++step) {
|
||||
OpTester test("AdamWOptimizer", 1, onnxruntime::kMSDomain);
|
||||
|
||||
// Update the steps for each param group update.
|
||||
// Both torch and HF increase training step before applying gradients.
|
||||
// The test alignes with them.
|
||||
// The test aligns with them.
|
||||
int64_t increased_update_count = step + 1;
|
||||
|
||||
// Weights/momentums before applying optimization.
|
||||
|
|
@ -206,6 +203,9 @@ void AdamWTestLoop(
|
|||
momentum_2_tolerance.second);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.emplace_back(std::move(execution_provider_creator()));
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
|
||||
if (use_baseline_inputs_for_each_iteration) {
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ struct AdamTestInputOutput {
|
|||
};
|
||||
|
||||
void AdamWTestLoop(
|
||||
std::unique_ptr<IExecutionProvider> execution_provider,
|
||||
ExecutionProviderCreationFunc execution_provider_creator,
|
||||
bool use_baseline_inputs_for_each_iteration, size_t total_step, float lr,
|
||||
float alpha, float beta, float epsilon, float weight_decay, int64_t adam_mode, int64_t correct_bias,
|
||||
std::unordered_map<std::string, std::vector<std::vector<float>>>& named_weights,
|
||||
|
|
|
|||
Loading…
Reference in a new issue