state_dict and load_state_dict for ORTTrainer (#6095)

* add functions state_dict and load_state_dict to ORTTrainer

* unit tests for state_dict and load_state_dict for ORTTrainer
This commit is contained in:
baijumeswani 2020-12-14 11:55:52 -08:00 committed by GitHub
parent d4dddd99d9
commit dd2e5a1a05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 806 additions and 4 deletions

View file

@ -3,6 +3,7 @@ import numpy as np
import os
import sys
import torch
from onnx import TensorProto
def get_device_index(device):
@ -177,3 +178,28 @@ def import_module_from_file(file_path, module_name=None):
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def state_dict_model_key():
"""Returns the model key name in the state dictionary"""
return 'model'
def state_dict_optimizer_key():
"""Returns the optimizer key name in the state dictionary"""
return 'optimizer'
def state_dict_partition_info_key():
"""Returns the partition info key name in the state dictionary"""
return 'partition_info'
def state_dict_trainer_options_key():
"""Returns the trainer options key name in the state dictionary"""
return 'trainer_options'
def state_dict_full_precision_key():
"""Returns the full precision key name in the state dictionary"""
return 'fp32'

View file

@ -5,6 +5,8 @@ import onnx
import torch
from inspect import signature
import warnings
from functools import partial
import numpy as np
import onnxruntime as ort
from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions
@ -206,6 +208,7 @@ class ORTTrainer(object):
self._train_step_info = TrainStepInfo(self.optim_config)
self._training_session = None
self._load_state_dict = None
self._init_session()
def eval_step(self, *args, **kwargs):
@ -565,7 +568,7 @@ class ORTTrainer(object):
return onnx_model
def _create_ort_training_session(self):
def _create_ort_training_session(self, state_dict = {}):
# Validating frozen_weights names
unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
if n not in [i.name for i in self._onnx_model.graph.initializer]]
@ -635,6 +638,8 @@ class ORTTrainer(object):
ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map
if bool(self._optim_state_dict):
ort_parameters.set_optimizer_initial_state(self._optim_state_dict)
if bool(state_dict) and bool(state_dict[_utils.state_dict_optimizer_key()]):
ort_parameters.set_optimizer_initial_state(state_dict[_utils.state_dict_optimizer_key()])
ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute
ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute
@ -682,9 +687,13 @@ class ORTTrainer(object):
if self.options._internal_use.extra_postprocess:
self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model)
self._init_session()
state_dict = {}
if self._load_state_dict:
state_dict = self._load_state_dict()
def _init_session(self):
self._init_session(state_dict)
def _init_session(self, state_dict = {}):
if self._onnx_model is None:
return
@ -692,7 +701,8 @@ class ORTTrainer(object):
self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True)
# Create training session used by train_step
self._create_ort_training_session()
# pass all optimizer states to the backend
self._create_ort_training_session(state_dict)
# Update model description to update dtype when mixed precision is enabled
# C++ backend modifies model's output dtype from float32 to float16 for mixed precision
@ -843,3 +853,357 @@ class ORTTrainer(object):
for w_i in replace_indices:
del self._onnx_model.graph.initializer[w_i]
self._onnx_model.graph.initializer.extend(new_weights)
def _extract_model_states(self, state_dict, pytorch_format):
"""Extract model states from the training session and load into the state_dict"""
model_states = self._training_session.get_model_state(include_mixed_precision_weights=False)
state_dict[_utils.state_dict_model_key()] = {}
# extract trained model weights from the training session
for precision in model_states:
state_dict[_utils.state_dict_model_key()][precision] = {}
for model_state_key in model_states[precision]:
if pytorch_format:
state_dict[_utils.state_dict_model_key()][precision][model_state_key] = \
torch.from_numpy(model_states[precision][model_state_key])
else:
state_dict[_utils.state_dict_model_key()][precision][model_state_key] = \
model_states[precision][model_state_key]
# extract untrained (frozen) model weights
for node in self._onnx_model.graph.initializer:
if node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] and \
node.name in self.options.utils.frozen_weights:
if pytorch_format:
state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][node.name] = \
torch.from_numpy(onnx.numpy_helper.to_array(node))
else:
state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][node.name] = \
onnx.numpy_helper.to_array(node)
def _extract_trainer_options(self, state_dict):
"""Extract relevant trainer configuration and load it into the state_dict"""
state_dict[_utils.state_dict_trainer_options_key()] = {}
state_dict[_utils.state_dict_trainer_options_key()]['mixed_precision'] = self.options.mixed_precision.enabled
state_dict[_utils.state_dict_trainer_options_key()]['zero_stage'] = \
self.options.distributed.deepspeed_zero_optimization.stage or 0
state_dict[_utils.state_dict_trainer_options_key()]['world_rank'] = self.options.distributed.world_rank or 0
state_dict[_utils.state_dict_trainer_options_key()]['world_size'] = self.options.distributed.world_size or 1
state_dict[_utils.state_dict_trainer_options_key()]['optimizer_name'] = self.optim_config.name
def state_dict(self, pytorch_format=False):
"""Returns a dictionary with model, and optionally, optimizer states
The returned dictionary contains the following information:
- Model and optimizer states
- Required ORTTrainerOptions settings
- Distributed training information, such as but not limited to ZeRO
Structure of the returned dictionary:
- When `pytorch_format = False`
schema:
{
"model":
{
type: dict,
schema:
{
"fp32":
{
type: dict,
schema:
{
model_weight_name:
{
type: array
}
}
}
}
},
"optimizer":
{
type: dict,
schema:
{
model_weight_name:
{
type: dict,
schema:
{
"Moment_1":
{
type: array
},
"Moment_2":
{
type: array
},
"Update_Count":
{
type: array,
optional: True # present if optimizer is adam, absent otherwise
}
}
},
"shared_optimizer_state":
{
type: dict,
optional: True, # present optimizer is shared, absent otherwise.
schema:
{
"step":
{
type: array,
}
}
}
}
},
"trainer_options":
{
type: dict,
schema:
{
"mixed_precision":
{
type: bool
},
"zero_stage":
{
type: int
},
"world_rank":
{
type: int
},
"world_size":
{
type: int
},
"optimizer_name":
{
type: str
}
}
},
"partition_info":
{
type: dict,
optional: True, # present if states partitioned, else absent
schema:
{
model_weight_name:
{
type: dict,
schema:
{
"original_dim":
{
type: array
}
}
}
}
}
}
- When `pytorch_format = True`
schema:
{
model_weight_name:
{
type: tensor
}
}
Args:
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema
Returns:
A dictionary with `ORTTrainer` state
"""
if not self._training_session:
warnings.warn("ONNX Runtime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().",
UserWarning)
return self._load_state_dict.args[0] if self._load_state_dict else {}
state_dict = {}
# load training session model states into the state_dict
self._extract_model_states(state_dict, pytorch_format)
if pytorch_format:
if self.options.distributed.deepspeed_zero_optimization.stage > 0:
warnings.warn("Incomplete state_dict: ZeRO enabled", UserWarning)
# if pytorch_format is true, return a flat dictionary with only model states
# which is compatible with a PyTorch model
return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()]
# load training session optimizer states into the state_dict
state_dict[_utils.state_dict_optimizer_key()] = self._training_session.get_optimizer_state()
# extract the relevant training configuration from the trainer and load them into the state_dict
self._extract_trainer_options(state_dict)
# add partition information in case of a distributed run
if self.options.distributed.deepspeed_zero_optimization.stage > 0:
state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map()
return state_dict
def _load_model_states(self, state_dict, strict):
"""Load the model states onto the onnx model graph"""
if _utils.state_dict_model_key() not in state_dict:
return
# collect all initializer names from the current onnx graph
assert self._onnx_model, "ONNX model graph is not exported"
initializer_names = {node.name for node in self._onnx_model.graph.initializer}
# loaded_initializers dict will be loaded with all the model states from the state dictionary
# that are found in the initializer_names dictionary
loaded_initializers = {}
# copy over model states from the input state dict onto the onnx model
for precision, precision_states in state_dict[_utils.state_dict_model_key()].items():
for state_key, state_value in precision_states.items():
if state_key in initializer_names:
loaded_initializers[state_key] = state_value
elif strict:
raise RuntimeError("Unexpected key: {} in state_dict[model][{}]".format(state_key, precision))
# update onnx model from loaded initializers
self._update_onnx_model_initializers(loaded_initializers)
def _load_optimizer_states(self, current_state_dict, state_dict):
"""Load the optimizer states onto the training session state dictionary"""
if _utils.state_dict_optimizer_key() not in state_dict:
return
# create an entry for the optimizer in the training session state dictionary
if _utils.state_dict_optimizer_key() not in current_state_dict:
current_state_dict[_utils.state_dict_optimizer_key()] = {}
# copy over optimizer states from the input state dict onto the training session state dict
for model_state_key, optimizer_dict in state_dict[_utils.state_dict_optimizer_key()].items():
if model_state_key not in current_state_dict[_utils.state_dict_optimizer_key()]:
current_state_dict[_utils.state_dict_optimizer_key()][model_state_key] = {}
for optimizer_state_key, optimizer_state_value in optimizer_dict.items():
current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][optimizer_state_key] = \
optimizer_state_value
def _load_state_dict_impl(self, state_dict, strict=True):
"""Load the state dictionary onto the onnx model and on the training session graph"""
# clear the callable partial
self._load_state_dict = None
def _mismatch_keys(keys1, keys2, in_error_str):
"""Find out the missing and the unexpected keys in two dictionaries
Throws a runtime error if missing or unexpected keys are found
- Keys in keys1 not in keys2 will be marked as missing
- Keys in keys2 not in keys1 will be marked as unexpected
"""
keys1 = set(keys1)
keys2 = set(keys2)
missing_keys = list(keys1 - keys2)
unexpected_keys = list(keys2 - keys1)
if len(missing_keys) > 0:
raise RuntimeError("Missing keys: {} in {}".format(missing_keys, in_error_str))
if len(unexpected_keys) > 0:
raise RuntimeError("Unexpected keys: {} in {}".format(unexpected_keys, in_error_str))
def _check_model_key_mismatch(current_state_dict, state_dict):
"""Check if there is any mismatch in the model sub state dictionary between the two state_dicts"""
# check unxexpected and missing precision keys in the model state_dict compared to the training
# session model state_dict
_mismatch_keys(current_state_dict[_utils.state_dict_model_key()],
state_dict[_utils.state_dict_model_key()], 'state_dict[model]')
# check for model state key mismatch
for precision_key in current_state_dict[_utils.state_dict_model_key()]:
_mismatch_keys(current_state_dict[_utils.state_dict_model_key()][precision_key],
state_dict[_utils.state_dict_model_key()][precision_key],
'state_dict[model][{}]'.format(precision_key))
def _check_optimizer_key_mismatch(current_state_dict, state_dict):
"""Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts"""
# check for model state key mismatch for the optimizer state_dict
_mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()],
state_dict[_utils.state_dict_optimizer_key()],
'state_dict[optimizer]')
# check for optimizer state keys mismatch
for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]:
_mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()][model_state_key],
state_dict[_utils.state_dict_optimizer_key()][model_state_key],
'state_dict[optimizer][{}]'.format(model_state_key))
def _check_key_mismatch(current_state_dict, state_dict):
"""Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts"""
# check presence of 'model' in the input state_dict
if _utils.state_dict_model_key() in state_dict:
_check_model_key_mismatch(current_state_dict, state_dict)
else:
warnings.warn("Missing key: model in state_dict", UserWarning)
# check presence of 'optimizer' in the input state_dict
if _utils.state_dict_optimizer_key() in state_dict:
_check_optimizer_key_mismatch(current_state_dict, state_dict)
else:
warnings.warn("Missing key: optimizer in state_dict", UserWarning)
# extract state dict from the current training session. this is to persist the states between
# two training sessions.
# for example, if user provided only the model states, the optimizer states from the current
# training session must be persisted
current_state_dict = {}
if self._training_session:
current_state_dict = self.state_dict()
if strict:
_check_key_mismatch(current_state_dict, state_dict)
# load the model states from the input state dictionary into the onnx graph
self._load_model_states(state_dict, strict)
# load the optimizer states from the input state dictionary into the training session states
# dictionary
self._load_optimizer_states(current_state_dict, state_dict)
return current_state_dict
def load_state_dict(self, state_dict, strict=True):
"""Loads state_dict containing model/optimizer states into ORTTrainer
The state_dict dictionary may contain the following information:
- Model and optimizer states
- Required ORTTrainerOptions settings
- Distributed training information, such as but not limited to ZeRO
Args:
state_dict: state dictionary containing both model and optimizer states. The structure of this dictionary
should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False
strict: boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict
"""
# if onnx graph has not been initialized, loading of states will be put on hold.
# a copy of the state_dict and other arguments to the function will be stored until the onnx graph has
# been initialized. Once the graph is initialized, the desired states will be loaded onto the grpah
if not self._training_session:
self._load_state_dict = partial(self._load_state_dict_impl, state_dict, strict=strict)
return
# load states onto the frontend onnx graph
state_dict = self._load_state_dict_impl(state_dict, strict=strict)
# create a new training session after loading initializer states onto the onnx graph
# pass the populated states to the training session to populate the backend graph
self._init_session(state_dict)

