Fix session load state & run extra_postpasses only once (#4255)

* Fix session load state & run extra_postpasses only once

* add testcase for onnx model as well
This commit is contained in:
Bowen Bao 2020-06-23 11:45:26 -07:00 committed by GitHub
parent d3c5cb6349
commit 15cb4b3023
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 136 additions and 5 deletions

View file

@ -287,9 +287,9 @@ class MNISTWrapper():
return model, model_desc
def get_trainer(self, model, model_desc, device, onnx_opset_ver=12, frozen_weights=[],
internal_loss_fn=False, get_lr_this_step=None):
internal_loss_fn=False, get_lr_this_step=None, optimizer="SGDOptimizer"):
loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None
return ORTTrainer(model, loss_fn, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
return ORTTrainer(model, loss_fn, model_desc, optimizer, None, IODescription('Learning_Rate', [1, ],
torch.float32), device, _opset_version=onnx_opset_ver, frozen_weights=frozen_weights,
get_lr_this_step=get_lr_this_step)
@ -606,6 +606,54 @@ class TestOrtTrainer(unittest.TestCase):
ckpt_loss, _ = trainer.eval_step(data, target)
assert loss == ckpt_loss
loaded_state_dict = trainer.state_dict()
assert state_dict.keys() == loaded_state_dict.keys()
def testMNISTTrainingCheckpoint(self):
torch.manual_seed(1)
device = torch.device("cuda")
mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()
trainer = mnist.get_trainer(model, model_desc, device,
optimizer='LambOptimizer', frozen_weights=['fc1.weight'])
learningRate = 0.02
epoch = 0
# do 5 train step
for i in range(5):
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
# do one eval step
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
loss, _ = trainer.eval_step(data, target)
# save checkpoint, load model and compare
state_dict = trainer.state_dict()
new_model, _ = mnist.get_model()
trainer = mnist.get_trainer(new_model, model_desc, device,
optimizer='LambOptimizer', frozen_weights=['fc1.weight'])
trainer.load_state_dict(state_dict)
ckpt_loss, _ = trainer.eval_step(data, target)
assert loss == ckpt_loss
loaded_state_dict = trainer.state_dict()
assert state_dict.keys() == loaded_state_dict.keys()
for key in state_dict:
assert np.array_equal(state_dict[key], loaded_state_dict[key])
def testBertTrainingBasic(self):
expected_losses = [
11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543,

View file

@ -721,9 +721,6 @@ class ORTTrainer():
self.onnx_model_ = convert_model_loss_fn_to_onnx(
self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs, opset_version=self.opset_version_, _enable_internal_postprocess=self._enable_internal_postprocess)
if self._extra_postprocess:
self._extra_postprocess(self.onnx_model_)
self._init_session()
def train(self):
@ -789,6 +786,10 @@ class ORTTrainer():
self.state_dict_ = None
self._init_session()
# load training state
session_state = {name:state_dict[name].numpy() for name in state_dict}
self.session.load_state(session_state, strict)
def save_as_onnx(self, path):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "

View file

@ -209,6 +209,88 @@ class Test_PostPasses(unittest.TestCase):
assert model_info[0].name == expand_nodes[0].output[0]
assert model_info[0].type == onnx_model.graph.input[0].type
def test_extra_postpass(self):
def postpass_replace_first_add_with_sub(model):
# this post pass replaces the first Add node with Sub in the model.
# Previous graph
# (subgraph 1) (subgraph 2)
# | |
# | |
# |________ ________|
# | |
# Add
# |
# (subgraph 3)
#
# Post graph
# (subgraph 1) (subgraph 2)
# | |
# | |
# |________ ________|
# | |
# Sub
# |
# (subgraph 3)
add_nodes = [n for n in model.graph.node if n.op_type == 'Add']
add_nodes[0].op_type = "Sub"
class MultiAdd(nn.Module):
def __init__(self, target):
super(MultiAdd, self).__init__()
self.loss = nn.CrossEntropyLoss()
self.target = target
self.linear = torch.nn.Linear(2, 2, bias=False)
def forward(self, x, x1):
output = x + x1
output = output + x
output = output + x1
output = self.linear(output)
loss = self.loss(output, self.target)
return loss, output
device = torch.device("cpu")
target = torch.ones(5, 2, dtype=torch.int64).to(device)
model = MultiAdd(target).to(device)
x = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
input0_desc = IODescription('x', [5, 5, 2], "float32")
input1_desc = IODescription('x1', [5, 5, 2], "float32")
output0_desc = IODescription('output0', [], "float32")
output1_desc = IODescription('output1', [5, 5, 2], "float32")
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
learning_rate = torch.tensor([1.0000000e+00]).to(device)
input_args = [x, x1, learning_rate]
onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
_extra_postprocess=postpass_replace_first_add_with_sub)
# check that extra postpass is called, and called only once.
add_nodes = self.find_nodes(onnx_model, "Add")
sub_nodes = self.find_nodes(onnx_model, "Sub")
assert len(add_nodes) == 2
assert len(sub_nodes) == 1
unprocessed_onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
_extra_postprocess=None, _enable_internal_postprocess=False)
# check that the model is unchanged.
add_nodes = self.find_nodes(unprocessed_onnx_model, "Add")
sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub")
assert len(add_nodes) == 3
assert len(sub_nodes) == 0
processed_onnx_model = self.get_onnx_model(unprocessed_onnx_model, model_desc, input_args, device,
_extra_postprocess=postpass_replace_first_add_with_sub)
# check that extra postpass is called, and called only once.
add_nodes = self.find_nodes(processed_onnx_model, "Add")
sub_nodes = self.find_nodes(processed_onnx_model, "Sub")
assert len(add_nodes) == 2
assert len(sub_nodes) == 1
if __name__ == '__main__':
unittest.main(module=__name__, buffer=True)