aggregate model states only for the case when mixed precision was true (#6176)

This commit is contained in:
baijumeswani 2020-12-18 14:09:32 -08:00 committed by GitHub
parent 86493e6d0c
commit 39aedbc97f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 10 deletions

View file

@ -152,7 +152,7 @@ def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_di
# create a new entry for this state in the state_sub_dict
state_sub_dict[state_key] = state_value
def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict):
def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled):
"""Aggregates all model states from the rank_state_dict into state_dict"""
model = _utils.state_dict_model_key()
@ -172,7 +172,9 @@ def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state
# iterate over all model state keys
for model_state_key, model_state_value in rank_state_dict[model][full_precision].items():
if model_state_key in rank_state_dict[partition_info]:
# full precision model states are sharded only when they exist in the partition_info subdict and mixed
# precision training was enabled. for full precision training, full precision model states are not sharded
if mixed_precision_enabled and (model_state_key in rank_state_dict[partition_info]):
# this model state is sharded since a record exists in the partition_info subdict
_add_or_update_sharded_key_for_zero(model_state_key, model_state_value,
state_dict[model][full_precision], model_state_key,
@ -215,7 +217,7 @@ def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, s
state_dict[optimizer][model_state_key],
"Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key))
def _reshape_states(sharded_states_original_dims, state_dict):
def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled):
"""Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims"""
model = _utils.state_dict_model_key()
@ -224,8 +226,8 @@ def _reshape_states(sharded_states_original_dims, state_dict):
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
for sharded_state_key, original_dim in sharded_states_original_dims.items():
# reshape model states to original_dim
if model in state_dict:
# reshape model states to original_dim only when mixed precision is enabled
if mixed_precision_enabled and (model in state_dict):
state_dict[model][full_precision][sharded_state_key] = \
state_dict[model][full_precision][sharded_state_key].reshape(original_dim)
@ -315,7 +317,7 @@ def aggregate_checkpoints(paths, pytorch_format=True):
"Optimizer name mismatch among checkpoint files. File: {}".format(path)
# aggregate all model states
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict)
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision)
if not pytorch_format:
# aggregate all optimizer states if pytorch_format is False
@ -331,7 +333,7 @@ def aggregate_checkpoints(paths, pytorch_format=True):
state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()]
# reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims
_reshape_states(sharded_states_original_dims, state_dict)
_reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision)
# return a flat structure for PyTorch model in case pytorch_format is True
# else return the hierarchical structure for ORTTrainer

View file

@ -597,6 +597,103 @@ def test_checkpoint_aggregation(load_mock):
'optimizer_name': b'Adam'
}
state_dict1 = {
'model': {
'full_precision': {
'optimizer_sharded': np.array([1, 2, 3]),
'non_sharded': np.array([11, 22, 33])
}
},
'optimizer': {
'optimizer_sharded': {
'Moment_1': np.array([9, 8, 7]),
'Moment_2': np.array([99, 88, 77]),
'Step': np.array([5])
},
'non_sharded': {
'Moment_1': np.array([666, 555, 444]),
'Moment_2': np.array([6666, 5555, 4444]),
'Step': np.array([55])
}
},
'trainer_options': {
'mixed_precision': np.bool_(False),
'world_rank': np.int64(0),
'world_size': np.int64(1),
'zero_stage': np.int64(0),
'optimizer_name': b'Adam'
},
'partition_info': {
'optimizer_sharded': {'original_dim': np.array([2, 3])}
}
}
state_dict2 = {
'model': {
'full_precision': {
'optimizer_sharded': np.array([1, 2, 3]),
'non_sharded': np.array([11, 22, 33])
}
},
'optimizer': {
'optimizer_sharded': {
'Moment_1': np.array([6, 5, 4]),
'Moment_2': np.array([66, 55, 44]),
'Step': np.array([5])
},
'non_sharded': {
'Moment_1': np.array([666, 555, 444]),
'Moment_2': np.array([6666, 5555, 4444]),
'Step': np.array([55])
}
},
'trainer_options': {
'mixed_precision': np.bool_(False),
'world_rank': np.int64(1),
'world_size': np.int64(1),
'zero_stage': np.int64(0),
'optimizer_name': b'Adam'
},
'partition_info': {
'optimizer_sharded': {'original_dim': np.array([2, 3])}
}
}
load_mock.side_effect = [trainer_options1, trainer_options2, state_dict1, state_dict2]
state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False)
assert (state_dict['model']['full_precision']['optimizer_sharded'] == np.array([1, 2, 3])).all()
assert (state_dict['model']['full_precision']['non_sharded'] == np.array([11, 22, 33])).all()
assert (state_dict['optimizer']['optimizer_sharded']['Moment_1'] == np.array([[9, 8, 7], [6, 5, 4]])).all()
assert (state_dict['optimizer']['optimizer_sharded']['Moment_2'] == np.array([[99, 88, 77], [66, 55, 44]])).all()
assert (state_dict['optimizer']['optimizer_sharded']['Step'] == np.array([5])).all()
assert (state_dict['optimizer']['non_sharded']['Moment_1'] == np.array([666, 555, 444])).all()
assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all()
assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all()
assert state_dict['trainer_options']['mixed_precision'] == False
assert state_dict['trainer_options']['world_rank'] == 0
assert state_dict['trainer_options']['world_size'] == 1
assert state_dict['trainer_options']['zero_stage'] == 0
assert state_dict['trainer_options']['optimizer_name'] == b'Adam'
@patch('onnxruntime.training._checkpoint_storage.load')
def test_checkpoint_aggregation_mixed_precision(load_mock):
trainer_options1 = {
'mixed_precision': np.bool_(True),
'world_rank': np.int64(0),
'world_size': np.int64(2),
'zero_stage': np.int64(1),
'optimizer_name': b'Adam'
}
trainer_options2 = {
'mixed_precision': np.bool_(True),
'world_rank': np.int64(1),
'world_size': np.int64(2),
'zero_stage': np.int64(1),
'optimizer_name': b'Adam'
}
state_dict1 = {
'model': {
'full_precision': {
@ -617,7 +714,7 @@ def test_checkpoint_aggregation(load_mock):
}
},
'trainer_options': {
'mixed_precision': np.bool_(False),
'mixed_precision': np.bool_(True),
'world_rank': np.int64(0),
'world_size': np.int64(1),
'zero_stage': np.int64(0),
@ -648,7 +745,7 @@ def test_checkpoint_aggregation(load_mock):
}
},
'trainer_options': {
'mixed_precision': np.bool_(False),
'mixed_precision': np.bool_(True),
'world_rank': np.int64(1),
'world_size': np.int64(1),
'zero_stage': np.int64(0),
@ -671,7 +768,7 @@ def test_checkpoint_aggregation(load_mock):
assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all()
assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all()
assert state_dict['trainer_options']['mixed_precision'] == False
assert state_dict['trainer_options']['mixed_precision'] == True
assert state_dict['trainer_options']['world_rank'] == 0
assert state_dict['trainer_options']['world_size'] == 1
assert state_dict['trainer_options']['zero_stage'] == 0