Fix state_dict and save_as_onnx for training (#3774)

This commit is contained in:
Bowen Bao 2020-05-05 11:47:46 -07:00 committed by GitHub
parent 5dfc91db51
commit f7ff5a7aa1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 26 deletions

View file

@ -82,7 +82,7 @@ def create_ort_trainer(gradient_accumulation_steps,
use_mixed_precision=use_mixed_precision,
allreduce_post_accumulation=allreduce_post_accumulation,
partition_optimizer = partition_optimizer)
return model, model_desc, device
def runBertTrainingTest(gradient_accumulation_steps,
@ -92,7 +92,7 @@ def runBertTrainingTest(gradient_accumulation_steps,
use_internel_loss_scale=False):
torch.manual_seed(1)
onnxruntime.set_seed(1)
loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None
model, model_desc, device = create_ort_trainer(gradient_accumulation_steps,
@ -237,7 +237,7 @@ class MNISTWrapper():
kwargs = {'num_workers': 0, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args_batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
@ -259,7 +259,7 @@ class MNISTWrapper():
return model, model_desc
def get_trainer(self, model, model_desc, device):
return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
torch.float32), device, _opset_version=12)
class TestOrtTrainer(unittest.TestCase):
@ -355,6 +355,57 @@ class TestOrtTrainer(unittest.TestCase):
assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch")
assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch")
def testMNISTStateDict(self):
torch.manual_seed(1)
device = torch.device("cuda")
mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()
trainer = mnist.get_trainer(model, model_desc, device)
state_dict = trainer.state_dict()
assert state_dict == {}
learningRate = 0.02
epoch = 0
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
state_dict = trainer.state_dict()
assert state_dict.keys() == {'model_.fc1.bias', 'model_.fc1.weight', 'model_.fc2.bias', 'model_.fc2.weight'}
def testMNISTSaveAsONNX(self):
torch.manual_seed(1)
device = torch.device("cuda")
onnx_file_name = 'mnist.onnx'
if os.path.exists(onnx_file_name):
os.remove(onnx_file_name)
mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()
trainer = mnist.get_trainer(model, model_desc, device)
trainer.save_as_onnx(onnx_file_name)
assert not os.path.exists(onnx_file_name)
learningRate = 0.02
epoch = 0
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
trainer.save_as_onnx(onnx_file_name)
assert os.path.exists(onnx_file_name)
def testBertTrainingBasic(self):
expected_losses = [
11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543,
@ -379,7 +430,7 @@ class TestOrtTrainer(unittest.TestCase):
11.02906322479248, 11.094074249267578, 11.008995056152344, 11.061283111572266,
11.029059410095215, 11.04024887084961, 11.04680347442627, 10.993708610534668]
expected_eval_loss = [10.959011]
actual_losses, actual_eval_loss = runBertTrainingTest(
gradient_accumulation_steps=4, use_mixed_precision=False, allreduce_post_accumulation=False)
@ -436,7 +487,7 @@ class TestOrtTrainer(unittest.TestCase):
ckpt_dir = get_name("ort_ckpt")
load_checkpoint(model, ckpt_dir, 'bert_toy_lamb')
expected_eval_loss = [10.997552871]
input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[ 635],[27399],[20647],[18519],[15546]], device=device)

View file

