From f7ff5a7aa1813bddca716f07bfe8c0e6492cb8f0 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 5 May 2020 11:47:46 -0700 Subject: [PATCH] Fix state_dict and save_as_onnx for training (#3774) --- .../python/onnxruntime_test_ort_trainer.py | 63 +++++++++++++++++-- orttraining/orttraining/python/ort_trainer.py | 49 +++++++++------ 2 files changed, 86 insertions(+), 26 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index 6112f293a4..a7b057a820 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -82,7 +82,7 @@ def create_ort_trainer(gradient_accumulation_steps, use_mixed_precision=use_mixed_precision, allreduce_post_accumulation=allreduce_post_accumulation, partition_optimizer = partition_optimizer) - + return model, model_desc, device def runBertTrainingTest(gradient_accumulation_steps, @@ -92,7 +92,7 @@ def runBertTrainingTest(gradient_accumulation_steps, use_internel_loss_scale=False): torch.manual_seed(1) onnxruntime.set_seed(1) - + loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None model, model_desc, device = create_ort_trainer(gradient_accumulation_steps, @@ -237,7 +237,7 @@ class MNISTWrapper(): kwargs = {'num_workers': 0, 'pin_memory': True} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, - transform=transforms.Compose([transforms.ToTensor(), + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=args_batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( @@ -259,7 +259,7 @@ class MNISTWrapper(): return model, model_desc def get_trainer(self, model, model_desc, device): - return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ], + return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ], torch.float32), device, _opset_version=12) class TestOrtTrainer(unittest.TestCase): @@ -355,6 +355,57 @@ class TestOrtTrainer(unittest.TestCase): assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch") assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch") + def testMNISTStateDict(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + trainer = mnist.get_trainer(model, model_desc, device) + state_dict = trainer.state_dict() + assert state_dict == {} + + learningRate = 0.02 + epoch = 0 + + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + state_dict = trainer.state_dict() + assert state_dict.keys() == {'model_.fc1.bias', 'model_.fc1.weight', 'model_.fc2.bias', 'model_.fc2.weight'} + + def testMNISTSaveAsONNX(self): + torch.manual_seed(1) + device = torch.device("cuda") + onnx_file_name = 'mnist.onnx' + if os.path.exists(onnx_file_name): + os.remove(onnx_file_name) + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + trainer = mnist.get_trainer(model, model_desc, device) + trainer.save_as_onnx(onnx_file_name) + assert not os.path.exists(onnx_file_name) + + learningRate = 0.02 + epoch = 0 + + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + trainer.save_as_onnx(onnx_file_name) + assert os.path.exists(onnx_file_name) + def testBertTrainingBasic(self): expected_losses = [ 11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543, @@ -379,7 +430,7 @@ class TestOrtTrainer(unittest.TestCase): 11.02906322479248, 11.094074249267578, 11.008995056152344, 11.061283111572266, 11.029059410095215, 11.04024887084961, 11.04680347442627, 10.993708610534668] expected_eval_loss = [10.959011] - + actual_losses, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=4, use_mixed_precision=False, allreduce_post_accumulation=False) @@ -436,7 +487,7 @@ class TestOrtTrainer(unittest.TestCase): ckpt_dir = get_name("ort_ckpt") load_checkpoint(model, ckpt_dir, 'bert_toy_lamb') - + expected_eval_loss = [10.997552871] input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[ 635],[27399],[20647],[18519],[15546]], device=device) diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 71d583c08d..7e0e678026 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -1,5 +1,6 @@ import io import os +import warnings import numpy as np import onnx from onnx import numpy_helper @@ -92,7 +93,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out device_index = input_get_device_index(input) iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype), list(input.size()), input.data_ptr()) - + output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) torch_outputs = {} for output_desc in output_descs_resolved: @@ -526,11 +527,11 @@ def save_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", c assert os.path.exists(checkpoint_dir), "ERROR: Checkpoint directory doesn't exist: {}".format(checkpoint_dir) - checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.partition_optimizer_, model.world_rank, model.world_size) + checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.partition_optimizer_, model.world_rank, model.world_size) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) if os.path.exists(checkpoint_file): - print("WARNING: {} already exists, overwriting.".format(checkpoint_file)) + warnings.warn("{} already exists, overwriting.".format(checkpoint_file)) torch.save(checkpoint_state_dict, checkpoint_file) @@ -538,7 +539,7 @@ def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partiti checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - if is_partitioned: + if is_partitioned: assert_msg = ("Couldn't find checkpoint file {}." + "Optimizer partitioning is enabled using ZeRO. Please make sure that the "+ "checkpoint file exists for rank {} of {}.").format(checkpoint_file,model.world_rank, model.world_size) @@ -561,24 +562,24 @@ def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): model.load_state_dict(aggregate_state_dict, strict=strict) - # aggregate other keys in the state_dict. + # aggregate other keys in the state_dict. # Values will be overwritten for matching keys among workers all_checkpoint_states=dict() for checkpoint_file in checkpoint_files: checkpoint_state = torch.load(checkpoint_file, map_location='cpu') del(checkpoint_state['model']) - all_checkpoint_states.update(checkpoint_state) + all_checkpoint_states.update(checkpoint_state) return all_checkpoint_states def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) is_partitioned = False if len(checkpoint_files) > 1: - print(f"WARNING: Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." + + warnings.warn(f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." + "Attempting to load ZeRO checkpoint.") is_partitioned = True if (not model.partition_optimizer_) and is_partitioned: - return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) + return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) else: return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) @@ -629,7 +630,7 @@ class ORTTrainer(): else: self.onnx_model_ = model if loss_fn is not None: - print("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") + warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") # TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn self.loss_fn_ = None @@ -657,7 +658,7 @@ class ORTTrainer(): self.loss_scaler_ = loss_scaler if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None: - print("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") + warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") self.training_optimizer_name_ = training_optimizer_name self.learning_rate_description_ = learning_rate_description self.map_optimizer_attributes_ = map_optimizer_attributes @@ -678,7 +679,7 @@ class ORTTrainer(): if self.onnx_model_ is None: return - self.verify_fully_optimized_model(self.onnx_model_) + self._verify_fully_optimized_model(self.onnx_model_) self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \ create_ort_training_session_with_optimizer( self.onnx_model_, self.device_, @@ -719,7 +720,7 @@ class ORTTrainer(): IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)] if self.state_dict_: - self.load_state_dict(self.state_dict_, self.strict_) + self.load_state_dict(self.state_dict_, self.strict_) self.state_dict_ = None def _init_onnx_model(self, inputs): @@ -742,6 +743,10 @@ class ORTTrainer(): self.is_train = False def state_dict(self): + if not self.session: + warnings.warn("ONNXRuntime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling state_dict().") + return {} session_state = self.session.get_state() torch_state = {} for name in session_state: @@ -763,6 +768,10 @@ class ORTTrainer(): self.session.load_state(session_state, strict) def save_as_onnx(self, path): + if not self.session: + warnings.warn("ONNXRuntime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling save_as_onnx().") + return state_tensors = self.session.get_state() # replace the initializers with new value new_weights = [] @@ -781,7 +790,7 @@ class ORTTrainer(): with open(path, "wb") as f: f.write(self.onnx_model_.SerializeToString()) - def prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs): + def _prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs): fetches = None if type(args) == tuple and len(args) == 1 and type(args[0]) == list: input = tuple(args[0]) @@ -842,17 +851,17 @@ class ORTTrainer(): loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) if self.onnx_model_ is None: - sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_, + sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) self._init_onnx_model(sample_input) if self.use_mixed_precision: - input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale, + input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs) assert len(self.input_desc_with_lr_and_loss_scale) == len(input) input_descs = self.input_desc_with_lr_and_loss_scale else: - input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr, + input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs) assert len(self.input_desc_with_lr) == len(input) input_descs = self.input_desc_with_lr @@ -925,8 +934,8 @@ class ORTTrainer(): """ # with model_loss_cls, the last input is label, first output is loss - input, fetches = self.prepare_input_and_fetches(self.model_desc_.inputs_, - None, None, *args, **kwargs) + input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, + None, None, *args, **kwargs) if self.onnx_model_ is None: if self.torch_model_ is not None: @@ -958,7 +967,7 @@ class ORTTrainer(): else: return [session_run_results[output_desc.name_] for output_desc in output_desc] - def verify_fully_optimized_model(self, model): + def _verify_fully_optimized_model(self, model): assert(len(model.graph.output) > 0) # model's first output must be the loss tensor if model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT and\ @@ -983,7 +992,7 @@ class ORTModel(): self.world_size = world_size self.gradient_accumulation_steps = gradient_accumulation_steps self.opset_version = _opset_version - + # Adding to not break checkpointing functions for ORTModel self.partition_optimizer_ = False