mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
aggregate model states only for the case when mixed precision was true (#6176)
This commit is contained in:
parent
86493e6d0c
commit
39aedbc97f
2 changed files with 109 additions and 10 deletions
|
|
@ -152,7 +152,7 @@ def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_di
|
|||
# create a new entry for this state in the state_sub_dict
|
||||
state_sub_dict[state_key] = state_value
|
||||
|
||||
def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict):
|
||||
def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled):
|
||||
"""Aggregates all model states from the rank_state_dict into state_dict"""
|
||||
|
||||
model = _utils.state_dict_model_key()
|
||||
|
|
@ -172,7 +172,9 @@ def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state
|
|||
|
||||
# iterate over all model state keys
|
||||
for model_state_key, model_state_value in rank_state_dict[model][full_precision].items():
|
||||
if model_state_key in rank_state_dict[partition_info]:
|
||||
# full precision model states are sharded only when they exist in the partition_info subdict and mixed
|
||||
# precision training was enabled. for full precision training, full precision model states are not sharded
|
||||
if mixed_precision_enabled and (model_state_key in rank_state_dict[partition_info]):
|
||||
# this model state is sharded since a record exists in the partition_info subdict
|
||||
_add_or_update_sharded_key_for_zero(model_state_key, model_state_value,
|
||||
state_dict[model][full_precision], model_state_key,
|
||||
|
|
@ -215,7 +217,7 @@ def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, s
|
|||
state_dict[optimizer][model_state_key],
|
||||
"Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key))
|
||||
|
||||
def _reshape_states(sharded_states_original_dims, state_dict):
|
||||
def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled):
|
||||
"""Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims"""
|
||||
|
||||
model = _utils.state_dict_model_key()
|
||||
|
|
@ -224,8 +226,8 @@ def _reshape_states(sharded_states_original_dims, state_dict):
|
|||
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
|
||||
|
||||
for sharded_state_key, original_dim in sharded_states_original_dims.items():
|
||||
# reshape model states to original_dim
|
||||
if model in state_dict:
|
||||
# reshape model states to original_dim only when mixed precision is enabled
|
||||
if mixed_precision_enabled and (model in state_dict):
|
||||
state_dict[model][full_precision][sharded_state_key] = \
|
||||
state_dict[model][full_precision][sharded_state_key].reshape(original_dim)
|
||||
|
||||
|
|
@ -315,7 +317,7 @@ def aggregate_checkpoints(paths, pytorch_format=True):
|
|||
"Optimizer name mismatch among checkpoint files. File: {}".format(path)
|
||||
|
||||
# aggregate all model states
|
||||
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict)
|
||||
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision)
|
||||
|
||||
if not pytorch_format:
|
||||
# aggregate all optimizer states if pytorch_format is False
|
||||
|
|
@ -331,7 +333,7 @@ def aggregate_checkpoints(paths, pytorch_format=True):
|
|||
state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()]
|
||||
|
||||
# reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims
|
||||
_reshape_states(sharded_states_original_dims, state_dict)
|
||||
_reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision)
|
||||
|
||||
# return a flat structure for PyTorch model in case pytorch_format is True
|
||||
# else return the hierarchical structure for ORTTrainer
|
||||
|
|
|
|||
|
|
@ -597,6 +597,103 @@ def test_checkpoint_aggregation(load_mock):
|
|||
'optimizer_name': b'Adam'
|
||||
}
|
||||
|
||||
state_dict1 = {
|
||||
'model': {
|
||||
'full_precision': {
|
||||
'optimizer_sharded': np.array([1, 2, 3]),
|
||||
'non_sharded': np.array([11, 22, 33])
|
||||
}
|
||||
},
|
||||
'optimizer': {
|
||||
'optimizer_sharded': {
|
||||
'Moment_1': np.array([9, 8, 7]),
|
||||
'Moment_2': np.array([99, 88, 77]),
|
||||
'Step': np.array([5])
|
||||
},
|
||||
'non_sharded': {
|
||||
'Moment_1': np.array([666, 555, 444]),
|
||||
'Moment_2': np.array([6666, 5555, 4444]),
|
||||
'Step': np.array([55])
|
||||
}
|
||||
},
|
||||
'trainer_options': {
|
||||
'mixed_precision': np.bool_(False),
|
||||
'world_rank': np.int64(0),
|
||||
'world_size': np.int64(1),
|
||||
'zero_stage': np.int64(0),
|
||||
'optimizer_name': b'Adam'
|
||||
},
|
||||
'partition_info': {
|
||||
'optimizer_sharded': {'original_dim': np.array([2, 3])}
|
||||
}
|
||||
}
|
||||
|
||||
state_dict2 = {
|
||||
'model': {
|
||||
'full_precision': {
|
||||
'optimizer_sharded': np.array([1, 2, 3]),
|
||||
'non_sharded': np.array([11, 22, 33])
|
||||
}
|
||||
},
|
||||
'optimizer': {
|
||||
'optimizer_sharded': {
|
||||
'Moment_1': np.array([6, 5, 4]),
|
||||
'Moment_2': np.array([66, 55, 44]),
|
||||
'Step': np.array([5])
|
||||
},
|
||||
'non_sharded': {
|
||||
'Moment_1': np.array([666, 555, 444]),
|
||||
'Moment_2': np.array([6666, 5555, 4444]),
|
||||
'Step': np.array([55])
|
||||
}
|
||||
},
|
||||
'trainer_options': {
|
||||
'mixed_precision': np.bool_(False),
|
||||
'world_rank': np.int64(1),
|
||||
'world_size': np.int64(1),
|
||||
'zero_stage': np.int64(0),
|
||||
'optimizer_name': b'Adam'
|
||||
},
|
||||
'partition_info': {
|
||||
'optimizer_sharded': {'original_dim': np.array([2, 3])}
|
||||
}
|
||||
}
|
||||
|
||||
load_mock.side_effect = [trainer_options1, trainer_options2, state_dict1, state_dict2]
|
||||
state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False)
|
||||
|
||||
assert (state_dict['model']['full_precision']['optimizer_sharded'] == np.array([1, 2, 3])).all()
|
||||
assert (state_dict['model']['full_precision']['non_sharded'] == np.array([11, 22, 33])).all()
|
||||
assert (state_dict['optimizer']['optimizer_sharded']['Moment_1'] == np.array([[9, 8, 7], [6, 5, 4]])).all()
|
||||
assert (state_dict['optimizer']['optimizer_sharded']['Moment_2'] == np.array([[99, 88, 77], [66, 55, 44]])).all()
|
||||
assert (state_dict['optimizer']['optimizer_sharded']['Step'] == np.array([5])).all()
|
||||
assert (state_dict['optimizer']['non_sharded']['Moment_1'] == np.array([666, 555, 444])).all()
|
||||
assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all()
|
||||
assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all()
|
||||
|
||||
assert state_dict['trainer_options']['mixed_precision'] == False
|
||||
assert state_dict['trainer_options']['world_rank'] == 0
|
||||
assert state_dict['trainer_options']['world_size'] == 1
|
||||
assert state_dict['trainer_options']['zero_stage'] == 0
|
||||
assert state_dict['trainer_options']['optimizer_name'] == b'Adam'
|
||||
|
||||
@patch('onnxruntime.training._checkpoint_storage.load')
|
||||
def test_checkpoint_aggregation_mixed_precision(load_mock):
|
||||
trainer_options1 = {
|
||||
'mixed_precision': np.bool_(True),
|
||||
'world_rank': np.int64(0),
|
||||
'world_size': np.int64(2),
|
||||
'zero_stage': np.int64(1),
|
||||
'optimizer_name': b'Adam'
|
||||
}
|
||||
trainer_options2 = {
|
||||
'mixed_precision': np.bool_(True),
|
||||
'world_rank': np.int64(1),
|
||||
'world_size': np.int64(2),
|
||||
'zero_stage': np.int64(1),
|
||||
'optimizer_name': b'Adam'
|
||||
}
|
||||
|
||||
state_dict1 = {
|
||||
'model': {
|
||||
'full_precision': {
|
||||
|
|
@ -617,7 +714,7 @@ def test_checkpoint_aggregation(load_mock):
|
|||
}
|
||||
},
|
||||
'trainer_options': {
|
||||
'mixed_precision': np.bool_(False),
|
||||
'mixed_precision': np.bool_(True),
|
||||
'world_rank': np.int64(0),
|
||||
'world_size': np.int64(1),
|
||||
'zero_stage': np.int64(0),
|
||||
|
|
@ -648,7 +745,7 @@ def test_checkpoint_aggregation(load_mock):
|
|||
}
|
||||
},
|
||||
'trainer_options': {
|
||||
'mixed_precision': np.bool_(False),
|
||||
'mixed_precision': np.bool_(True),
|
||||
'world_rank': np.int64(1),
|
||||
'world_size': np.int64(1),
|
||||
'zero_stage': np.int64(0),
|
||||
|
|
@ -671,7 +768,7 @@ def test_checkpoint_aggregation(load_mock):
|
|||
assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all()
|
||||
assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all()
|
||||
|
||||
assert state_dict['trainer_options']['mixed_precision'] == False
|
||||
assert state_dict['trainer_options']['mixed_precision'] == True
|
||||
assert state_dict['trainer_options']['world_rank'] == 0
|
||||
assert state_dict['trainer_options']['world_size'] == 1
|
||||
assert state_dict['trainer_options']['zero_stage'] == 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue