diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint.py index 6d348a949a..1bbcd41221 100644 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint.py +++ b/orttraining/orttraining/test/python/orttraining_test_checkpoint.py @@ -59,22 +59,14 @@ backend_api_file = os.path.join('checkpoint', 'orttraining_test_backend_api.py') single_node_full_precision_path = os.path.join(checkpoint_dir, 'single_node', 'full_precision') single_node_mixed_precision_path = os.path.join(checkpoint_dir, 'single_node', 'mixed_precision') -data_parallelism_full_precision_path = os.path.join(checkpoint_dir, 'data_parallelism', 'full_precision') -data_parallelism_mixed_precision_path = os.path.join(checkpoint_dir, 'data_parallelism', 'mixed_precision') -distributed_zero_full_precision_adam_path = os.path.join(checkpoint_dir, 'distributed_zero', 'full_precision', 'adam') -distributed_zero_mixed_precision_adam_path = os.path.join(checkpoint_dir, 'distributed_zero', 'mixed_precision', 'adam') distributed_zero_full_precision_lamb_path = os.path.join(checkpoint_dir, 'distributed_zero', 'full_precision', 'lamb') distributed_zero_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'distributed_zero', 'mixed_precision', 'lamb') # megatron saving and loading uses a different model single_node_full_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'single_node', 'full_precision') single_node_mixed_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'single_node', 'mixed_precision') -data_parallelism_full_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'data_parallelism', 'full_precision') -data_parallelism_mixed_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'data_parallelism', 'mixed_precision') distributed_zero_full_precision_lamb_bart_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero', 'full_precision', 'lamb') distributed_zero_mixed_precision_lamb_bart_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero', 'mixed_precision', 'lamb') -distributed_megatron_full_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'full_precision', 'adam') -distributed_megatron_mixed_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'mixed_precision', 'adam') distributed_megatron_full_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'full_precision', 'lamb') distributed_megatron_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'mixed_precision', 'lamb') distributed_zero_megatron_full_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'full_precision', 'adam') @@ -85,26 +77,16 @@ distributed_zero_megatron_mixed_precision_lamb_path = os.path.join(checkpoint_di # save all checkpoint files (pre-checkpoint) _single_run(save_checkpoint_file, 'single_node_full_precision', single_node_full_precision_path) _single_run(save_checkpoint_file, 'single_node_mixed_precision', single_node_mixed_precision_path) -_distributed_run(save_checkpoint_file, 'data_parallelism_full_precision', data_parallelism_full_precision_path) -_distributed_run(save_checkpoint_file, 'data_parallelism_mixed_precision', data_parallelism_mixed_precision_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_full_precision_adam', distributed_zero_full_precision_adam_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_mixed_precision_adam', distributed_zero_mixed_precision_adam_path) _distributed_run(save_checkpoint_file, 'distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path) _distributed_run(save_checkpoint_file, 'distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path) _single_run(save_checkpoint_file, 'single_node_full_precision_bart', single_node_full_precision_bart_path) _single_run(save_checkpoint_file, 'single_node_mixed_precision_bart', single_node_mixed_precision_bart_path) -_distributed_run(save_checkpoint_file, 'data_parallelism_full_precision_bart', data_parallelism_full_precision_bart_path) -_distributed_run(save_checkpoint_file, 'data_parallelism_mixed_precision_bart', data_parallelism_mixed_precision_bart_path) _distributed_run(save_checkpoint_file, 'distributed_zero_full_precision_lamb_bart', distributed_zero_full_precision_lamb_bart_path) _distributed_run(save_checkpoint_file, 'distributed_zero_mixed_precision_lamb_bart', distributed_zero_mixed_precision_lamb_bart_path) -_distributed_run(save_checkpoint_file, 'distributed_megatron_full_precision_adam', distributed_megatron_full_precision_adam_path) -_distributed_run(save_checkpoint_file, 'distributed_megatron_mixed_precision_adam', distributed_megatron_mixed_precision_adam_path) _distributed_run(save_checkpoint_file, 'distributed_megatron_full_precision_lamb', distributed_megatron_full_precision_lamb_path) _distributed_run(save_checkpoint_file, 'distributed_megatron_mixed_precision_lamb', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_full_precision_adam', distributed_zero_megatron_full_precision_adam_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_mixed_precision_adam', distributed_zero_megatron_mixed_precision_adam_path) _distributed_run(save_checkpoint_file, 'distributed_zero_megatron_full_precision_lamb', distributed_zero_megatron_full_precision_lamb_path) _distributed_run(save_checkpoint_file, 'distributed_zero_megatron_mixed_precision_lamb', distributed_zero_megatron_mixed_precision_lamb_path) @@ -114,10 +96,6 @@ _single_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_int _single_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_single_node_full_precision', single_node_mixed_precision_path) _single_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_single_node_mixed_precision', single_node_mixed_precision_path) _single_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_single_node_mixed_precision', single_node_full_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_single_node_full_precision', data_parallelism_full_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision', data_parallelism_mixed_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision', data_parallelism_mixed_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision', data_parallelism_full_precision_path) _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_single_node_full_precision', distributed_zero_full_precision_lamb_path) _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision', distributed_zero_mixed_precision_lamb_path) _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision', distributed_zero_mixed_precision_lamb_path) @@ -131,37 +109,11 @@ _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixe _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) _single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) -# going to data parallel trainer -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_data_parallelism_full_precision', single_node_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision', single_node_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision', single_node_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision', single_node_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision', data_parallelism_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision', data_parallelism_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision', data_parallelism_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision', data_parallelism_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision', distributed_zero_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision', distributed_zero_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision', distributed_zero_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision', distributed_zero_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision', distributed_zero_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) - # going to distributed zero trainer _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_full_precision', single_node_full_precision_path) _distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision', single_node_mixed_precision_path) _distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision', single_node_mixed_precision_path) _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision', single_node_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision', data_parallelism_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision', data_parallelism_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision', data_parallelism_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision', data_parallelism_full_precision_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision', distributed_zero_full_precision_lamb_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision', distributed_zero_mixed_precision_lamb_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision', distributed_zero_mixed_precision_lamb_path) @@ -180,10 +132,6 @@ _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precisio _distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision', single_node_mixed_precision_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision', single_node_mixed_precision_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision', single_node_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision', data_parallelism_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision', data_parallelism_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision', data_parallelism_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision', data_parallelism_full_precision_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision', distributed_zero_full_precision_lamb_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision', distributed_zero_mixed_precision_lamb_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision', distributed_zero_mixed_precision_lamb_bart_path) @@ -202,10 +150,6 @@ _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precisio _distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision', single_node_mixed_precision_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision', single_node_mixed_precision_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision', single_node_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision', data_parallelism_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision', data_parallelism_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision', data_parallelism_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision', data_parallelism_full_precision_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision', distributed_zero_full_precision_lamb_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision', distributed_zero_mixed_precision_lamb_bart_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_mixed_precision_lamb_bart_path) @@ -219,28 +163,4 @@ _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) _distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) -# checkpoint aggregation tests -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_full_precision_adam', distributed_zero_full_precision_adam_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_mixed_precision_adam', distributed_zero_mixed_precision_adam_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_full_precision_adam', distributed_megatron_full_precision_adam_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_mixed_precision_adam', distributed_megatron_mixed_precision_adam_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_mixed_precision_lamb', distributed_megatron_mixed_precision_lamb_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_megatron_full_precision_lamb', distributed_megatron_full_precision_lamb_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_full_precision_adam', distributed_zero_megatron_full_precision_adam_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_mixed_precision_adam', distributed_zero_megatron_mixed_precision_adam_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb', distributed_zero_megatron_mixed_precision_lamb_path) -_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_megatron_full_precision_lamb', distributed_zero_megatron_full_precision_lamb_path) - -# optimizer state loading into model-parallel tests -_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_full_precision_adam', distributed_zero_full_precision_adam_path) -_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_mixed_precision_adam', distributed_zero_mixed_precision_adam_path) -_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path) -_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path) - -# backend api tests -_single_run(backend_api_file, 'test_single_node_full_precision_lamb', single_node_full_precision_path) -_distributed_run(backend_api_file, 'test_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path) - shutil.rmtree(checkpoint_dir)