mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
fixes for ort_trainer.py to resume from checkpoint (#3510)
* fixes for ort_trainer.py to resume from checkpoint * define self.state_dict_ during init * add comment of explanation * add unit test for restore from checkpoint * fix file not found Co-authored-by: suffian khan <sukha@microsoft.com>
This commit is contained in:
parent
e4fc83252d
commit
0e12d05cd2
3 changed files with 140 additions and 62 deletions
|
|
@ -152,71 +152,69 @@ def runBertTrainingTest(gradient_accumulation_steps,
|
|||
else:
|
||||
return actual_losses, eval_loss
|
||||
|
||||
class TestOrtTrainer(unittest.TestCase):
|
||||
def testMNISTTrainingAndTesting(self):
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
class MNISTWrapper():
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(MNISTWrapper.NeuralNet, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
def my_loss(x, target):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target)
|
||||
def my_loss(x, target):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target)
|
||||
|
||||
def train_with_trainer(learningRate, trainer, device, train_loader, epoch):
|
||||
actual_losses = []
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch):
|
||||
actual_losses = []
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
|
||||
|
||||
args_log_interval = 100
|
||||
if batch_idx % args_log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
actual_losses = [*actual_losses, loss.cpu().numpy().item()]
|
||||
|
||||
return actual_losses
|
||||
|
||||
# TODO: comple this once ORT training can do evaluation.
|
||||
def test_with_trainer(self, trainer, device, test_loader):
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
output = F.log_softmax(trainer.eval_step((data), fetches=['probability']), dim=1)
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
args_log_interval = 100
|
||||
if batch_idx % args_log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
actual_losses = [*actual_losses, loss.cpu().numpy().item()]
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
return actual_losses
|
||||
return test_loss, correct / len(test_loader.dataset)
|
||||
|
||||
# TODO: comple this once ORT training can do evaluation.
|
||||
def test_with_trainer(trainer, device, test_loader):
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
output = F.log_softmax(trainer.eval_step((data), fetches=['probability']), dim=1)
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
return test_loss, correct / len(test_loader.dataset)
|
||||
|
||||
def mnist_model_description():
|
||||
input_desc = IODescription('input1', ['batch', 784], torch.float32)
|
||||
label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10)
|
||||
loss_desc = IODescription('loss', [], torch.float32)
|
||||
probability_desc = IODescription('probability', ['batch', 10], torch.float32)
|
||||
return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
|
||||
|
||||
torch.manual_seed(1)
|
||||
def mnist_model_description():
|
||||
input_desc = IODescription('input1', ['batch', 784], torch.float32)
|
||||
label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10)
|
||||
loss_desc = IODescription('loss', [], torch.float32)
|
||||
probability_desc = IODescription('probability', ['batch', 10], torch.float32)
|
||||
return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
|
||||
|
||||
def get_loaders(self):
|
||||
args_batch_size = 64
|
||||
args_test_batch_size = 1000
|
||||
|
||||
|
|
@ -232,17 +230,31 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args_test_batch_size, shuffle=True, **kwargs)
|
||||
|
||||
device = torch.device("cuda")
|
||||
return train_loader, test_loader
|
||||
|
||||
def get_model(self):
|
||||
input_size = 784
|
||||
hidden_size = 500
|
||||
num_classes = 10
|
||||
model = NeuralNet(input_size, hidden_size, num_classes)
|
||||
|
||||
model_desc = mnist_model_description()
|
||||
# warning: changes the pytorch random generator state
|
||||
model = MNISTWrapper.NeuralNet(input_size, hidden_size, num_classes)
|
||||
model_desc = MNISTWrapper.mnist_model_description()
|
||||
return model, model_desc
|
||||
|
||||
trainer = ORTTrainer(model, my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
|
||||
torch.float32), device, _opset_version=12)
|
||||
def get_trainer(self, model, model_desc, device):
|
||||
return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
|
||||
torch.float32), device, _opset_version=12)
|
||||
|
||||
class TestOrtTrainer(unittest.TestCase):
|
||||
def testMNISTTrainingAndTesting(self):
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda")
|
||||
|
||||
mnist = MNISTWrapper()
|
||||
train_loader, test_loader = mnist.get_loaders()
|
||||
model, model_desc = mnist.get_model()
|
||||
trainer = mnist.get_trainer(model, model_desc, device)
|
||||
|
||||
learningRate = 0.01
|
||||
args_epochs = 2
|
||||
|
|
@ -257,9 +269,62 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
actual_losses = []
|
||||
actual_test_losses, actual_accuracies = [], []
|
||||
for epoch in range(1, args_epochs + 1):
|
||||
actual_losses = [*actual_losses, *train_with_trainer(learningRate, trainer, device, train_loader, epoch)]
|
||||
actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch)]
|
||||
|
||||
test_loss, accuracy = test_with_trainer(trainer, device, test_loader)
|
||||
test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader)
|
||||
actual_test_losses = [*actual_test_losses, test_loss]
|
||||
actual_accuracies = [*actual_accuracies, accuracy]
|
||||
|
||||
# if you update outcomes, also do so for resume from checkpoint test
|
||||
# args_checkpoint_epoch = 1
|
||||
# if epoch == args_checkpoint_epoch:
|
||||
# state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()}
|
||||
# torch.save(state, get_name("ckpt_mnist.pt"))
|
||||
|
||||
|
||||
print("actual_losses=", actual_losses)
|
||||
print("actual_test_losses=", actual_test_losses)
|
||||
print("actual_accuracies=", actual_accuracies)
|
||||
|
||||
# to update expected outcomes, enable pdb and run the test with -s and copy paste outputs
|
||||
# import pdb; pdb.set_trace()
|
||||
rtol = 1e-03
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch")
|
||||
assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch")
|
||||
|
||||
def testMNISTResumeTrainingAndTesting(self):
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda")
|
||||
|
||||
mnist = MNISTWrapper()
|
||||
train_loader, test_loader = mnist.get_loaders()
|
||||
model, model_desc = mnist.get_model()
|
||||
|
||||
learningRate = 0.01
|
||||
args_epochs = 2
|
||||
args_checkpoint_epoch = 1
|
||||
# should match those in test without checkpointing
|
||||
expected_losses = [0.23006735742092133, 0.48427966237068176,
|
||||
0.30716797709465027, 0.3238796889781952, 0.19543828070163727, 0.3561663031578064,
|
||||
0.3089643716812134, 0.37738722562789917, 0.24883587658405304, 0.30744990706443787]
|
||||
expected_test_losses = [0.25183824462890625]
|
||||
expected_test_accuracies = [0.9304]
|
||||
|
||||
actual_losses = []
|
||||
actual_test_losses, actual_accuracies = [], []
|
||||
|
||||
# restore from checkpoint
|
||||
resume_trainer = mnist.get_trainer(model, model_desc, device)
|
||||
checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu")
|
||||
torch.set_rng_state(checkpoint['rng_state'])
|
||||
resume_trainer.load_state_dict(checkpoint['model'], strict=True)
|
||||
|
||||
# continue ..
|
||||
for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1):
|
||||
actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch)]
|
||||
|
||||
test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader)
|
||||
actual_test_losses = [*actual_test_losses, test_loss]
|
||||
actual_accuracies = [*actual_accuracies, accuracy]
|
||||
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/ckpt_mnist.pt
vendored
Normal file
BIN
onnxruntime/test/testdata/ckpt_mnist.pt
vendored
Normal file
Binary file not shown.
|
|
@ -598,6 +598,7 @@ class ORTTrainer():
|
|||
self.frozen_weights_ = frozen_weights
|
||||
self.opset_version_ = _opset_version
|
||||
self.loss_scale_input_name = ''
|
||||
self.state_dict_ = None
|
||||
|
||||
self._init_session()
|
||||
|
||||
|
|
@ -645,6 +646,10 @@ class ORTTrainer():
|
|||
*self.model_desc_.outputs_,
|
||||
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)]
|
||||
|
||||
if self.state_dict_:
|
||||
self.load_state_dict(self.state_dict_, self.strict_)
|
||||
self.state_dict_ = None
|
||||
|
||||
def _init_onnx_model(self, inputs):
|
||||
if self.onnx_model_ is not None:
|
||||
return
|
||||
|
|
@ -672,6 +677,14 @@ class ORTTrainer():
|
|||
return torch_state
|
||||
|
||||
def load_state_dict(self, state_dict, strict=False):
|
||||
# Note: It may happen ONNX model has not yet been initialized
|
||||
# In this case we cache a reference to desired state and delay the restore until after initialization
|
||||
# Unexpected behavior will result if the user changes the reference before initialization
|
||||
if not self.session:
|
||||
self.state_dict_ = state_dict
|
||||
self.strict_ = strict
|
||||
return
|
||||
|
||||
session_state = {}
|
||||
for name in state_dict:
|
||||
session_state[name] = state_dict[name].numpy()
|
||||
|
|
|
|||
Loading…
Reference in a new issue