diff --git a/orttraining/orttraining/python/experimental/orttrainer.py b/orttraining/orttraining/python/experimental/orttrainer.py index 5741e37cc1..f7b86b1c19 100644 --- a/orttraining/orttraining/python/experimental/orttrainer.py +++ b/orttraining/orttraining/python/experimental/orttrainer.py @@ -267,8 +267,9 @@ class ORTTrainer(object): ValueError: raised when `path` is not valid path """ if not self._training_session: - raise RuntimeWarning("Training session is not initialized yet. " + warnings.warn("Training session is not initialized yet. " "'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'.") + return state_tensors = self._training_session.get_state() self._update_onnx_model_initializers(state_tensors) @@ -276,7 +277,8 @@ class ORTTrainer(object): dir_name = os.path.dirname(path) file_name = os.path.basename(path) if not dir_name or not os.path.exists(dir_name) or not file_name: - raise ValueError("'path' is not valid. It must contain an existing folder + filename") + warnings.warn("'path' is not valid. It must contain an existing folder + filename") + return with open(path, "wb") as f: f.write(self._onnx_model.SerializeToString()) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py index 09c81d8826..07b6cf4c41 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py @@ -79,6 +79,7 @@ def optimizer_parameters(model): if any(key in initializer.name for key in no_decay_keys): no_decay_param_group.append(initializer.name) params = [{'params': no_decay_param_group, "alpha": 0.9, "beta": 0.999, "lambda_coef": 0.0, "epsilon": 1e-6}] + return params @@ -593,6 +594,100 @@ def testToyBertCheckpointFrozenWeights(): loaded_state_dict = checkpoint.experimental_state_dict(trainer2) assert state_dict.keys() == loaded_state_dict.keys() +@pytest.mark.parametrize("model_params", [ + (['bert.embeddings.LayerNorm.bias']), + (['bert.embeddings.LayerNorm.bias', + 'bert.embeddings.LayerNorm.weight', + 'bert.encoder.layer.0.attention.output.LayerNorm.bias']), +]) +def testORTTrainerFrozenWeights(model_params): + device = 'cuda' + total_steps = 10 + seed = 1 + + # EXPERIMENTAL API + model_desc = bert_model_description() + model = load_bert_onnx_model() + + optim_config = optim.LambConfig() + # Setup ORTTrainer WITHOUT frozen weights + opts_dict = { + 'debug' : { + 'deterministic_compute': True + }, + 'device': { + 'id': device, + }, + } + opts = orttrainer.ORTTrainerOptions(opts_dict) + + torch.manual_seed(seed) + set_seed(seed) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) + + for i in range(total_steps): + sample_input = generate_random_input_from_model_desc(model_desc, i) + trainer.train_step(*sample_input) + + # All model_params must be in the session state + assert trainer._onnx_model is not None + session_state = trainer._training_session.get_state() + assert all([param in session_state for param in model_params]) + + # Setup ORTTrainer WITH frozen weights + opts_dict.update({'utils' : {'frozen_weights' : model_params}}) + opts = orttrainer.ORTTrainerOptions(opts_dict) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) + + for i in range(total_steps): + sample_input = generate_random_input_from_model_desc(model_desc, i) + trainer.train_step(*sample_input) + + # All model_params CANNOT be in the session state + assert trainer._onnx_model is not None + session_state = trainer._training_session.get_state() + assert not any([param in session_state for param in model_params]) + +def testToyBERTSaveAsONNX(): + device = 'cuda' + onnx_file_name = os.path.join('..','..','..','temp_toy_bert_onnx_model.onnx') + if os.path.exists(onnx_file_name): + os.remove(onnx_file_name) + assert not os.path.exists(onnx_file_name) + + # Load trainer + model_desc = bert_model_description() + model = load_bert_onnx_model() + + optim_config = optim.LambConfig() + opts = orttrainer.ORTTrainerOptions({ + 'debug' : { + 'deterministic_compute': True + }, + 'device': { + 'id': device, + }, + }) + + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config)#, options=opts) + + trainer.save_as_onnx(onnx_file_name) + assert os.path.exists(onnx_file_name) + + with open(onnx_file_name, "rb") as f: + bin_str = f.read() + reload_onnx_model = onnx.load_model_from_string(bin_str) + os.remove(onnx_file_name) + + # Create a new trainer from persisted ONNX model and compare with original ONNX model + trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config)#, options=opts) + assert trainer_from_onnx._onnx_model is not None + assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) + for initializer, loaded_initializer in zip(trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer): + assert initializer.name == loaded_initializer.name + assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph)) + _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx) + ############################################################################### # Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ @@ -682,7 +777,7 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - assert trainer.options.lr_scheduler.get_last_lr()[0] == legacy_lr_scheduler(i) + assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) # LEGACY IMPLEMENTATION torch.manual_seed(seed)