From dd2e5a1a052c0ef8241af6fe38404dc2b4d1310d Mon Sep 17 00:00:00 2001 From: baijumeswani Date: Mon, 14 Dec 2020 11:55:52 -0800 Subject: [PATCH] state_dict and load_state_dict for ORTTrainer (#6095) * add functions state_dict and load_state_dict to ORTTrainer * unit tests for state_dict and load_state_dict for ORTTrainer --- .../orttraining/python/training/_utils.py | 26 ++ .../orttraining/python/training/orttrainer.py | 372 +++++++++++++++- ...ng_test_orttrainer_checkpoint_functions.py | 409 ++++++++++++++++++ tools/ci_build/build.py | 3 + 4 files changed, 806 insertions(+), 4 deletions(-) create mode 100644 orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index ba21412b2a..4444e8327c 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -3,6 +3,7 @@ import numpy as np import os import sys import torch +from onnx import TensorProto def get_device_index(device): @@ -177,3 +178,28 @@ def import_module_from_file(file_path, module_name=None): sys.modules[module_name] = module spec.loader.exec_module(module) return module + +def state_dict_model_key(): + """Returns the model key name in the state dictionary""" + + return 'model' + +def state_dict_optimizer_key(): + """Returns the optimizer key name in the state dictionary""" + + return 'optimizer' + +def state_dict_partition_info_key(): + """Returns the partition info key name in the state dictionary""" + + return 'partition_info' + +def state_dict_trainer_options_key(): + """Returns the trainer options key name in the state dictionary""" + + return 'trainer_options' + +def state_dict_full_precision_key(): + """Returns the full precision key name in the state dictionary""" + + return 'fp32' diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 326640dca0..fc86cb48e0 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -5,6 +5,8 @@ import onnx import torch from inspect import signature import warnings +from functools import partial +import numpy as np import onnxruntime as ort from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions @@ -206,6 +208,7 @@ class ORTTrainer(object): self._train_step_info = TrainStepInfo(self.optim_config) self._training_session = None + self._load_state_dict = None self._init_session() def eval_step(self, *args, **kwargs): @@ -565,7 +568,7 @@ class ORTTrainer(object): return onnx_model - def _create_ort_training_session(self): + def _create_ort_training_session(self, state_dict = {}): # Validating frozen_weights names unused_frozen_weights = [n for n in self.options.utils.frozen_weights\ if n not in [i.name for i in self._onnx_model.graph.initializer]] @@ -635,6 +638,8 @@ class ORTTrainer(object): ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map if bool(self._optim_state_dict): ort_parameters.set_optimizer_initial_state(self._optim_state_dict) + if bool(state_dict) and bool(state_dict[_utils.state_dict_optimizer_key()]): + ort_parameters.set_optimizer_initial_state(state_dict[_utils.state_dict_optimizer_key()]) ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute @@ -682,9 +687,13 @@ class ORTTrainer(object): if self.options._internal_use.extra_postprocess: self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - self._init_session() + state_dict = {} + if self._load_state_dict: + state_dict = self._load_state_dict() - def _init_session(self): + self._init_session(state_dict) + + def _init_session(self, state_dict = {}): if self._onnx_model is None: return @@ -692,7 +701,8 @@ class ORTTrainer(object): self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True) # Create training session used by train_step - self._create_ort_training_session() + # pass all optimizer states to the backend + self._create_ort_training_session(state_dict) # Update model description to update dtype when mixed precision is enabled # C++ backend modifies model's output dtype from float32 to float16 for mixed precision @@ -843,3 +853,357 @@ class ORTTrainer(object): for w_i in replace_indices: del self._onnx_model.graph.initializer[w_i] self._onnx_model.graph.initializer.extend(new_weights) + + def _extract_model_states(self, state_dict, pytorch_format): + """Extract model states from the training session and load into the state_dict""" + + model_states = self._training_session.get_model_state(include_mixed_precision_weights=False) + state_dict[_utils.state_dict_model_key()] = {} + + # extract trained model weights from the training session + for precision in model_states: + state_dict[_utils.state_dict_model_key()][precision] = {} + for model_state_key in model_states[precision]: + if pytorch_format: + state_dict[_utils.state_dict_model_key()][precision][model_state_key] = \ + torch.from_numpy(model_states[precision][model_state_key]) + else: + state_dict[_utils.state_dict_model_key()][precision][model_state_key] = \ + model_states[precision][model_state_key] + + # extract untrained (frozen) model weights + for node in self._onnx_model.graph.initializer: + if node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] and \ + node.name in self.options.utils.frozen_weights: + if pytorch_format: + state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][node.name] = \ + torch.from_numpy(onnx.numpy_helper.to_array(node)) + else: + state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][node.name] = \ + onnx.numpy_helper.to_array(node) + + def _extract_trainer_options(self, state_dict): + """Extract relevant trainer configuration and load it into the state_dict""" + + state_dict[_utils.state_dict_trainer_options_key()] = {} + state_dict[_utils.state_dict_trainer_options_key()]['mixed_precision'] = self.options.mixed_precision.enabled + state_dict[_utils.state_dict_trainer_options_key()]['zero_stage'] = \ + self.options.distributed.deepspeed_zero_optimization.stage or 0 + state_dict[_utils.state_dict_trainer_options_key()]['world_rank'] = self.options.distributed.world_rank or 0 + state_dict[_utils.state_dict_trainer_options_key()]['world_size'] = self.options.distributed.world_size or 1 + state_dict[_utils.state_dict_trainer_options_key()]['optimizer_name'] = self.optim_config.name + + def state_dict(self, pytorch_format=False): + """Returns a dictionary with model, and optionally, optimizer states + + The returned dictionary contains the following information: + - Model and optimizer states + - Required ORTTrainerOptions settings + - Distributed training information, such as but not limited to ZeRO + + Structure of the returned dictionary: + - When `pytorch_format = False` + schema: + { + "model": + { + type: dict, + schema: + { + "fp32": + { + type: dict, + schema: + { + model_weight_name: + { + type: array + } + } + } + } + }, + "optimizer": + { + type: dict, + schema: + { + model_weight_name: + { + type: dict, + schema: + { + "Moment_1": + { + type: array + }, + "Moment_2": + { + type: array + }, + "Update_Count": + { + type: array, + optional: True # present if optimizer is adam, absent otherwise + } + } + }, + "shared_optimizer_state": + { + type: dict, + optional: True, # present optimizer is shared, absent otherwise. + schema: + { + "step": + { + type: array, + } + } + } + } + }, + "trainer_options": + { + type: dict, + schema: + { + "mixed_precision": + { + type: bool + }, + "zero_stage": + { + type: int + }, + "world_rank": + { + type: int + }, + "world_size": + { + type: int + }, + "optimizer_name": + { + type: str + } + } + }, + "partition_info": + { + type: dict, + optional: True, # present if states partitioned, else absent + schema: + { + model_weight_name: + { + type: dict, + schema: + { + "original_dim": + { + type: array + } + } + } + } + } + } + - When `pytorch_format = True` + schema: + { + model_weight_name: + { + type: tensor + } + } + + Args: + pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema + + Returns: + A dictionary with `ORTTrainer` state + """ + if not self._training_session: + warnings.warn("ONNX Runtime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().", + UserWarning) + return self._load_state_dict.args[0] if self._load_state_dict else {} + + state_dict = {} + + # load training session model states into the state_dict + self._extract_model_states(state_dict, pytorch_format) + if pytorch_format: + if self.options.distributed.deepspeed_zero_optimization.stage > 0: + warnings.warn("Incomplete state_dict: ZeRO enabled", UserWarning) + # if pytorch_format is true, return a flat dictionary with only model states + # which is compatible with a PyTorch model + return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] + + # load training session optimizer states into the state_dict + state_dict[_utils.state_dict_optimizer_key()] = self._training_session.get_optimizer_state() + + # extract the relevant training configuration from the trainer and load them into the state_dict + self._extract_trainer_options(state_dict) + + # add partition information in case of a distributed run + if self.options.distributed.deepspeed_zero_optimization.stage > 0: + state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map() + + return state_dict + + def _load_model_states(self, state_dict, strict): + """Load the model states onto the onnx model graph""" + + if _utils.state_dict_model_key() not in state_dict: + return + + # collect all initializer names from the current onnx graph + assert self._onnx_model, "ONNX model graph is not exported" + initializer_names = {node.name for node in self._onnx_model.graph.initializer} + + # loaded_initializers dict will be loaded with all the model states from the state dictionary + # that are found in the initializer_names dictionary + loaded_initializers = {} + + # copy over model states from the input state dict onto the onnx model + for precision, precision_states in state_dict[_utils.state_dict_model_key()].items(): + for state_key, state_value in precision_states.items(): + if state_key in initializer_names: + loaded_initializers[state_key] = state_value + elif strict: + raise RuntimeError("Unexpected key: {} in state_dict[model][{}]".format(state_key, precision)) + + # update onnx model from loaded initializers + self._update_onnx_model_initializers(loaded_initializers) + + def _load_optimizer_states(self, current_state_dict, state_dict): + """Load the optimizer states onto the training session state dictionary""" + + if _utils.state_dict_optimizer_key() not in state_dict: + return + + # create an entry for the optimizer in the training session state dictionary + if _utils.state_dict_optimizer_key() not in current_state_dict: + current_state_dict[_utils.state_dict_optimizer_key()] = {} + + # copy over optimizer states from the input state dict onto the training session state dict + for model_state_key, optimizer_dict in state_dict[_utils.state_dict_optimizer_key()].items(): + if model_state_key not in current_state_dict[_utils.state_dict_optimizer_key()]: + current_state_dict[_utils.state_dict_optimizer_key()][model_state_key] = {} + for optimizer_state_key, optimizer_state_value in optimizer_dict.items(): + current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][optimizer_state_key] = \ + optimizer_state_value + + + def _load_state_dict_impl(self, state_dict, strict=True): + """Load the state dictionary onto the onnx model and on the training session graph""" + + # clear the callable partial + self._load_state_dict = None + + def _mismatch_keys(keys1, keys2, in_error_str): + """Find out the missing and the unexpected keys in two dictionaries + + Throws a runtime error if missing or unexpected keys are found + - Keys in keys1 not in keys2 will be marked as missing + - Keys in keys2 not in keys1 will be marked as unexpected + """ + keys1 = set(keys1) + keys2 = set(keys2) + missing_keys = list(keys1 - keys2) + unexpected_keys = list(keys2 - keys1) + if len(missing_keys) > 0: + raise RuntimeError("Missing keys: {} in {}".format(missing_keys, in_error_str)) + if len(unexpected_keys) > 0: + raise RuntimeError("Unexpected keys: {} in {}".format(unexpected_keys, in_error_str)) + + def _check_model_key_mismatch(current_state_dict, state_dict): + """Check if there is any mismatch in the model sub state dictionary between the two state_dicts""" + + # check unxexpected and missing precision keys in the model state_dict compared to the training + # session model state_dict + _mismatch_keys(current_state_dict[_utils.state_dict_model_key()], + state_dict[_utils.state_dict_model_key()], 'state_dict[model]') + + # check for model state key mismatch + for precision_key in current_state_dict[_utils.state_dict_model_key()]: + _mismatch_keys(current_state_dict[_utils.state_dict_model_key()][precision_key], + state_dict[_utils.state_dict_model_key()][precision_key], + 'state_dict[model][{}]'.format(precision_key)) + + def _check_optimizer_key_mismatch(current_state_dict, state_dict): + """Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts""" + + # check for model state key mismatch for the optimizer state_dict + _mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()], + state_dict[_utils.state_dict_optimizer_key()], + 'state_dict[optimizer]') + + # check for optimizer state keys mismatch + for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]: + _mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], + state_dict[_utils.state_dict_optimizer_key()][model_state_key], + 'state_dict[optimizer][{}]'.format(model_state_key)) + + def _check_key_mismatch(current_state_dict, state_dict): + """Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts""" + + # check presence of 'model' in the input state_dict + if _utils.state_dict_model_key() in state_dict: + _check_model_key_mismatch(current_state_dict, state_dict) + else: + warnings.warn("Missing key: model in state_dict", UserWarning) + # check presence of 'optimizer' in the input state_dict + if _utils.state_dict_optimizer_key() in state_dict: + _check_optimizer_key_mismatch(current_state_dict, state_dict) + else: + warnings.warn("Missing key: optimizer in state_dict", UserWarning) + + # extract state dict from the current training session. this is to persist the states between + # two training sessions. + # for example, if user provided only the model states, the optimizer states from the current + # training session must be persisted + current_state_dict = {} + if self._training_session: + current_state_dict = self.state_dict() + if strict: + _check_key_mismatch(current_state_dict, state_dict) + + # load the model states from the input state dictionary into the onnx graph + self._load_model_states(state_dict, strict) + + # load the optimizer states from the input state dictionary into the training session states + # dictionary + self._load_optimizer_states(current_state_dict, state_dict) + + return current_state_dict + + def load_state_dict(self, state_dict, strict=True): + """Loads state_dict containing model/optimizer states into ORTTrainer + + The state_dict dictionary may contain the following information: + - Model and optimizer states + - Required ORTTrainerOptions settings + - Distributed training information, such as but not limited to ZeRO + + Args: + state_dict: state dictionary containing both model and optimizer states. The structure of this dictionary + should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False + strict: boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict + """ + + # if onnx graph has not been initialized, loading of states will be put on hold. + # a copy of the state_dict and other arguments to the function will be stored until the onnx graph has + # been initialized. Once the graph is initialized, the desired states will be loaded onto the grpah + if not self._training_session: + self._load_state_dict = partial(self._load_state_dict_impl, state_dict, strict=strict) + return + + # load states onto the frontend onnx graph + state_dict = self._load_state_dict_impl(state_dict, strict=strict) + + # create a new training session after loading initializer states onto the onnx graph + # pass the populated states to the training session to populate the backend graph + self._init_session(state_dict) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py new file mode 100644 index 0000000000..f3359f6f73 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py @@ -0,0 +1,409 @@ +import pytest +from unittest.mock import patch, Mock +from orttraining_test_orttrainer_frontend import _load_pytorch_transformer_model +from onnxruntime.training import amp, checkpoint, optim, orttrainer +import numpy as np +import onnx +import torch + +# Helper functions + +def _create_trainer(zero_enabled=False): + """Cerates a simple ORTTrainer for ORTTrainer functional tests""" + + device = 'cuda' + optim_config = optim.LambConfig(lr=0.1) + opts = { + 'device' : {'id' : device}, + 'debug' : {'deterministic_compute': True} + } + if zero_enabled: + opts['distributed'] = { + 'world_rank' : 0, + 'world_size' : 1, + 'allreduce_post_accumulation' : True, + 'deepspeed_zero_optimization': + { + 'stage': 1 + } + } + model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)) + + return trainer + +class _training_session_mock(object): + """Mock object for the ORTTrainer _training_session member""" + + def __init__(self, model_states, optimizer_states, partition_info): + self.model_states = model_states + self.optimizer_states = optimizer_states + self.partition_info = partition_info + + def get_model_state(self, include_mixed_precision_weights=False): + return self.model_states + + def get_optimizer_state(self): + return self.optimizer_states + + def get_partition_info_map(self): + return self.partition_info + +def _get_load_state_dict_strict_error_arguments(): + """Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing + + Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments) + The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and + throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys. + """ + + training_session_state_dict = { + 'model': { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + }, + 'optimizer': { + 'a': { + 'Moment_1': np.arange(5), + 'Moment_2': np.arange(7) + }, + 'shared_optimizer_state': { + 'step': np.arange(5) + } + } + } + + # input state dictionaries + precision_key_missing = {'model': {}, 'optimizer': {}} + precision_key_unexpected = {'model': {'fp32': {}, 'fp16': {}}, 'optimizer': {}} + model_state_key_missing = {'model': {'fp32': {}}, 'optimizer': {}} + model_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3, 'c': 4}}, 'optimizer': {}} + optimizer_model_state_key_missing = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': {}} + optimizer_model_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \ + {'a': {}, 'shared_optimizer_state': {}, 'b': {}}} + optimizer_state_key_missing = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \ + {'a': {}, 'shared_optimizer_state': {'step': np.arange(5)}}} + optimizer_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \ + {'a': {'Moment_1': np.arange(5), 'Moment_2': np.arange(7)}, 'shared_optimizer_state': {'step': np.arange(5), 'another_step': np.arange(1)}}} + + input_arguments = [ + (training_session_state_dict, precision_key_missing, ['fp32']), + (training_session_state_dict, precision_key_unexpected, ['fp16']), + (training_session_state_dict, model_state_key_missing, ['a', 'b']), + (training_session_state_dict, model_state_key_unexpected, ['c']), + (training_session_state_dict, optimizer_model_state_key_missing, ['a', 'shared_optimizer_state']), + (training_session_state_dict, optimizer_model_state_key_unexpected, ['b']), + (training_session_state_dict, optimizer_state_key_missing, ['Moment_1', 'Moment_2']), + (training_session_state_dict, optimizer_state_key_unexpected, ['another_step']) + ] + + return input_arguments + +# Tests + +def test_empty_state_dict_when_training_session_uninitialized(): + trainer = _create_trainer() + with pytest.warns(UserWarning) as user_warning: + state_dict = trainer.state_dict() + + assert len(state_dict.keys()) == 0 + assert user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " \ + "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()." + +@patch('onnx.ModelProto') +def test_training_session_provides_empty_model_states(onnx_model_mock): + trainer = _create_trainer() + training_session_mock = _training_session_mock({}, {}, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict() + assert len(state_dict['model'].keys()) == 0 + +@patch('onnx.ModelProto') +def test_training_session_provides_model_states(onnx_model_mock): + trainer = _create_trainer() + model_states = { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + } + training_session_mock = _training_session_mock(model_states, {}, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict() + assert (state_dict['model']['fp32']['a'] == np.arange(5)).all() + assert (state_dict['model']['fp32']['b'] == np.arange(7)).all() + +@patch('onnx.ModelProto') +def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): + trainer = _create_trainer() + model_states = { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + } + training_session_mock = _training_session_mock(model_states, {}, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict(pytorch_format=True) + assert torch.all(torch.eq(state_dict['a'], torch.tensor(np.arange(5)))) + assert torch.all(torch.eq(state_dict['b'], torch.tensor(np.arange(7)))) + +@patch('onnx.ModelProto') +def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): + trainer = _create_trainer() + model_states = { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + } + training_session_mock = _training_session_mock(model_states, {}, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + trainer.options.utils.frozen_weights = ['a_frozen_weight', 'a_float16_weight'] + trainer._onnx_model.graph.initializer = [ + onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), 'a_frozen_weight'), + onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), 'a_non_fronzen_weight'), + onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), 'a_float16_weight') + ] + + state_dict = trainer.state_dict() + assert (state_dict['model']['fp32']['a'] == np.arange(5)).all() + assert (state_dict['model']['fp32']['b'] == np.arange(7)).all() + assert (state_dict['model']['fp32']['a_frozen_weight'] == np.array([1, 2, 3], dtype=np.float32)).all() + assert 'a_non_fronzen_weight' not in state_dict['model']['fp32'] + assert (state_dict['model']['fp32']['a_float16_weight'] == np.array([7, 8, 9], dtype=np.float32)).all() + +@patch('onnx.ModelProto') +def test_training_session_provides_empty_optimizer_states(onnx_model_mock): + trainer = _create_trainer() + training_session_mock = _training_session_mock({}, {}, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict() + assert len(state_dict['optimizer'].keys()) == 0 + +@patch('onnx.ModelProto') +def test_training_session_provides_optimizer_states(onnx_model_mock): + trainer = _create_trainer() + optimizer_states = { + 'model_weight': { + 'Moment_1': np.arange(5), + 'Moment_2': np.arange(7) + }, + 'shared_optimizer_state': { + 'step': np.arange(1) + } + } + training_session_mock = _training_session_mock({}, optimizer_states, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict() + assert (state_dict['optimizer']['model_weight']['Moment_1'] == np.arange(5)).all() + assert (state_dict['optimizer']['model_weight']['Moment_2'] == np.arange(7)).all() + assert (state_dict['optimizer']['shared_optimizer_state']['step'] == np.arange(1)).all() + +@patch('onnx.ModelProto') +def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock): + trainer = _create_trainer() + model_states = { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + } + optimizer_states = { + 'model_weight': { + 'Moment_1': np.arange(5), + 'Moment_2': np.arange(7) + }, + 'shared_optimizer_state': { + 'step': np.arange(1) + } + } + training_session_mock = _training_session_mock(model_states, optimizer_states, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict(pytorch_format=True) + assert 'optimizer' not in state_dict + +@patch('onnx.ModelProto') +def test_training_session_provides_empty_partition_info_map(onnx_model_mock): + trainer = _create_trainer(zero_enabled=True) + training_session_mock = _training_session_mock({}, {}, {}) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict() + assert len(state_dict['partition_info'].keys()) == 0 + +@patch('onnx.ModelProto') +def test_training_session_provides_partition_info_map(onnx_model_mock): + trainer = _create_trainer(zero_enabled=True) + partition_info = { + 'a': { + 'original_dim': [1, 2, 3] + } + } + training_session_mock = _training_session_mock({}, {}, partition_info) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict() + assert state_dict['partition_info']['a']['original_dim'] == [1, 2, 3] + +@patch('onnx.ModelProto') +def test_training_session_provides_all_states(onnx_model_mock): + trainer = _create_trainer(zero_enabled=True) + model_states = { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + } + optimizer_states = { + 'model_weight': { + 'Moment_1': np.arange(5), + 'Moment_2': np.arange(7) + }, + 'shared_optimizer_state': { + 'step': np.arange(1) + } + } + partition_info = { + 'a': { + 'original_dim': [1, 2, 3] + } + } + training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info) + trainer._training_session = training_session_mock + trainer._onnx_model = onnx_model_mock() + + state_dict = trainer.state_dict() + assert (state_dict['model']['fp32']['a'] == np.arange(5)).all() + assert (state_dict['model']['fp32']['b'] == np.arange(7)).all() + assert (state_dict['optimizer']['model_weight']['Moment_1'] == np.arange(5)).all() + assert (state_dict['optimizer']['model_weight']['Moment_2'] == np.arange(7)).all() + assert (state_dict['optimizer']['shared_optimizer_state']['step'] == np.arange(1)).all() + assert state_dict['partition_info']['a']['original_dim'] == [1, 2, 3] + +def test_load_state_dict_holds_when_training_session_not_initialized(): + trainer = _create_trainer() + state_dict = { + 'model': { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + }, + 'optimizer': { + 'a': { + 'Moment_1': np.arange(5), + 'Moment_2': np.arange(7) + }, + 'shared_optimizer_state': { + 'step': np.arange(5) + } + } + } + assert not trainer._load_state_dict + state_dict = trainer.load_state_dict(state_dict) + assert trainer._load_state_dict + +@pytest.mark.parametrize("state_dict, input_state_dict, error_key", [({'optimizer':{}}, {'optimizer':{}}, 'model'), ({'model':{}}, {'model':{}}, 'optimizer')]) +def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key): + trainer = _create_trainer() + trainer._training_session = _training_session_mock({}, {}, {}) + trainer.state_dict = Mock(return_value=state_dict) + trainer._update_onnx_model_initializers = Mock() + trainer._init_session = Mock() + with patch('onnx.ModelProto') as onnx_model_mock: + trainer._onnx_model = onnx_model_mock() + trainer._onnx_model.graph.initializer = [] + with pytest.warns(UserWarning) as user_warning: + trainer.load_state_dict(input_state_dict) + + assert user_warning[0].message.args[0] == "Missing key: {} in state_dict".format(error_key) + +@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments()) +def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys): + trainer = _create_trainer() + trainer._training_session = _training_session_mock({}, {}, {}) + trainer.state_dict = Mock(return_value=state_dict) + with pytest.raises(RuntimeError) as runtime_error: + trainer.load_state_dict(input_state_dict) + + assert any(key in str(runtime_error.value) for key in error_keys) + +@patch('onnx.ModelProto') +def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock): + trainer = _create_trainer() + training_session_state_dict = { + 'model': { + 'fp32': { + 'a': np.arange(5), + 'b': np.arange(7) + } + }, + 'optimizer': { + 'a': { + 'Moment_1': np.arange(5), + 'Moment_2': np.arange(7) + }, + 'shared_optimizer_state': { + 'step': np.arange(1) + } + } + } + + input_state_dict = { + 'model': { + 'fp32': { + 'a': np.array([1, 2]), + 'b': np.array([3, 4]) + } + }, + 'optimizer': { + 'a': { + 'Moment_1': np.array([5, 6]), + 'Moment_2': np.array([7, 8]) + }, + 'shared_optimizer_state': { + 'step': np.array([9]) + } + } + } + trainer._training_session = _training_session_mock({}, {}, {}) + trainer.state_dict = Mock(return_value=training_session_state_dict) + trainer._onnx_model = onnx_model_mock() + trainer._onnx_model.graph.initializer = [ + onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), 'a'), + onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), 'b') + ] + trainer._update_onnx_model_initializers = Mock() + trainer._init_session = Mock() + + trainer.load_state_dict(input_state_dict) + + loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args + state_dict_to_load, _ = trainer._init_session.call_args + + assert 'a' in loaded_initializers[0] + assert (loaded_initializers[0]['a'] == np.array([1, 2])).all() + assert 'b' in loaded_initializers[0] + assert (loaded_initializers[0]['b'] == np.array([3, 4])).all() + + assert (state_dict_to_load[0]['optimizer']['a']['Moment_1'] == np.array([5, 6])).all() + assert (state_dict_to_load[0]['optimizer']['a']['Moment_2'] == np.array([7, 8])).all() + assert (state_dict_to_load[0]['optimizer']['shared_optimizer_state']['step'] == np.array([9])).all() diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0a965e4fb4..f4187e6d5c 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1299,6 +1299,9 @@ def run_training_python_frontend_tests(cwd): run_subprocess([sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_checkpoint_storage.py'], cwd=cwd) + run_subprocess([ + sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_checkpoint_functions.py'], cwd=cwd) + def run_training_python_frontend_e2e_tests(cwd): # frontend tests are to be added here: