diff --git a/orttraining/orttraining/python/experimental/orttrainer.py b/orttraining/orttraining/python/experimental/orttrainer.py index e01b83c0fc..ac0bb34a08 100644 --- a/orttraining/orttraining/python/experimental/orttrainer.py +++ b/orttraining/orttraining/python/experimental/orttrainer.py @@ -453,10 +453,16 @@ class ORTTrainer(object): with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. sample_inputs_copy = copy.deepcopy(sample_inputs) - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) + try: + # Deepcopy model, in case model is stateful and changes after model run. + model_copy = copy.deepcopy(model) + except Exception: + model_copy = model + warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." + " Compute will continue, but unexpected results may occur!") sample_outputs = model_copy(*sample_inputs_copy) model.train() + if isinstance(sample_outputs, torch.Tensor): sample_outputs = [sample_outputs] @@ -472,6 +478,7 @@ class ORTTrainer(object): # Export the model to ONNX f = io.BytesIO() + # Deepcopy inputs, since input values may change after model run. sample_inputs_copy = copy.deepcopy(sample_inputs) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py index 23bd66a3d5..6200e10348 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py @@ -7,7 +7,6 @@ import onnx import os import pytest import torch -import torch.nn.functional as F import onnxruntime from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\ @@ -454,80 +453,6 @@ def testToyBertCheckpointLoadZero(): assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol) -@pytest.mark.parametrize("loss_scaler, optimizer_config, gradient_accumulation_steps", [ - (None, optim.AdamConfig(), 1), - (None, optim.LambConfig(), 1), - (None, optim.SGDConfig(), 1), - (amp.DynamicLossScaler(), optim.AdamConfig(), 1), - (amp.DynamicLossScaler(), optim.LambConfig(), 5), - #(amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 -]) -def testToyBertStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps): - # Common setup - seed = 1 - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - model_desc = {'inputs' : [('x', [2, 2]), - ('label', [2, ])], - 'outputs' : [('loss', [], True), - ('output', [2, 4])]} - - # Dummy data - data1 = torch.randn(2, 2) - label1 = torch.tensor([0, 1], dtype=torch.int64) - data2 = torch.randn(2, 2) - label2 = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = {'debug' : {'deterministic_compute': True}, - 'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps}} - if loss_scaler: - opts['mixed_precision'] = { 'enabled': True, 'loss_scaler': loss_scaler} - opts = orttrainer.ORTTrainerOptions(opts) - - # Training session 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - pt_model = LinearModel() - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - - # Check state_dict keys before train. Must be empty - state_dict = checkpoint.experimental_state_dict(trainer) - assert state_dict == {} - - # Train once and check initial state - trainer.train_step(x=data1, label=label1) - state_dict = checkpoint.experimental_state_dict(trainer) - assert all([weight in state_dict.keys() for weight in ['linear.bias', 'linear.weight']]) - - # Initialize training session 2 from state of Training 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - checkpoint.experimental_load_state_dict(trainer2, state_dict) - - # Verify state was loaded properly - for k,v in state_dict.items(): - assert_allclose(v, trainer2._state_dict[k]) - - # Perform a second step in both training session 1 and 2 and verify they match - trainer.train_step(x=data2, label=label2) - state_dict = checkpoint.experimental_state_dict(trainer) - trainer2.train_step(x=data2, label=label2) - state_dict2 = checkpoint.experimental_state_dict(trainer2) - for k,v in state_dict.items(): - assert_allclose(v, state_dict2[k]) - - def testToyBertCheckpointFrozenWeights(): # Common setup seed = 1 diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 23015fe584..22e7a36ce2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -6,14 +6,14 @@ import onnx import os import pytest import torch - +import torch.nn.functional as F from onnxruntime import set_seed from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\ ModelDescription as Legacy_ModelDescription,\ LossScaler as Legacy_LossScaler,\ ORTTrainer as Legacy_ORTTrainer -from onnxruntime.experimental import _utils, amp, optim, orttrainer, TrainStepInfo,\ +from onnxruntime.experimental import _utils, amp, checkpoint, optim, orttrainer, TrainStepInfo,\ model_desc_validation as md_val,\ orttrainer_options as orttrainer_options import _test_commons,_test_helpers @@ -850,6 +850,121 @@ def testORTTrainerFrozenWeights(model_params): assert not all([param in session_state for param in model_params]) +@pytest.mark.parametrize("loss_scaler, optimizer_config, gradient_accumulation_steps", [ + (None, optim.AdamConfig(), 1), + (None, optim.LambConfig(), 1), + (None, optim.SGDConfig(), 1), + (amp.DynamicLossScaler(), optim.AdamConfig(), 1), + (amp.DynamicLossScaler(), optim.LambConfig(), 5), + #(amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 +]) +def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps): + # Common setup + seed = 1 + class LinearModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 4) + def forward(self, y=None, x=None): + if y is not None: + return self.linear(x) + y + else: + return self.linear(x) + torch.ones(2, 4) + model_desc = {'inputs' : [('x', [2, 2]), + ('label', [2, ])], + 'outputs' : [('loss', [], True), + ('output', [2, 4])]} + + # Dummy data + data1 = torch.randn(2, 2) + label1 = torch.tensor([0, 1], dtype=torch.int64) + data2 = torch.randn(2, 2) + label2 = torch.tensor([0, 1], dtype=torch.int64) + + # Setup training based on test parameters + opts = {'debug' : {'deterministic_compute': True}, + 'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps}} + if loss_scaler: + opts['mixed_precision'] = { 'enabled': True, 'loss_scaler': loss_scaler} + opts = orttrainer.ORTTrainerOptions(opts) + + # Training session 1 + torch.manual_seed(seed) + set_seed(seed) + pt_model = LinearModel() + def loss_fn(x, label): + return F.nll_loss(F.log_softmax(x, dim=1), label) + trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) + + # Check state_dict keys before train. Must be empty + state_dict = checkpoint.experimental_state_dict(trainer) + assert state_dict == {} + + # Train once and check initial state + trainer.train_step(x=data1, label=label1) + state_dict = checkpoint.experimental_state_dict(trainer) + assert all([weight in state_dict.keys() for weight in ['linear.bias', 'linear.weight']]) + + # Initialize training session 2 from state of Training 1 + torch.manual_seed(seed) + set_seed(seed) + trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) + checkpoint.experimental_load_state_dict(trainer2, state_dict) + + # Verify state was loaded properly + for k,v in state_dict.items(): + assert_allclose(v, trainer2._state_dict[k]) + + # Perform a second step in both training session 1 and 2 and verify they match + trainer.train_step(x=data2, label=label2) + state_dict = checkpoint.experimental_state_dict(trainer) + trainer2.train_step(x=data2, label=label2) + state_dict2 = checkpoint.experimental_state_dict(trainer2) + for k,v in state_dict.items(): + assert_allclose(v, state_dict2[k]) + + +def testORTTrainerNonPickableModel(): + # Common setup + import threading + seed = 1 + class UnpickableModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 4) + self._lock = threading.Lock() + + def forward(self, y=None, x=None): + with self._lock: + if y is not None: + return self.linear(x) + y + else: + return self.linear(x) + torch.ones(2, 4) + + model_desc = {'inputs' : [('x', [2, 2]), + ('label', [2, ])], + 'outputs' : [('loss', [], True), + ('output', [2, 4])]} + + # Dummy data + data = torch.randn(2, 2) + label = torch.tensor([0, 1], dtype=torch.int64) + + # Setup training based on test parameters + opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}}) + + # Training session + torch.manual_seed(seed) + set_seed(seed) + pt_model = UnpickableModel() + def loss_fn(x, label): + return F.nll_loss(F.log_softmax(x, dim=1), label) + optim_config = optim.AdamConfig() + trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts) + + # Train must succeed despite warning + _, _ = trainer.train_step(data, label) + ############################################################################### # Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ ###############################################################################