fixes for ort_trainer.py to resume from checkpoint (#3510)

* fixes for ort_trainer.py to resume from checkpoint

* define self.state_dict_ during init

* add comment of explanation

* add unit test for restore from checkpoint

* fix file not found

Co-authored-by: suffian khan <sukha@microsoft.com>
This commit is contained in:
suffiank 2020-04-22 16:33:58 -07:00 committed by GitHub
parent e4fc83252d
commit 0e12d05cd2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 62 deletions

View file

@ -152,71 +152,69 @@ def runBertTrainingTest(gradient_accumulation_steps,
else:
return actual_losses, eval_loss
class TestOrtTrainer(unittest.TestCase):
def testMNISTTrainingAndTesting(self):
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
class MNISTWrapper():
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(MNISTWrapper.NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
def my_loss(x, target):
return F.nll_loss(F.log_softmax(x, dim=1), target)
def my_loss(x, target):
return F.nll_loss(F.log_softmax(x, dim=1), target)
def train_with_trainer(learningRate, trainer, device, train_loader, epoch):
actual_losses = []
for batch_idx, (data, target) in enumerate(train_loader):
def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch):
actual_losses = []
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
args_log_interval = 100
if batch_idx % args_log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
actual_losses = [*actual_losses, loss.cpu().numpy().item()]
return actual_losses
# TODO: comple this once ORT training can do evaluation.
def test_with_trainer(self, trainer, device, test_loader):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
output = F.log_softmax(trainer.eval_step((data), fetches=['probability']), dim=1)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
test_loss /= len(test_loader.dataset)
args_log_interval = 100
if batch_idx % args_log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
actual_losses = [*actual_losses, loss.cpu().numpy().item()]
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return actual_losses
return test_loss, correct / len(test_loader.dataset)
# TODO: comple this once ORT training can do evaluation.
def test_with_trainer(trainer, device, test_loader):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
output = F.log_softmax(trainer.eval_step((data), fetches=['probability']), dim=1)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_loss, correct / len(test_loader.dataset)
def mnist_model_description():
input_desc = IODescription('input1', ['batch', 784], torch.float32)
label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10)
loss_desc = IODescription('loss', [], torch.float32)
probability_desc = IODescription('probability', ['batch', 10], torch.float32)
return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
torch.manual_seed(1)
def mnist_model_description():
input_desc = IODescription('input1', ['batch', 784], torch.float32)
label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10)
loss_desc = IODescription('loss', [], torch.float32)
probability_desc = IODescription('probability', ['batch', 10], torch.float32)
return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
def get_loaders(self):
args_batch_size = 64
args_test_batch_size = 1000
@ -232,17 +230,31 @@ class TestOrtTrainer(unittest.TestCase):
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args_test_batch_size, shuffle=True, **kwargs)
device = torch.device("cuda")
return train_loader, test_loader
def get_model(self):
input_size = 784
hidden_size = 500
num_classes = 10
model = NeuralNet(input_size, hidden_size, num_classes)
model_desc = mnist_model_description()
# warning: changes the pytorch random generator state
model = MNISTWrapper.NeuralNet(input_size, hidden_size, num_classes)
model_desc = MNISTWrapper.mnist_model_description()
return model, model_desc
trainer = ORTTrainer(model, my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
torch.float32), device, _opset_version=12)
def get_trainer(self, model, model_desc, device):
return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
torch.float32), device, _opset_version=12)
class TestOrtTrainer(unittest.TestCase):
def testMNISTTrainingAndTesting(self):
torch.manual_seed(1)
device = torch.device("cuda")
mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()
trainer = mnist.get_trainer(model, model_desc, device)
learningRate = 0.01
args_epochs = 2
@ -257,9 +269,62 @@ class TestOrtTrainer(unittest.TestCase):
actual_losses = []
actual_test_losses, actual_accuracies = [], []
for epoch in range(1, args_epochs + 1):
actual_losses = [*actual_losses, *train_with_trainer(learningRate, trainer, device, train_loader, epoch)]
actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch)]
test_loss, accuracy = test_with_trainer(trainer, device, test_loader)
test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader)
actual_test_losses = [*actual_test_losses, test_loss]
actual_accuracies = [*actual_accuracies, accuracy]
# if you update outcomes, also do so for resume from checkpoint test
# args_checkpoint_epoch = 1
# if epoch == args_checkpoint_epoch:
# state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()}
# torch.save(state, get_name("ckpt_mnist.pt"))
print("actual_losses=", actual_losses)
print("actual_test_losses=", actual_test_losses)
print("actual_accuracies=", actual_accuracies)
# to update expected outcomes, enable pdb and run the test with -s and copy paste outputs
# import pdb; pdb.set_trace()
rtol = 1e-03
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch")
assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch")
def testMNISTResumeTrainingAndTesting(self):
torch.manual_seed(1)
device = torch.device("cuda")
mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()
learningRate = 0.01
args_epochs = 2
args_checkpoint_epoch = 1
# should match those in test without checkpointing
expected_losses = [0.23006735742092133, 0.48427966237068176,
0.30716797709465027, 0.3238796889781952, 0.19543828070163727, 0.3561663031578064,
0.3089643716812134, 0.37738722562789917, 0.24883587658405304, 0.30744990706443787]
expected_test_losses = [0.25183824462890625]
expected_test_accuracies = [0.9304]
actual_losses = []
actual_test_losses, actual_accuracies = [], []
# restore from checkpoint
resume_trainer = mnist.get_trainer(model, model_desc, device)
checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu")
torch.set_rng_state(checkpoint['rng_state'])
resume_trainer.load_state_dict(checkpoint['model'], strict=True)
# continue ..
for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1):
actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch)]
test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader)
actual_test_losses = [*actual_test_losses, test_loss]
actual_accuracies = [*actual_accuracies, accuracy]

BIN
onnxruntime/test/testdata/ckpt_mnist.pt vendored Normal file

Binary file not shown.

View file

@ -598,6 +598,7 @@ class ORTTrainer():
self.frozen_weights_ = frozen_weights
self.opset_version_ = _opset_version
self.loss_scale_input_name = ''
self.state_dict_ = None
self._init_session()
@ -645,6 +646,10 @@ class ORTTrainer():
*self.model_desc_.outputs_,
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)]
if self.state_dict_:
self.load_state_dict(self.state_dict_, self.strict_)
self.state_dict_ = None
def _init_onnx_model(self, inputs):
if self.onnx_model_ is not None:
return
@ -672,6 +677,14 @@ class ORTTrainer():
return torch_state
def load_state_dict(self, state_dict, strict=False):
# Note: It may happen ONNX model has not yet been initialized
# In this case we cache a reference to desired state and delay the restore until after initialization
# Unexpected behavior will result if the user changes the reference before initialization
if not self.session:
self.state_dict_ = state_dict
self.strict_ = strict
return
session_state = {}
for name in state_dict:
session_state[name] = state_dict[name].numpy()