View file

@ -0,0 +1,409 @@
import pytest
from unittest.mock import patch, Mock
from orttraining_test_orttrainer_frontend import _load_pytorch_transformer_model
from onnxruntime.training import amp, checkpoint, optim, orttrainer
import numpy as np
import onnx
import torch
# Helper functions
def _create_trainer(zero_enabled=False):
"""Cerates a simple ORTTrainer for ORTTrainer functional tests"""
device = 'cuda'
optim_config = optim.LambConfig(lr=0.1)
opts = {
'device' : {'id' : device},
'debug' : {'deterministic_compute': True}
}
if zero_enabled:
opts['distributed'] = {
'world_rank' : 0,
'world_size' : 1,
'allreduce_post_accumulation' : True,
'deepspeed_zero_optimization':
{
'stage': 1
}
}
model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts))
return trainer
class _training_session_mock(object):
"""Mock object for the ORTTrainer _training_session member"""
def __init__(self, model_states, optimizer_states, partition_info):
self.model_states = model_states
self.optimizer_states = optimizer_states
self.partition_info = partition_info
def get_model_state(self, include_mixed_precision_weights=False):
return self.model_states
def get_optimizer_state(self):
return self.optimizer_states
def get_partition_info_map(self):
return self.partition_info
def _get_load_state_dict_strict_error_arguments():
"""Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing
Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments)
The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and
throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys.
"""
training_session_state_dict = {
'model': {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
},
'optimizer': {
'a': {
'Moment_1': np.arange(5),
'Moment_2': np.arange(7)
},
'shared_optimizer_state': {
'step': np.arange(5)
}
}
}
# input state dictionaries
precision_key_missing = {'model': {}, 'optimizer': {}}
precision_key_unexpected = {'model': {'fp32': {}, 'fp16': {}}, 'optimizer': {}}
model_state_key_missing = {'model': {'fp32': {}}, 'optimizer': {}}
model_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3, 'c': 4}}, 'optimizer': {}}
optimizer_model_state_key_missing = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': {}}
optimizer_model_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \
{'a': {}, 'shared_optimizer_state': {}, 'b': {}}}
optimizer_state_key_missing = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \
{'a': {}, 'shared_optimizer_state': {'step': np.arange(5)}}}
optimizer_state_key_unexpected = {'model': {'fp32': {'a': 2, 'b': 3}}, 'optimizer': \
{'a': {'Moment_1': np.arange(5), 'Moment_2': np.arange(7)}, 'shared_optimizer_state': {'step': np.arange(5), 'another_step': np.arange(1)}}}
input_arguments = [
(training_session_state_dict, precision_key_missing, ['fp32']),
(training_session_state_dict, precision_key_unexpected, ['fp16']),
(training_session_state_dict, model_state_key_missing, ['a', 'b']),
(training_session_state_dict, model_state_key_unexpected, ['c']),
(training_session_state_dict, optimizer_model_state_key_missing, ['a', 'shared_optimizer_state']),
(training_session_state_dict, optimizer_model_state_key_unexpected, ['b']),
(training_session_state_dict, optimizer_state_key_missing, ['Moment_1', 'Moment_2']),
(training_session_state_dict, optimizer_state_key_unexpected, ['another_step'])
]
return input_arguments
# Tests
def test_empty_state_dict_when_training_session_uninitialized():
trainer = _create_trainer()
with pytest.warns(UserWarning) as user_warning:
state_dict = trainer.state_dict()
assert len(state_dict.keys()) == 0
assert user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " \
"Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()."
@patch('onnx.ModelProto')
def test_training_session_provides_empty_model_states(onnx_model_mock):
trainer = _create_trainer()
training_session_mock = _training_session_mock({}, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert len(state_dict['model'].keys()) == 0
@patch('onnx.ModelProto')
def test_training_session_provides_model_states(onnx_model_mock):
trainer = _create_trainer()
model_states = {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
}
training_session_mock = _training_session_mock(model_states, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert (state_dict['model']['fp32']['a'] == np.arange(5)).all()
assert (state_dict['model']['fp32']['b'] == np.arange(7)).all()
@patch('onnx.ModelProto')
def test_training_session_provides_model_states_pytorch_format(onnx_model_mock):
trainer = _create_trainer()
model_states = {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
}
training_session_mock = _training_session_mock(model_states, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict(pytorch_format=True)
assert torch.all(torch.eq(state_dict['a'], torch.tensor(np.arange(5))))
assert torch.all(torch.eq(state_dict['b'], torch.tensor(np.arange(7))))
@patch('onnx.ModelProto')
def test_onnx_graph_provides_frozen_model_states(onnx_model_mock):
trainer = _create_trainer()
model_states = {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
}
training_session_mock = _training_session_mock(model_states, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
trainer.options.utils.frozen_weights = ['a_frozen_weight', 'a_float16_weight']
trainer._onnx_model.graph.initializer = [
onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), 'a_frozen_weight'),
onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), 'a_non_fronzen_weight'),
onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), 'a_float16_weight')
]
state_dict = trainer.state_dict()
assert (state_dict['model']['fp32']['a'] == np.arange(5)).all()
assert (state_dict['model']['fp32']['b'] == np.arange(7)).all()
assert (state_dict['model']['fp32']['a_frozen_weight'] == np.array([1, 2, 3], dtype=np.float32)).all()
assert 'a_non_fronzen_weight' not in state_dict['model']['fp32']
assert (state_dict['model']['fp32']['a_float16_weight'] == np.array([7, 8, 9], dtype=np.float32)).all()
@patch('onnx.ModelProto')
def test_training_session_provides_empty_optimizer_states(onnx_model_mock):
trainer = _create_trainer()
training_session_mock = _training_session_mock({}, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert len(state_dict['optimizer'].keys()) == 0
@patch('onnx.ModelProto')
def test_training_session_provides_optimizer_states(onnx_model_mock):
trainer = _create_trainer()
optimizer_states = {
'model_weight': {
'Moment_1': np.arange(5),
'Moment_2': np.arange(7)
},
'shared_optimizer_state': {
'step': np.arange(1)
}
}
training_session_mock = _training_session_mock({}, optimizer_states, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert (state_dict['optimizer']['model_weight']['Moment_1'] == np.arange(5)).all()
assert (state_dict['optimizer']['model_weight']['Moment_2'] == np.arange(7)).all()
assert (state_dict['optimizer']['shared_optimizer_state']['step'] == np.arange(1)).all()
@patch('onnx.ModelProto')
def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock):
trainer = _create_trainer()
model_states = {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
}
optimizer_states = {
'model_weight': {
'Moment_1': np.arange(5),
'Moment_2': np.arange(7)
},
'shared_optimizer_state': {
'step': np.arange(1)
}
}
training_session_mock = _training_session_mock(model_states, optimizer_states, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict(pytorch_format=True)
assert 'optimizer' not in state_dict
@patch('onnx.ModelProto')
def test_training_session_provides_empty_partition_info_map(onnx_model_mock):
trainer = _create_trainer(zero_enabled=True)
training_session_mock = _training_session_mock({}, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert len(state_dict['partition_info'].keys()) == 0
@patch('onnx.ModelProto')
def test_training_session_provides_partition_info_map(onnx_model_mock):
trainer = _create_trainer(zero_enabled=True)
partition_info = {
'a': {
'original_dim': [1, 2, 3]
}
}
training_session_mock = _training_session_mock({}, {}, partition_info)
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert state_dict['partition_info']['a']['original_dim'] == [1, 2, 3]
@patch('onnx.ModelProto')
def test_training_session_provides_all_states(onnx_model_mock):
trainer = _create_trainer(zero_enabled=True)
model_states = {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
}
optimizer_states = {
'model_weight': {
'Moment_1': np.arange(5),
'Moment_2': np.arange(7)
},
'shared_optimizer_state': {
'step': np.arange(1)
}
}
partition_info = {
'a': {
'original_dim': [1, 2, 3]
}
}
training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info)
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert (state_dict['model']['fp32']['a'] == np.arange(5)).all()
assert (state_dict['model']['fp32']['b'] == np.arange(7)).all()
assert (state_dict['optimizer']['model_weight']['Moment_1'] == np.arange(5)).all()
assert (state_dict['optimizer']['model_weight']['Moment_2'] == np.arange(7)).all()
assert (state_dict['optimizer']['shared_optimizer_state']['step'] == np.arange(1)).all()
assert state_dict['partition_info']['a']['original_dim'] == [1, 2, 3]
def test_load_state_dict_holds_when_training_session_not_initialized():
trainer = _create_trainer()
state_dict = {
'model': {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
},
'optimizer': {
'a': {
'Moment_1': np.arange(5),
'Moment_2': np.arange(7)
},
'shared_optimizer_state': {
'step': np.arange(5)
}
}
}
assert not trainer._load_state_dict
state_dict = trainer.load_state_dict(state_dict)
assert trainer._load_state_dict
@pytest.mark.parametrize("state_dict, input_state_dict, error_key", [({'optimizer':{}}, {'optimizer':{}}, 'model'), ({'model':{}}, {'model':{}}, 'optimizer')])
def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key):
trainer = _create_trainer()
trainer._training_session = _training_session_mock({}, {}, {})
trainer.state_dict = Mock(return_value=state_dict)
trainer._update_onnx_model_initializers = Mock()
trainer._init_session = Mock()
with patch('onnx.ModelProto') as onnx_model_mock:
trainer._onnx_model = onnx_model_mock()
trainer._onnx_model.graph.initializer = []
with pytest.warns(UserWarning) as user_warning:
trainer.load_state_dict(input_state_dict)
assert user_warning[0].message.args[0] == "Missing key: {} in state_dict".format(error_key)
@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments())
def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys):
trainer = _create_trainer()
trainer._training_session = _training_session_mock({}, {}, {})
trainer.state_dict = Mock(return_value=state_dict)
with pytest.raises(RuntimeError) as runtime_error:
trainer.load_state_dict(input_state_dict)
assert any(key in str(runtime_error.value) for key in error_keys)
@patch('onnx.ModelProto')
def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock):
trainer = _create_trainer()
training_session_state_dict = {
'model': {
'fp32': {
'a': np.arange(5),
'b': np.arange(7)
}
},
'optimizer': {
'a': {
'Moment_1': np.arange(5),
'Moment_2': np.arange(7)
},
'shared_optimizer_state': {
'step': np.arange(1)
}
}
}
input_state_dict = {
'model': {
'fp32': {
'a': np.array([1, 2]),
'b': np.array([3, 4])
}
},
'optimizer': {
'a': {
'Moment_1': np.array([5, 6]),
'Moment_2': np.array([7, 8])
},
'shared_optimizer_state': {
'step': np.array([9])
}
}
}
trainer._training_session = _training_session_mock({}, {}, {})
trainer.state_dict = Mock(return_value=training_session_state_dict)
trainer._onnx_model = onnx_model_mock()
trainer._onnx_model.graph.initializer = [
onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), 'a'),
onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), 'b')
]
trainer._update_onnx_model_initializers = Mock()
trainer._init_session = Mock()
trainer.load_state_dict(input_state_dict)
loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args
state_dict_to_load, _ = trainer._init_session.call_args
assert 'a' in loaded_initializers[0]
assert (loaded_initializers[0]['a'] == np.array([1, 2])).all()
assert 'b' in loaded_initializers[0]
assert (loaded_initializers[0]['b'] == np.array([3, 4])).all()
assert (state_dict_to_load[0]['optimizer']['a']['Moment_1'] == np.array([5, 6])).all()
assert (state_dict_to_load[0]['optimizer']['a']['Moment_2'] == np.array([7, 8])).all()
assert (state_dict_to_load[0]['optimizer']['shared_optimizer_state']['step'] == np.array([9])).all()

View file

@ -1299,6 +1299,9 @@ def run_training_python_frontend_tests(cwd):
run_subprocess([sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_checkpoint_storage.py'], cwd=cwd)
run_subprocess([
sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_checkpoint_functions.py'], cwd=cwd)
def run_training_python_frontend_e2e_tests(cwd):
# frontend tests are to be added here: