mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix session load state & run extra_postpasses only once (#4255)
* Fix session load state & run extra_postpasses only once * add testcase for onnx model as well
This commit is contained in:
parent
d3c5cb6349
commit
15cb4b3023
3 changed files with 136 additions and 5 deletions
|
|
@ -287,9 +287,9 @@ class MNISTWrapper():
|
|||
return model, model_desc
|
||||
|
||||
def get_trainer(self, model, model_desc, device, onnx_opset_ver=12, frozen_weights=[],
|
||||
internal_loss_fn=False, get_lr_this_step=None):
|
||||
internal_loss_fn=False, get_lr_this_step=None, optimizer="SGDOptimizer"):
|
||||
loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None
|
||||
return ORTTrainer(model, loss_fn, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
|
||||
return ORTTrainer(model, loss_fn, model_desc, optimizer, None, IODescription('Learning_Rate', [1, ],
|
||||
torch.float32), device, _opset_version=onnx_opset_ver, frozen_weights=frozen_weights,
|
||||
get_lr_this_step=get_lr_this_step)
|
||||
|
||||
|
|
@ -606,6 +606,54 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
ckpt_loss, _ = trainer.eval_step(data, target)
|
||||
assert loss == ckpt_loss
|
||||
|
||||
loaded_state_dict = trainer.state_dict()
|
||||
assert state_dict.keys() == loaded_state_dict.keys()
|
||||
|
||||
def testMNISTTrainingCheckpoint(self):
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda")
|
||||
|
||||
mnist = MNISTWrapper()
|
||||
train_loader, test_loader = mnist.get_loaders()
|
||||
model, model_desc = mnist.get_model()
|
||||
|
||||
trainer = mnist.get_trainer(model, model_desc, device,
|
||||
optimizer='LambOptimizer', frozen_weights=['fc1.weight'])
|
||||
|
||||
learningRate = 0.02
|
||||
epoch = 0
|
||||
|
||||
# do 5 train step
|
||||
for i in range(5):
|
||||
data, target = next(iter(train_loader))
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
|
||||
|
||||
# do one eval step
|
||||
data, target = next(iter(train_loader))
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
loss, _ = trainer.eval_step(data, target)
|
||||
|
||||
# save checkpoint, load model and compare
|
||||
state_dict = trainer.state_dict()
|
||||
|
||||
new_model, _ = mnist.get_model()
|
||||
trainer = mnist.get_trainer(new_model, model_desc, device,
|
||||
optimizer='LambOptimizer', frozen_weights=['fc1.weight'])
|
||||
trainer.load_state_dict(state_dict)
|
||||
|
||||
ckpt_loss, _ = trainer.eval_step(data, target)
|
||||
assert loss == ckpt_loss
|
||||
|
||||
loaded_state_dict = trainer.state_dict()
|
||||
assert state_dict.keys() == loaded_state_dict.keys()
|
||||
for key in state_dict:
|
||||
assert np.array_equal(state_dict[key], loaded_state_dict[key])
|
||||
|
||||
def testBertTrainingBasic(self):
|
||||
expected_losses = [
|
||||
11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543,
|
||||
|
|
|
|||
|
|
@ -721,9 +721,6 @@ class ORTTrainer():
|
|||
self.onnx_model_ = convert_model_loss_fn_to_onnx(
|
||||
self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs, opset_version=self.opset_version_, _enable_internal_postprocess=self._enable_internal_postprocess)
|
||||
|
||||
if self._extra_postprocess:
|
||||
self._extra_postprocess(self.onnx_model_)
|
||||
|
||||
self._init_session()
|
||||
|
||||
def train(self):
|
||||
|
|
@ -789,6 +786,10 @@ class ORTTrainer():
|
|||
self.state_dict_ = None
|
||||
self._init_session()
|
||||
|
||||
# load training state
|
||||
session_state = {name:state_dict[name].numpy() for name in state_dict}
|
||||
self.session.load_state(session_state, strict)
|
||||
|
||||
def save_as_onnx(self, path):
|
||||
if not self.session:
|
||||
warnings.warn("ONNXRuntime training session is not initialized yet. "
|
||||
|
|
|
|||
|
|
@ -209,6 +209,88 @@ class Test_PostPasses(unittest.TestCase):
|
|||
assert model_info[0].name == expand_nodes[0].output[0]
|
||||
assert model_info[0].type == onnx_model.graph.input[0].type
|
||||
|
||||
def test_extra_postpass(self):
|
||||
def postpass_replace_first_add_with_sub(model):
|
||||
# this post pass replaces the first Add node with Sub in the model.
|
||||
# Previous graph
|
||||
# (subgraph 1) (subgraph 2)
|
||||
# | |
|
||||
# | |
|
||||
# |________ ________|
|
||||
# | |
|
||||
# Add
|
||||
# |
|
||||
# (subgraph 3)
|
||||
#
|
||||
# Post graph
|
||||
# (subgraph 1) (subgraph 2)
|
||||
# | |
|
||||
# | |
|
||||
# |________ ________|
|
||||
# | |
|
||||
# Sub
|
||||
# |
|
||||
# (subgraph 3)
|
||||
add_nodes = [n for n in model.graph.node if n.op_type == 'Add']
|
||||
add_nodes[0].op_type = "Sub"
|
||||
|
||||
class MultiAdd(nn.Module):
|
||||
def __init__(self, target):
|
||||
super(MultiAdd, self).__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
self.target = target
|
||||
self.linear = torch.nn.Linear(2, 2, bias=False)
|
||||
|
||||
def forward(self, x, x1):
|
||||
output = x + x1
|
||||
output = output + x
|
||||
output = output + x1
|
||||
output = self.linear(output)
|
||||
loss = self.loss(output, self.target)
|
||||
return loss, output
|
||||
|
||||
device = torch.device("cpu")
|
||||
target = torch.ones(5, 2, dtype=torch.int64).to(device)
|
||||
model = MultiAdd(target).to(device)
|
||||
|
||||
x = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
|
||||
x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
|
||||
|
||||
input0_desc = IODescription('x', [5, 5, 2], "float32")
|
||||
input1_desc = IODescription('x1', [5, 5, 2], "float32")
|
||||
output0_desc = IODescription('output0', [], "float32")
|
||||
output1_desc = IODescription('output1', [5, 5, 2], "float32")
|
||||
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
|
||||
|
||||
learning_rate = torch.tensor([1.0000000e+00]).to(device)
|
||||
input_args = [x, x1, learning_rate]
|
||||
|
||||
onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
|
||||
_extra_postprocess=postpass_replace_first_add_with_sub)
|
||||
|
||||
# check that extra postpass is called, and called only once.
|
||||
add_nodes = self.find_nodes(onnx_model, "Add")
|
||||
sub_nodes = self.find_nodes(onnx_model, "Sub")
|
||||
assert len(add_nodes) == 2
|
||||
assert len(sub_nodes) == 1
|
||||
|
||||
|
||||
unprocessed_onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
|
||||
_extra_postprocess=None, _enable_internal_postprocess=False)
|
||||
# check that the model is unchanged.
|
||||
add_nodes = self.find_nodes(unprocessed_onnx_model, "Add")
|
||||
sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub")
|
||||
assert len(add_nodes) == 3
|
||||
assert len(sub_nodes) == 0
|
||||
|
||||
processed_onnx_model = self.get_onnx_model(unprocessed_onnx_model, model_desc, input_args, device,
|
||||
_extra_postprocess=postpass_replace_first_add_with_sub)
|
||||
# check that extra postpass is called, and called only once.
|
||||
add_nodes = self.find_nodes(processed_onnx_model, "Add")
|
||||
sub_nodes = self.find_nodes(processed_onnx_model, "Sub")
|
||||
assert len(add_nodes) == 2
|
||||
assert len(sub_nodes) == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(module=__name__, buffer=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue