mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Description: Format all python files under onnxruntime with black and isort. After checking in, we can use .git-blame-ignore-revs to ignore the formatting PR in git blame. #11315, #11316
748 lines
33 KiB
Python
748 lines
33 KiB
Python
import numpy as np
|
|
import onnx
|
|
import os
|
|
import torch
|
|
import warnings
|
|
import tempfile
|
|
from enum import Enum
|
|
from . import _checkpoint_storage, _utils
|
|
|
|
|
|
################################################################################
|
|
# Experimental Checkpoint APIs
|
|
################################################################################
|
|
|
|
|
|
def experimental_state_dict(ort_trainer, include_optimizer_state=True):
|
|
warnings.warn(
|
|
"experimental_state_dict() will be deprecated soon. " "Please use ORTTrainer.state_dict() instead.",
|
|
DeprecationWarning,
|
|
)
|
|
|
|
if not ort_trainer._training_session:
|
|
warnings.warn(
|
|
"ONNX Runtime training session is not initialized yet. "
|
|
"Please run train_step or eval_step at least once before calling state_dict()."
|
|
)
|
|
return ort_trainer._state_dict
|
|
|
|
# extract trained weights
|
|
session_state = ort_trainer._training_session.get_state()
|
|
torch_state = {}
|
|
for name in session_state:
|
|
torch_state[name] = torch.from_numpy(session_state[name])
|
|
|
|
# extract untrained weights and buffer
|
|
for n in ort_trainer._onnx_model.graph.initializer:
|
|
if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights:
|
|
torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n)))
|
|
|
|
# Need to remove redundant (optimizer) initializers to map back to original torch state names
|
|
if not include_optimizer_state and ort_trainer._torch_state_dict_keys:
|
|
return {key: torch_state[key] for key in ort_trainer._torch_state_dict_keys if key in torch_state}
|
|
return torch_state
|
|
|
|
|
|
def experimental_load_state_dict(ort_trainer, state_dict, strict=False):
|
|
warnings.warn(
|
|
"experimental_load_state_dict() will be deprecated soon. " "Please use ORTTrainer.load_state_dict() instead.",
|
|
DeprecationWarning,
|
|
)
|
|
|
|
# Note: It may happen ONNX model has not yet been initialized
|
|
# In this case we cache a reference to desired state and delay the restore until after initialization
|
|
# Unexpected behavior will result if the user changes the reference before initialization
|
|
if not ort_trainer._training_session:
|
|
ort_trainer._state_dict = state_dict
|
|
ort_trainer._load_state_dict_strict = strict
|
|
return
|
|
|
|
# Update onnx model from loaded state dict
|
|
cur_initializers_names = [n.name for n in ort_trainer._onnx_model.graph.initializer]
|
|
new_initializers = {}
|
|
|
|
for name in state_dict:
|
|
if name in cur_initializers_names:
|
|
new_initializers[name] = state_dict[name].numpy()
|
|
elif strict:
|
|
raise RuntimeError("Checkpoint tensor: {} is not present in the model.".format(name))
|
|
|
|
ort_trainer._update_onnx_model_initializers(new_initializers)
|
|
|
|
# create new session based on updated onnx model
|
|
ort_trainer._state_dict = None
|
|
ort_trainer._init_session()
|
|
|
|
# load training state
|
|
session_state = {name: state_dict[name].numpy() for name in state_dict}
|
|
ort_trainer._training_session.load_state(session_state, strict)
|
|
|
|
|
|
def experimental_save_checkpoint(
|
|
ort_trainer,
|
|
checkpoint_dir,
|
|
checkpoint_prefix="ORT_checkpoint",
|
|
checkpoint_state_dict=None,
|
|
include_optimizer_state=True,
|
|
):
|
|
warnings.warn(
|
|
"experimental_save_checkpoint() will be deprecated soon. " "Please use ORTTrainer.save_checkpoint() instead.",
|
|
DeprecationWarning,
|
|
)
|
|
|
|
if checkpoint_state_dict is None:
|
|
checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)}
|
|
else:
|
|
checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)})
|
|
|
|
assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist"
|
|
|
|
checkpoint_name = _get_checkpoint_name(
|
|
checkpoint_prefix,
|
|
ort_trainer.options.distributed.deepspeed_zero_optimization.stage,
|
|
ort_trainer.options.distributed.world_rank,
|
|
ort_trainer.options.distributed.world_size,
|
|
)
|
|
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
|
if os.path.exists(checkpoint_file):
|
|
msg = f"{checkpoint_file} already exists, overwriting."
|
|
warnings.warn(msg)
|
|
torch.save(checkpoint_state_dict, checkpoint_file)
|
|
|
|
|
|
def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False):
|
|
warnings.warn(
|
|
"experimental_load_checkpoint() will be deprecated soon. " "Please use ORTTrainer.load_checkpoint() instead.",
|
|
DeprecationWarning,
|
|
)
|
|
|
|
checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
|
|
is_partitioned = False
|
|
if len(checkpoint_files) > 1:
|
|
msg = (
|
|
f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}."
|
|
" Attempting to load ZeRO checkpoint."
|
|
)
|
|
warnings.warn(msg)
|
|
is_partitioned = True
|
|
if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned:
|
|
return _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict)
|
|
else:
|
|
return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
|
|
|
|
|
|
class _AGGREGATION_MODE(Enum):
|
|
Zero = 0
|
|
Megatron = 1
|
|
|
|
|
|
def _order_paths(paths, D_groups, H_groups):
|
|
"""Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively
|
|
and returns the ordered dict"""
|
|
|
|
trainer_options_path_tuples = []
|
|
world_rank = _utils.state_dict_trainer_options_world_rank_key()
|
|
|
|
for path in paths:
|
|
trainer_options_path_tuples.append(
|
|
(_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path)
|
|
)
|
|
|
|
# sort paths according to rank
|
|
sorted_paths = [
|
|
path
|
|
for _, path in sorted(
|
|
trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank]
|
|
)
|
|
]
|
|
|
|
ordered_paths = dict()
|
|
ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))]
|
|
ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))]
|
|
|
|
return ordered_paths
|
|
|
|
|
|
def _add_or_update_sharded_key(
|
|
state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode
|
|
):
|
|
"""Add or update the record for the sharded state_key in the state_sub_dict"""
|
|
|
|
# record the original dimension for this state
|
|
original_dim = _utils.state_dict_original_dimension_key()
|
|
sharded_states_original_dims[model_state_key] = state_partition_info[original_dim]
|
|
|
|
axis = 0
|
|
if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0:
|
|
axis = -1
|
|
|
|
if state_key in state_sub_dict:
|
|
# state_dict already contains a record for this state
|
|
# since this state is sharded, concatenate the state value to
|
|
# the record in the state_dict
|
|
state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis)
|
|
else:
|
|
# create a new entry for this state in the state_dict
|
|
state_sub_dict[state_key] = state_value
|
|
|
|
|
|
def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string):
|
|
"""Add or validate the record for the unsharded state_key in the state_sub_dict"""
|
|
|
|
if state_key in state_sub_dict:
|
|
# state_dict already contains a record for this unsharded state.
|
|
# assert that all values are the same for this previously loaded state
|
|
assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string
|
|
else:
|
|
# create a new entry for this state in the state_sub_dict
|
|
state_sub_dict[state_key] = state_value
|
|
|
|
|
|
def _aggregate_model_states(
|
|
rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero
|
|
):
|
|
"""Aggregates all model states from the rank_state_dict into state_dict"""
|
|
|
|
model = _utils.state_dict_model_key()
|
|
full_precision = _utils.state_dict_full_precision_key()
|
|
partition_info = _utils.state_dict_partition_info_key()
|
|
|
|
# if there are no model states in the rank_state_dict, no model aggregation is needed
|
|
if model not in rank_state_dict:
|
|
return
|
|
|
|
if model not in state_dict:
|
|
state_dict[model] = {}
|
|
|
|
if full_precision not in state_dict[model]:
|
|
state_dict[model][full_precision] = {}
|
|
|
|
# iterate over all model state keys
|
|
for model_state_key, model_state_value in rank_state_dict[model][full_precision].items():
|
|
# ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed
|
|
# precision training was enabled. for full precision training, full precision model states are not sharded
|
|
# MEGATRON : full precision model states are sharded when they exist in the partition_info subdict
|
|
if (model_state_key in rank_state_dict[partition_info]) and (
|
|
mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled
|
|
):
|
|
# this model state is sharded
|
|
_add_or_update_sharded_key(
|
|
model_state_key,
|
|
model_state_value,
|
|
state_dict[model][full_precision],
|
|
model_state_key,
|
|
rank_state_dict[partition_info][model_state_key],
|
|
sharded_states_original_dims,
|
|
mode,
|
|
)
|
|
else:
|
|
# this model state is not sharded since a record for it does not exist in the partition_info subdict
|
|
_add_or_validate_unsharded_key(
|
|
model_state_key,
|
|
model_state_value,
|
|
state_dict[model][full_precision],
|
|
"Value mismatch for model state {}".format(model_state_key),
|
|
)
|
|
|
|
|
|
def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero):
|
|
"""Aggregates all optimizer states from the rank_state_dict into state_dict"""
|
|
|
|
optimizer = _utils.state_dict_optimizer_key()
|
|
partition_info = _utils.state_dict_partition_info_key()
|
|
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
|
|
|
|
# if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed
|
|
if optimizer not in rank_state_dict:
|
|
return
|
|
|
|
if optimizer not in state_dict:
|
|
state_dict[optimizer] = {}
|
|
|
|
# iterate over all optimizer state keys
|
|
for model_state_key, optimizer_dict in rank_state_dict[optimizer].items():
|
|
for optimizer_key, optimizer_value in optimizer_dict.items():
|
|
if model_state_key not in state_dict[optimizer]:
|
|
state_dict[optimizer][model_state_key] = {}
|
|
|
|
if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]:
|
|
# this optimizer state is sharded since a record exists in the partition_info subdict
|
|
_add_or_update_sharded_key(
|
|
optimizer_key,
|
|
optimizer_value,
|
|
state_dict[optimizer][model_state_key],
|
|
model_state_key,
|
|
rank_state_dict[partition_info][model_state_key],
|
|
sharded_states_original_dims,
|
|
mode,
|
|
)
|
|
else:
|
|
# this optimizer state is not sharded since a record for it does not exist in the partition_info subdict
|
|
# or this optimizer key is not one of the sharded optimizer keys
|
|
_add_or_validate_unsharded_key(
|
|
optimizer_key,
|
|
optimizer_value,
|
|
state_dict[optimizer][model_state_key],
|
|
"Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key),
|
|
)
|
|
|
|
|
|
def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled):
|
|
"""Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims"""
|
|
|
|
model = _utils.state_dict_model_key()
|
|
full_precision = _utils.state_dict_full_precision_key()
|
|
optimizer = _utils.state_dict_optimizer_key()
|
|
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
|
|
|
|
for sharded_state_key, original_dim in sharded_states_original_dims.items():
|
|
# reshape model states to original_dim only when mixed precision is enabled
|
|
if mixed_precision_enabled and (model in state_dict):
|
|
state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][
|
|
sharded_state_key
|
|
].reshape(original_dim)
|
|
|
|
# reshape optimizer states to original_dim
|
|
if optimizer in state_dict:
|
|
for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items():
|
|
if optimizer_key in sharded_optimizer_keys:
|
|
state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim)
|
|
|
|
|
|
def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation):
|
|
"""Extracts trainer options from rank_state_dict and loads them accordingly on state_dict"""
|
|
trainer_options = _utils.state_dict_trainer_options_key()
|
|
state_dict[trainer_options] = {}
|
|
|
|
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
|
|
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
|
|
world_rank = _utils.state_dict_trainer_options_world_rank_key()
|
|
world_size = _utils.state_dict_trainer_options_world_size_key()
|
|
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()
|
|
D_size = _utils.state_dict_trainer_options_data_parallel_size_key()
|
|
H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key()
|
|
|
|
state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision]
|
|
state_dict[trainer_options][zero_stage] = 0
|
|
state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0
|
|
state_dict[trainer_options][world_size] = 1
|
|
state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name]
|
|
state_dict[trainer_options][D_size] = 1
|
|
state_dict[trainer_options][H_size] = 1
|
|
|
|
|
|
def _aggregate_megatron_partition_info(rank_state_dict, state_dict):
|
|
"""Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights"""
|
|
partition_info = _utils.state_dict_partition_info_key()
|
|
if partition_info not in state_dict:
|
|
state_dict[partition_info] = {}
|
|
|
|
rank_partition_info = rank_state_dict[partition_info]
|
|
for model_state_key, partition_info_dict in rank_partition_info.items():
|
|
if model_state_key not in state_dict[partition_info]:
|
|
# add partition info only if weight is megatron partitioned
|
|
if partition_info_dict["megatron_row_partition"] >= 0:
|
|
state_dict[partition_info][model_state_key] = partition_info_dict
|
|
|
|
|
|
def _to_pytorch_format(state_dict):
|
|
"""Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)"""
|
|
|
|
pytorch_state_dict = {}
|
|
for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][
|
|
_utils.state_dict_full_precision_key()
|
|
].items():
|
|
# convert numpy array to a torch tensor
|
|
pytorch_state_dict[model_state_key] = torch.tensor(model_state_value)
|
|
return pytorch_state_dict
|
|
|
|
|
|
def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size):
|
|
"""Returns the D and H groups for the given sizes"""
|
|
num_data_groups = world_size // data_parallel_size
|
|
data_groups = []
|
|
for data_group_id in range(num_data_groups):
|
|
data_group_ranks = []
|
|
for r in range(data_parallel_size):
|
|
data_group_ranks.append(data_group_id + horizontal_parallel_size * r)
|
|
data_groups.append(data_group_ranks)
|
|
|
|
num_horizontal_groups = world_size // horizontal_parallel_size
|
|
horizontal_groups = []
|
|
for hori_group_id in range(num_horizontal_groups):
|
|
hori_group_ranks = []
|
|
for r in range(horizontal_parallel_size):
|
|
hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r)
|
|
horizontal_groups.append(hori_group_ranks)
|
|
|
|
return data_groups, horizontal_groups
|
|
|
|
|
|
def _aggregate_over_ranks(
|
|
ordered_paths,
|
|
ranks,
|
|
sharded_states_original_dims=None,
|
|
mode=_AGGREGATION_MODE.Zero,
|
|
partial_aggregation=False,
|
|
pytorch_format=True,
|
|
):
|
|
"""Aggregate checkpoint files over set of ranks and return a single state dictionary
|
|
|
|
Args:
|
|
ordered_paths: list of paths in the order in which they must be aggregated
|
|
ranks: list of ranks that are to be aggregated
|
|
sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over
|
|
multiple calls to _aggregate_over_ranks()
|
|
mode: mode of aggregation: Zero or Megatron
|
|
partial_aggregation: boolean flag to indicate whether to produce a partially
|
|
aggregated state which can be further aggregated over
|
|
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
|
|
Returns:
|
|
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
|
|
"""
|
|
state_dict = {}
|
|
if sharded_states_original_dims is None:
|
|
sharded_states_original_dims = dict()
|
|
world_rank = _utils.state_dict_trainer_options_world_rank_key()
|
|
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
|
|
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
|
|
world_size = _utils.state_dict_trainer_options_world_size_key()
|
|
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()
|
|
|
|
loaded_mixed_precision = None
|
|
loaded_world_size = None
|
|
loaded_zero_stage = None
|
|
loaded_optimizer_name = None
|
|
|
|
for i, path in enumerate(ordered_paths):
|
|
rank_state_dict = _checkpoint_storage.load(path)
|
|
|
|
assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info"
|
|
assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options"
|
|
assert (
|
|
ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank]
|
|
), "Unexpected rank in file at path {}. Expected {}, got {}".format(
|
|
path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank]
|
|
)
|
|
if loaded_mixed_precision is None:
|
|
loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
|
|
else:
|
|
assert (
|
|
loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
|
|
), "Mixed precision state mismatch among checkpoint files. File: {}".format(path)
|
|
if loaded_world_size is None:
|
|
loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
|
|
else:
|
|
assert (
|
|
loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
|
|
), "World size state mismatch among checkpoint files. File: {}".format(path)
|
|
if loaded_zero_stage is None:
|
|
loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
|
|
else:
|
|
assert (
|
|
loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
|
|
), "Zero stage mismatch among checkpoint files. File: {}".format(path)
|
|
if loaded_optimizer_name is None:
|
|
loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
|
|
else:
|
|
assert (
|
|
loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
|
|
), "Optimizer name mismatch among checkpoint files. File: {}".format(path)
|
|
|
|
# aggregate all model states
|
|
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode)
|
|
|
|
if not pytorch_format:
|
|
# aggregate all optimizer states if pytorch_format is False
|
|
_aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode)
|
|
|
|
# for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups
|
|
# to aggregate over Zero, and another pass to aggregate Megatron partitioned
|
|
# states. Preserve the relevant partition info only for weights that are megatron partitioned for
|
|
# a partial aggregation call
|
|
if partial_aggregation:
|
|
_aggregate_megatron_partition_info(rank_state_dict, state_dict)
|
|
|
|
# entry for trainer_options in the state_dict to perform other sanity checks
|
|
if _utils.state_dict_trainer_options_key() not in state_dict:
|
|
_aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation)
|
|
|
|
# entry for user_dict in the state_dict if not already present
|
|
if (
|
|
_utils.state_dict_user_dict_key() not in state_dict
|
|
and _utils.state_dict_user_dict_key() in rank_state_dict
|
|
):
|
|
state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()]
|
|
|
|
# for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape
|
|
if not partial_aggregation:
|
|
# reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims
|
|
_reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision)
|
|
|
|
# return a flat structure for PyTorch model in case pytorch_format is True
|
|
# else return the hierarchical structure for ORTTrainer
|
|
return _to_pytorch_format(state_dict) if pytorch_format else state_dict
|
|
|
|
|
|
def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format):
|
|
"""Aggregate checkpoint files and return a single state dictionary for the D+H
|
|
(Zero+Megatron) partitioning strategy.
|
|
For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups
|
|
to aggregate over Zero, and another pass over the previously aggregated states
|
|
to aggregate Megatron partitioned states.
|
|
"""
|
|
sharded_states_original_dims = {}
|
|
aggregate_data_checkpoint_files = []
|
|
|
|
# combine for Zero over data groups and save to temp file
|
|
with tempfile.TemporaryDirectory() as save_dir:
|
|
for group_id, d_group in enumerate(D_groups):
|
|
aggregate_state_dict = _aggregate_over_ranks(
|
|
ordered_paths["D"][group_id],
|
|
d_group,
|
|
sharded_states_original_dims,
|
|
partial_aggregation=True,
|
|
pytorch_format=False,
|
|
)
|
|
|
|
filename = "ort.data_group." + str(group_id) + ".ort.pt"
|
|
filepath = os.path.join(save_dir, filename)
|
|
_checkpoint_storage.save(aggregate_state_dict, filepath)
|
|
aggregate_data_checkpoint_files.append(filepath)
|
|
|
|
assert len(aggregate_data_checkpoint_files) > 0
|
|
|
|
# combine for megatron:
|
|
aggregate_state = _aggregate_over_ranks(
|
|
aggregate_data_checkpoint_files,
|
|
H_groups[0],
|
|
sharded_states_original_dims,
|
|
mode=_AGGREGATION_MODE.Megatron,
|
|
pytorch_format=pytorch_format,
|
|
)
|
|
|
|
return aggregate_state
|
|
|
|
|
|
def aggregate_checkpoints(paths, pytorch_format=True):
|
|
"""Aggregate checkpoint files and return a single state dictionary
|
|
|
|
Aggregates checkpoint files specified by paths and loads them one at a time, merging
|
|
them into a single state dictionary.
|
|
The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function.
|
|
The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict()
|
|
|
|
Args:
|
|
paths: list of more than one file represented as strings where the checkpoint is saved
|
|
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
|
|
Returns:
|
|
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
|
|
"""
|
|
|
|
loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key())
|
|
D_size = _utils.state_dict_trainer_options_data_parallel_size_key()
|
|
H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key()
|
|
world_size = _utils.state_dict_trainer_options_world_size_key()
|
|
|
|
D_size = loaded_trainer_options[D_size]
|
|
H_size = loaded_trainer_options[H_size]
|
|
world_size = loaded_trainer_options[world_size]
|
|
D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size)
|
|
|
|
combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0
|
|
combine_megatron = len(H_groups[0]) > 1
|
|
|
|
# order the paths in the order of groups in which they must be aggregated according to
|
|
# data-parallel groups and H-parallel groups obtained
|
|
# eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]}
|
|
ordered_paths = _order_paths(paths, D_groups, H_groups)
|
|
|
|
aggregate_state = None
|
|
if combine_zero and combine_megatron:
|
|
aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format)
|
|
elif combine_zero:
|
|
aggregate_state = _aggregate_over_ranks(
|
|
ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format
|
|
)
|
|
elif combine_megatron:
|
|
aggregate_state = _aggregate_over_ranks(
|
|
ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format
|
|
)
|
|
|
|
return aggregate_state
|
|
|
|
|
|
################################################################################
|
|
# Helper functions
|
|
################################################################################
|
|
|
|
|
|
def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict):
|
|
checkpoint_name = _get_checkpoint_name(
|
|
checkpoint_prefix,
|
|
is_partitioned,
|
|
ort_trainer.options.distributed.world_rank,
|
|
ort_trainer.options.distributed.world_size,
|
|
)
|
|
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
|
|
|
if is_partitioned:
|
|
assert_msg = (
|
|
f"Couldn't find checkpoint file {checkpoint_file}."
|
|
" Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists "
|
|
f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}"
|
|
)
|
|
else:
|
|
assert_msg = f"Couldn't find checkpoint file {checkpoint_file}."
|
|
assert os.path.exists(checkpoint_file), assert_msg
|
|
|
|
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
|
|
experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict)
|
|
del checkpoint_state["model"]
|
|
return checkpoint_state
|
|
|
|
|
|
def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict):
|
|
checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
|
|
|
|
ckpt_agg = _CombineZeroCheckpoint(checkpoint_files)
|
|
aggregate_state_dict = ckpt_agg.aggregate_checkpoints()
|
|
|
|
experimental_load_state_dict(ort_trainer, aggregate_state_dict, strict=strict)
|
|
|
|
# aggregate other keys in the state_dict.
|
|
# Values will be overwritten for matching keys among workers
|
|
all_checkpoint_states = dict()
|
|
for checkpoint_file in checkpoint_files:
|
|
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
|
|
del checkpoint_state["model"]
|
|
all_checkpoint_states.update(checkpoint_state)
|
|
return all_checkpoint_states
|
|
|
|
|
|
def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"):
|
|
ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)]
|
|
ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)]
|
|
ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names]
|
|
|
|
assert len(ckpt_file_names) > 0, f"No checkpoint found with prefix '{checkpoint_prefix}' at '{checkpoint_dir}'"
|
|
return ckpt_file_names
|
|
|
|
|
|
def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None):
|
|
SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt"
|
|
MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt"
|
|
|
|
if is_partitioned:
|
|
filename = MULTIPLE_CHECKPOINT_FILENAME.format(
|
|
prefix=prefix, world_rank=world_rank, world_size=(world_size - 1)
|
|
)
|
|
else:
|
|
filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix)
|
|
return filename
|
|
|
|
|
|
def _split_state_dict(state_dict):
|
|
optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"]
|
|
split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}}
|
|
for k, v in state_dict.items():
|
|
mode = "fp32_param"
|
|
for optim_key in optimizer_keys:
|
|
if k.startswith(optim_key):
|
|
mode = "optimizer"
|
|
break
|
|
if k.endswith("_fp16"):
|
|
mode = "fp16_param"
|
|
split_sd[mode][k] = v
|
|
return split_sd
|
|
|
|
|
|
class _CombineZeroCheckpoint(object):
|
|
def __init__(self, checkpoint_files, clean_state_dict=None):
|
|
|
|
assert len(checkpoint_files) > 0, "No checkpoint files passed"
|
|
self.checkpoint_files = checkpoint_files
|
|
self.clean_state_dict = clean_state_dict
|
|
self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1
|
|
assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files"
|
|
self.weight_shape_map = dict()
|
|
self.sharded_params = set()
|
|
|
|
def _split_name(self, name):
|
|
name_split = name.split("_view_")
|
|
view_num = None
|
|
if len(name_split) > 1:
|
|
view_num = int(name_split[1])
|
|
optimizer_key = ""
|
|
mp_suffix = ""
|
|
if name_split[0].startswith("Moment_1"):
|
|
optimizer_key = "Moment_1_"
|
|
elif name_split[0].startswith("Moment_2"):
|
|
optimizer_key = "Moment_2_"
|
|
elif name_split[0].startswith("Update_Count"):
|
|
optimizer_key = "Update_Count_"
|
|
elif name_split[0].endswith("_fp16"):
|
|
mp_suffix = "_fp16"
|
|
param_name = name_split[0]
|
|
if optimizer_key != "":
|
|
param_name = param_name.split(optimizer_key)[1]
|
|
param_name = param_name.split("_fp16")[0]
|
|
return param_name, optimizer_key, view_num, mp_suffix
|
|
|
|
def _update_weight_statistics(self, name, value):
|
|
if name not in self.weight_shape_map:
|
|
self.weight_shape_map[name] = value.size() # original shape of tensor
|
|
|
|
def _reshape_tensor(self, key):
|
|
value = self.aggregate_state_dict[key]
|
|
weight_name, _, _, _ = self._split_name(key)
|
|
set_size = self.weight_shape_map[weight_name]
|
|
self.aggregate_state_dict[key] = value.reshape(set_size)
|
|
|
|
def _aggregate(self, param_dict):
|
|
for k, v in param_dict.items():
|
|
weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k)
|
|
if view_num is not None:
|
|
# parameter is sharded
|
|
param_name = optimizer_key + weight_name + mp_suffix
|
|
|
|
if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]:
|
|
self.sharded_params.add(param_name)
|
|
# Found a previous shard of the param, concatenate shards ordered by ranks
|
|
self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v))
|
|
else:
|
|
self.aggregate_state_dict[param_name] = v
|
|
else:
|
|
if k in self.aggregate_state_dict:
|
|
assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value"
|
|
else:
|
|
self.aggregate_state_dict[k] = v
|
|
self._update_weight_statistics(weight_name, v)
|
|
|
|
def aggregate_checkpoints(self):
|
|
warnings.warn(
|
|
"_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. "
|
|
"Please use aggregate_checkpoints() instead.",
|
|
DeprecationWarning,
|
|
)
|
|
|
|
checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0]
|
|
self.aggregate_state_dict = dict()
|
|
|
|
for i in range(self.world_size):
|
|
checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size)
|
|
rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu"))
|
|
if "model" in rank_state_dict:
|
|
rank_state_dict = rank_state_dict["model"]
|
|
|
|
if self.clean_state_dict:
|
|
rank_state_dict = self.clean_state_dict(rank_state_dict)
|
|
|
|
rank_state_dict = _split_state_dict(rank_state_dict)
|
|
self._aggregate(rank_state_dict["fp16_param"])
|
|
self._aggregate(rank_state_dict["fp32_param"])
|
|
self._aggregate(rank_state_dict["optimizer"])
|
|
|
|
for k in self.sharded_params:
|
|
self._reshape_tensor(k)
|
|
return self.aggregate_state_dict
|