diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py index 8384367415..2b11b816bb 100644 --- a/orttraining/orttraining/python/training/checkpoint.py +++ b/orttraining/orttraining/python/training/checkpoint.py @@ -152,7 +152,7 @@ def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_di # create a new entry for this state in the state_sub_dict state_sub_dict[state_key] = state_value -def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict): +def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled): """Aggregates all model states from the rank_state_dict into state_dict""" model = _utils.state_dict_model_key() @@ -172,7 +172,9 @@ def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state # iterate over all model state keys for model_state_key, model_state_value in rank_state_dict[model][full_precision].items(): - if model_state_key in rank_state_dict[partition_info]: + # full precision model states are sharded only when they exist in the partition_info subdict and mixed + # precision training was enabled. for full precision training, full precision model states are not sharded + if mixed_precision_enabled and (model_state_key in rank_state_dict[partition_info]): # this model state is sharded since a record exists in the partition_info subdict _add_or_update_sharded_key_for_zero(model_state_key, model_state_value, state_dict[model][full_precision], model_state_key, @@ -215,7 +217,7 @@ def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, s state_dict[optimizer][model_state_key], "Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key)) -def _reshape_states(sharded_states_original_dims, state_dict): +def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled): """Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims""" model = _utils.state_dict_model_key() @@ -224,8 +226,8 @@ def _reshape_states(sharded_states_original_dims, state_dict): sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() for sharded_state_key, original_dim in sharded_states_original_dims.items(): - # reshape model states to original_dim - if model in state_dict: + # reshape model states to original_dim only when mixed precision is enabled + if mixed_precision_enabled and (model in state_dict): state_dict[model][full_precision][sharded_state_key] = \ state_dict[model][full_precision][sharded_state_key].reshape(original_dim) @@ -315,7 +317,7 @@ def aggregate_checkpoints(paths, pytorch_format=True): "Optimizer name mismatch among checkpoint files. File: {}".format(path) # aggregate all model states - _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict) + _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision) if not pytorch_format: # aggregate all optimizer states if pytorch_format is False @@ -331,7 +333,7 @@ def aggregate_checkpoints(paths, pytorch_format=True): state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims - _reshape_states(sharded_states_original_dims, state_dict) + _reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision) # return a flat structure for PyTorch model in case pytorch_format is True # else return the hierarchical structure for ORTTrainer diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py index cade9e09c3..7a4d6785f1 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py @@ -597,6 +597,103 @@ def test_checkpoint_aggregation(load_mock): 'optimizer_name': b'Adam' } + state_dict1 = { + 'model': { + 'full_precision': { + 'optimizer_sharded': np.array([1, 2, 3]), + 'non_sharded': np.array([11, 22, 33]) + } + }, + 'optimizer': { + 'optimizer_sharded': { + 'Moment_1': np.array([9, 8, 7]), + 'Moment_2': np.array([99, 88, 77]), + 'Step': np.array([5]) + }, + 'non_sharded': { + 'Moment_1': np.array([666, 555, 444]), + 'Moment_2': np.array([6666, 5555, 4444]), + 'Step': np.array([55]) + } + }, + 'trainer_options': { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(0), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0), + 'optimizer_name': b'Adam' + }, + 'partition_info': { + 'optimizer_sharded': {'original_dim': np.array([2, 3])} + } + } + + state_dict2 = { + 'model': { + 'full_precision': { + 'optimizer_sharded': np.array([1, 2, 3]), + 'non_sharded': np.array([11, 22, 33]) + } + }, + 'optimizer': { + 'optimizer_sharded': { + 'Moment_1': np.array([6, 5, 4]), + 'Moment_2': np.array([66, 55, 44]), + 'Step': np.array([5]) + }, + 'non_sharded': { + 'Moment_1': np.array([666, 555, 444]), + 'Moment_2': np.array([6666, 5555, 4444]), + 'Step': np.array([55]) + } + }, + 'trainer_options': { + 'mixed_precision': np.bool_(False), + 'world_rank': np.int64(1), + 'world_size': np.int64(1), + 'zero_stage': np.int64(0), + 'optimizer_name': b'Adam' + }, + 'partition_info': { + 'optimizer_sharded': {'original_dim': np.array([2, 3])} + } + } + + load_mock.side_effect = [trainer_options1, trainer_options2, state_dict1, state_dict2] + state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False) + + assert (state_dict['model']['full_precision']['optimizer_sharded'] == np.array([1, 2, 3])).all() + assert (state_dict['model']['full_precision']['non_sharded'] == np.array([11, 22, 33])).all() + assert (state_dict['optimizer']['optimizer_sharded']['Moment_1'] == np.array([[9, 8, 7], [6, 5, 4]])).all() + assert (state_dict['optimizer']['optimizer_sharded']['Moment_2'] == np.array([[99, 88, 77], [66, 55, 44]])).all() + assert (state_dict['optimizer']['optimizer_sharded']['Step'] == np.array([5])).all() + assert (state_dict['optimizer']['non_sharded']['Moment_1'] == np.array([666, 555, 444])).all() + assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all() + assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all() + + assert state_dict['trainer_options']['mixed_precision'] == False + assert state_dict['trainer_options']['world_rank'] == 0 + assert state_dict['trainer_options']['world_size'] == 1 + assert state_dict['trainer_options']['zero_stage'] == 0 + assert state_dict['trainer_options']['optimizer_name'] == b'Adam' + +@patch('onnxruntime.training._checkpoint_storage.load') +def test_checkpoint_aggregation_mixed_precision(load_mock): + trainer_options1 = { + 'mixed_precision': np.bool_(True), + 'world_rank': np.int64(0), + 'world_size': np.int64(2), + 'zero_stage': np.int64(1), + 'optimizer_name': b'Adam' + } + trainer_options2 = { + 'mixed_precision': np.bool_(True), + 'world_rank': np.int64(1), + 'world_size': np.int64(2), + 'zero_stage': np.int64(1), + 'optimizer_name': b'Adam' + } + state_dict1 = { 'model': { 'full_precision': { @@ -617,7 +714,7 @@ def test_checkpoint_aggregation(load_mock): } }, 'trainer_options': { - 'mixed_precision': np.bool_(False), + 'mixed_precision': np.bool_(True), 'world_rank': np.int64(0), 'world_size': np.int64(1), 'zero_stage': np.int64(0), @@ -648,7 +745,7 @@ def test_checkpoint_aggregation(load_mock): } }, 'trainer_options': { - 'mixed_precision': np.bool_(False), + 'mixed_precision': np.bool_(True), 'world_rank': np.int64(1), 'world_size': np.int64(1), 'zero_stage': np.int64(0), @@ -671,7 +768,7 @@ def test_checkpoint_aggregation(load_mock): assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all() assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all() - assert state_dict['trainer_options']['mixed_precision'] == False + assert state_dict['trainer_options']['mixed_precision'] == True assert state_dict['trainer_options']['world_rank'] == 0 assert state_dict['trainer_options']['world_size'] == 1 assert state_dict['trainer_options']['zero_stage'] == 0