From 15cb4b3023dfd6e5439dcd8fc6151de4f3242a5c Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 23 Jun 2020 11:45:26 -0700 Subject: [PATCH] Fix session load state & run extra_postpasses only once (#4255) * Fix session load state & run extra_postpasses only once * add testcase for onnx model as well --- .../python/onnxruntime_test_ort_trainer.py | 52 +++++++++++- orttraining/orttraining/python/ort_trainer.py | 7 +- .../python/onnxruntime_test_postprocess.py | 82 +++++++++++++++++++ 3 files changed, 136 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index d77cd65699..e77b93520e 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -287,9 +287,9 @@ class MNISTWrapper(): return model, model_desc def get_trainer(self, model, model_desc, device, onnx_opset_ver=12, frozen_weights=[], - internal_loss_fn=False, get_lr_this_step=None): + internal_loss_fn=False, get_lr_this_step=None, optimizer="SGDOptimizer"): loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None - return ORTTrainer(model, loss_fn, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ], + return ORTTrainer(model, loss_fn, model_desc, optimizer, None, IODescription('Learning_Rate', [1, ], torch.float32), device, _opset_version=onnx_opset_ver, frozen_weights=frozen_weights, get_lr_this_step=get_lr_this_step) @@ -606,6 +606,54 @@ class TestOrtTrainer(unittest.TestCase): ckpt_loss, _ = trainer.eval_step(data, target) assert loss == ckpt_loss + loaded_state_dict = trainer.state_dict() + assert state_dict.keys() == loaded_state_dict.keys() + + def testMNISTTrainingCheckpoint(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + trainer = mnist.get_trainer(model, model_desc, device, + optimizer='LambOptimizer', frozen_weights=['fc1.weight']) + + learningRate = 0.02 + epoch = 0 + + # do 5 train step + for i in range(5): + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + # do one eval step + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.eval_step(data, target) + + # save checkpoint, load model and compare + state_dict = trainer.state_dict() + + new_model, _ = mnist.get_model() + trainer = mnist.get_trainer(new_model, model_desc, device, + optimizer='LambOptimizer', frozen_weights=['fc1.weight']) + trainer.load_state_dict(state_dict) + + ckpt_loss, _ = trainer.eval_step(data, target) + assert loss == ckpt_loss + + loaded_state_dict = trainer.state_dict() + assert state_dict.keys() == loaded_state_dict.keys() + for key in state_dict: + assert np.array_equal(state_dict[key], loaded_state_dict[key]) + def testBertTrainingBasic(self): expected_losses = [ 11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543, diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 07f9d1369d..c28fd47e21 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -721,9 +721,6 @@ class ORTTrainer(): self.onnx_model_ = convert_model_loss_fn_to_onnx( self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs, opset_version=self.opset_version_, _enable_internal_postprocess=self._enable_internal_postprocess) - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - self._init_session() def train(self): @@ -789,6 +786,10 @@ class ORTTrainer(): self.state_dict_ = None self._init_session() + # load training state + session_state = {name:state_dict[name].numpy() for name in state_dict} + self.session.load_state(session_state, strict) + def save_as_onnx(self, path): if not self.session: warnings.warn("ONNXRuntime training session is not initialized yet. " diff --git a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py index 32df362fc4..21c87ae68a 100644 --- a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py +++ b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py @@ -209,6 +209,88 @@ class Test_PostPasses(unittest.TestCase): assert model_info[0].name == expand_nodes[0].output[0] assert model_info[0].type == onnx_model.graph.input[0].type + def test_extra_postpass(self): + def postpass_replace_first_add_with_sub(model): + # this post pass replaces the first Add node with Sub in the model. + # Previous graph + # (subgraph 1) (subgraph 2) + # | | + # | | + # |________ ________| + # | | + # Add + # | + # (subgraph 3) + # + # Post graph + # (subgraph 1) (subgraph 2) + # | | + # | | + # |________ ________| + # | | + # Sub + # | + # (subgraph 3) + add_nodes = [n for n in model.graph.node if n.op_type == 'Add'] + add_nodes[0].op_type = "Sub" + + class MultiAdd(nn.Module): + def __init__(self, target): + super(MultiAdd, self).__init__() + self.loss = nn.CrossEntropyLoss() + self.target = target + self.linear = torch.nn.Linear(2, 2, bias=False) + + def forward(self, x, x1): + output = x + x1 + output = output + x + output = output + x1 + output = self.linear(output) + loss = self.loss(output, self.target) + return loss, output + + device = torch.device("cpu") + target = torch.ones(5, 2, dtype=torch.int64).to(device) + model = MultiAdd(target).to(device) + + x = torch.randn(5, 5, 2, dtype=torch.float32).to(device) + x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device) + + input0_desc = IODescription('x', [5, 5, 2], "float32") + input1_desc = IODescription('x1', [5, 5, 2], "float32") + output0_desc = IODescription('output0', [], "float32") + output1_desc = IODescription('output1', [5, 5, 2], "float32") + model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) + + learning_rate = torch.tensor([1.0000000e+00]).to(device) + input_args = [x, x1, learning_rate] + + onnx_model = self.get_onnx_model(model, model_desc, input_args, device, + _extra_postprocess=postpass_replace_first_add_with_sub) + + # check that extra postpass is called, and called only once. + add_nodes = self.find_nodes(onnx_model, "Add") + sub_nodes = self.find_nodes(onnx_model, "Sub") + assert len(add_nodes) == 2 + assert len(sub_nodes) == 1 + + + unprocessed_onnx_model = self.get_onnx_model(model, model_desc, input_args, device, + _extra_postprocess=None, _enable_internal_postprocess=False) + # check that the model is unchanged. + add_nodes = self.find_nodes(unprocessed_onnx_model, "Add") + sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub") + assert len(add_nodes) == 3 + assert len(sub_nodes) == 0 + + processed_onnx_model = self.get_onnx_model(unprocessed_onnx_model, model_desc, input_args, device, + _extra_postprocess=postpass_replace_first_add_with_sub) + # check that extra postpass is called, and called only once. + add_nodes = self.find_nodes(processed_onnx_model, "Add") + sub_nodes = self.find_nodes(processed_onnx_model, "Sub") + assert len(add_nodes) == 2 + assert len(sub_nodes) == 1 + if __name__ == '__main__': unittest.main(module=__name__, buffer=True)