diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index b2934939ff..00a73a0b71 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -152,71 +152,69 @@ def runBertTrainingTest(gradient_accumulation_steps, else: return actual_losses, eval_loss -class TestOrtTrainer(unittest.TestCase): - def testMNISTTrainingAndTesting(self): - class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super(NeuralNet, self).__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) +class MNISTWrapper(): + class NeuralNet(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(MNISTWrapper.NeuralNet, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, num_classes) - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out - def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) + def my_loss(x, target): + return F.nll_loss(F.log_softmax(x, dim=1), target) - def train_with_trainer(learningRate, trainer, device, train_loader, epoch): - actual_losses = [] - for batch_idx, (data, target) in enumerate(train_loader): + def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch): + actual_losses = [] + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + args_log_interval = 100 + if batch_idx % args_log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + actual_losses = [*actual_losses, loss.cpu().numpy().item()] + + return actual_losses + + # TODO: comple this once ORT training can do evaluation. + def test_with_trainer(self, trainer, device, test_loader): + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: data, target = data.to(device), target.to(device) data = data.reshape(data.shape[0], -1) + output = F.log_softmax(trainer.eval_step((data), fetches=['probability']), dim=1) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + test_loss /= len(test_loader.dataset) - args_log_interval = 100 - if batch_idx % args_log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) - actual_losses = [*actual_losses, loss.cpu().numpy().item()] + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) - return actual_losses + return test_loss, correct / len(test_loader.dataset) - # TODO: comple this once ORT training can do evaluation. - def test_with_trainer(trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - output = F.log_softmax(trainer.eval_step((data), fetches=['probability']), dim=1) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) - - return test_loss, correct / len(test_loader.dataset) - - def mnist_model_description(): - input_desc = IODescription('input1', ['batch', 784], torch.float32) - label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10) - loss_desc = IODescription('loss', [], torch.float32) - probability_desc = IODescription('probability', ['batch', 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - torch.manual_seed(1) + def mnist_model_description(): + input_desc = IODescription('input1', ['batch', 784], torch.float32) + label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10) + loss_desc = IODescription('loss', [], torch.float32) + probability_desc = IODescription('probability', ['batch', 10], torch.float32) + return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) + def get_loaders(self): args_batch_size = 64 args_test_batch_size = 1000 @@ -232,17 +230,31 @@ class TestOrtTrainer(unittest.TestCase): transforms.Normalize((0.1307,), (0.3081,))])), batch_size=args_test_batch_size, shuffle=True, **kwargs) - device = torch.device("cuda") + return train_loader, test_loader + def get_model(self): input_size = 784 hidden_size = 500 num_classes = 10 - model = NeuralNet(input_size, hidden_size, num_classes) - model_desc = mnist_model_description() + # warning: changes the pytorch random generator state + model = MNISTWrapper.NeuralNet(input_size, hidden_size, num_classes) + model_desc = MNISTWrapper.mnist_model_description() + return model, model_desc - trainer = ORTTrainer(model, my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ], - torch.float32), device, _opset_version=12) + def get_trainer(self, model, model_desc, device): + return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ], + torch.float32), device, _opset_version=12) + +class TestOrtTrainer(unittest.TestCase): + def testMNISTTrainingAndTesting(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + trainer = mnist.get_trainer(model, model_desc, device) learningRate = 0.01 args_epochs = 2 @@ -257,9 +269,62 @@ class TestOrtTrainer(unittest.TestCase): actual_losses = [] actual_test_losses, actual_accuracies = [], [] for epoch in range(1, args_epochs + 1): - actual_losses = [*actual_losses, *train_with_trainer(learningRate, trainer, device, train_loader, epoch)] + actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch)] - test_loss, accuracy = test_with_trainer(trainer, device, test_loader) + test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader) + actual_test_losses = [*actual_test_losses, test_loss] + actual_accuracies = [*actual_accuracies, accuracy] + + # if you update outcomes, also do so for resume from checkpoint test + # args_checkpoint_epoch = 1 + # if epoch == args_checkpoint_epoch: + # state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()} + # torch.save(state, get_name("ckpt_mnist.pt")) + + + print("actual_losses=", actual_losses) + print("actual_test_losses=", actual_test_losses) + print("actual_accuracies=", actual_accuracies) + + # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs + # import pdb; pdb.set_trace() + rtol = 1e-03 + assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") + assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch") + assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch") + + def testMNISTResumeTrainingAndTesting(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + learningRate = 0.01 + args_epochs = 2 + args_checkpoint_epoch = 1 + # should match those in test without checkpointing + expected_losses = [0.23006735742092133, 0.48427966237068176, + 0.30716797709465027, 0.3238796889781952, 0.19543828070163727, 0.3561663031578064, + 0.3089643716812134, 0.37738722562789917, 0.24883587658405304, 0.30744990706443787] + expected_test_losses = [0.25183824462890625] + expected_test_accuracies = [0.9304] + + actual_losses = [] + actual_test_losses, actual_accuracies = [], [] + + # restore from checkpoint + resume_trainer = mnist.get_trainer(model, model_desc, device) + checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu") + torch.set_rng_state(checkpoint['rng_state']) + resume_trainer.load_state_dict(checkpoint['model'], strict=True) + + # continue .. + for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1): + actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch)] + + test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader) actual_test_losses = [*actual_test_losses, test_loss] actual_accuracies = [*actual_accuracies, accuracy] diff --git a/onnxruntime/test/testdata/ckpt_mnist.pt b/onnxruntime/test/testdata/ckpt_mnist.pt new file mode 100644 index 0000000000..01a9a9b448 Binary files /dev/null and b/onnxruntime/test/testdata/ckpt_mnist.pt differ diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index f92a632fbc..d9f58e0704 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -598,6 +598,7 @@ class ORTTrainer(): self.frozen_weights_ = frozen_weights self.opset_version_ = _opset_version self.loss_scale_input_name = '' + self.state_dict_ = None self._init_session() @@ -645,6 +646,10 @@ class ORTTrainer(): *self.model_desc_.outputs_, IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)] + if self.state_dict_: + self.load_state_dict(self.state_dict_, self.strict_) + self.state_dict_ = None + def _init_onnx_model(self, inputs): if self.onnx_model_ is not None: return @@ -672,6 +677,14 @@ class ORTTrainer(): return torch_state def load_state_dict(self, state_dict, strict=False): + # Note: It may happen ONNX model has not yet been initialized + # In this case we cache a reference to desired state and delay the restore until after initialization + # Unexpected behavior will result if the user changes the reference before initialization + if not self.session: + self.state_dict_ = state_dict + self.strict_ = strict + return + session_state = {} for name in state_dict: session_state[name] = state_dict[name].numpy()