@ -1,5 +1,6 @@
import io
import os
import warnings
import numpy as np
import onnx
from onnx import numpy_helper
@ -92,7 +93,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out
device_index = input_get_device_index(input)
iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype),
list(input.size()), input.data_ptr())
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
torch_outputs = {}
for output_desc in output_descs_resolved:
@ -526,11 +527,11 @@ def save_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", c
assert os.path.exists(checkpoint_dir), "ERROR: Checkpoint directory doesn't exist: {}".format(checkpoint_dir)
checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.partition_optimizer_, model.world_rank, model.world_size)
checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.partition_optimizer_, model.world_rank, model.world_size)
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
if os.path.exists(checkpoint_file):
print("WARNING: {} already exists, overwriting.".format(checkpoint_file))
warnings.warn("{} already exists, overwriting.".format(checkpoint_file))
torch.save(checkpoint_state_dict, checkpoint_file)
@ -538,7 +539,7 @@ def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partiti
checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size)
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
if is_partitioned:
if is_partitioned:
assert_msg = ("Couldn't find checkpoint file {}." +
"Optimizer partitioning is enabled using ZeRO. Please make sure that the "+
"checkpoint file exists for rank {} of {}.").format(checkpoint_file,model.world_rank, model.world_size)
@ -561,24 +562,24 @@ def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict):
model.load_state_dict(aggregate_state_dict, strict=strict)
# aggregate other keys in the state_dict.
# aggregate other keys in the state_dict.
# Values will be overwritten for matching keys among workers
all_checkpoint_states=dict()
for checkpoint_file in checkpoint_files:
checkpoint_state = torch.load(checkpoint_file, map_location='cpu')
del(checkpoint_state['model'])
all_checkpoint_states.update(checkpoint_state)
all_checkpoint_states.update(checkpoint_state)
return all_checkpoint_states
def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False):
checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
is_partitioned = False
if len(checkpoint_files) > 1:
print(f"WARNING: Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." +
warnings.warn(f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." +
"Attempting to load ZeRO checkpoint.")
is_partitioned = True
if (not model.partition_optimizer_) and is_partitioned:
return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict)
return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict)
else:
return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
@ -629,7 +630,7 @@ class ORTTrainer():
else:
self.onnx_model_ = model
if loss_fn is not None:
print("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
# TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn
self.loss_fn_ = None
@ -657,7 +658,7 @@ class ORTTrainer():
self.loss_scaler_ = loss_scaler
if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None:
print("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
self.training_optimizer_name_ = training_optimizer_name
self.learning_rate_description_ = learning_rate_description
self.map_optimizer_attributes_ = map_optimizer_attributes
@ -678,7 +679,7 @@ class ORTTrainer():
if self.onnx_model_ is None:
return
self.verify_fully_optimized_model(self.onnx_model_)
self._verify_fully_optimized_model(self.onnx_model_)
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
create_ort_training_session_with_optimizer(
self.onnx_model_, self.device_,
@ -719,7 +720,7 @@ class ORTTrainer():
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)]
if self.state_dict_:
self.load_state_dict(self.state_dict_, self.strict_)
self.load_state_dict(self.state_dict_, self.strict_)
self.state_dict_ = None
def _init_onnx_model(self, inputs):
@ -742,6 +743,10 @@ class ORTTrainer():
self.is_train = False
def state_dict(self):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling state_dict().")
return {}
session_state = self.session.get_state()
torch_state = {}
for name in session_state:
@ -763,6 +768,10 @@ class ORTTrainer():
self.session.load_state(session_state, strict)
def save_as_onnx(self, path):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling save_as_onnx().")
return
state_tensors = self.session.get_state()
# replace the initializers with new value
new_weights = []
@ -781,7 +790,7 @@ class ORTTrainer():
with open(path, "wb") as f:
f.write(self.onnx_model_.SerializeToString())
def prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs):
def _prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs):
fetches = None
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
input = tuple(args[0])
@ -842,17 +851,17 @@ class ORTTrainer():
loss_scale = torch.tensor([self.loss_scaler_.loss_scale_])
if self.onnx_model_ is None:
sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_,
sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_,
None, None, *args, **kwargs)
self._init_onnx_model(sample_input)
if self.use_mixed_precision:
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale,
input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale,
learning_rate, loss_scale, *args, **kwargs)
assert len(self.input_desc_with_lr_and_loss_scale) == len(input)
input_descs = self.input_desc_with_lr_and_loss_scale
else:
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr,
input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr,
learning_rate, loss_scale, *args, **kwargs)
assert len(self.input_desc_with_lr) == len(input)
input_descs = self.input_desc_with_lr
@ -925,8 +934,8 @@ class ORTTrainer():
"""
# with model_loss_cls, the last input is label, first output is loss
input, fetches = self.prepare_input_and_fetches(self.model_desc_.inputs_,
None, None, *args, **kwargs)
input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_,
None, None, *args, **kwargs)
if self.onnx_model_ is None:
if self.torch_model_ is not None:
@ -958,7 +967,7 @@ class ORTTrainer():
else:
return [session_run_results[output_desc.name_] for output_desc in output_desc]
def verify_fully_optimized_model(self, model):
def _verify_fully_optimized_model(self, model):
assert(len(model.graph.output) > 0)
# model's first output must be the loss tensor
if model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT and\
@ -983,7 +992,7 @@ class ORTModel():
self.world_size = world_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.opset_version = _opset_version
# Adding to not break checkpointing functions for ORTModel
self.partition_optimizer_ = False