mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
* Fix session load state & run extra_postpasses only once * add testcase for onnx model as well
296 lines
12 KiB
Python
296 lines
12 KiB
Python
import unittest
|
|
import pytest
|
|
import sys
|
|
import os
|
|
import copy
|
|
from numpy.testing import assert_allclose, assert_array_equal
|
|
|
|
import onnx
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from orttraining_test_utils import map_optimizer_attributes
|
|
from orttraining_test_transformers import BertModelTest, BertForPreTraining
|
|
from orttraining_test_data_loader import create_ort_test_dataloader
|
|
from orttraining_test_bert_postprocess import postprocess_model
|
|
import onnxruntime
|
|
|
|
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample
|
|
|
|
torch.manual_seed(1)
|
|
onnxruntime.set_seed(1)
|
|
|
|
class Test_PostPasses(unittest.TestCase):
|
|
def get_onnx_model(self, model, model_desc, inputs, device,
|
|
_enable_internal_postprocess=True, _extra_postprocess=None):
|
|
lr_desc = IODescription('Learning_Rate', [1,], torch.float32)
|
|
model = ORTTrainer(model,
|
|
None,
|
|
model_desc,
|
|
"LambOptimizer",
|
|
map_optimizer_attributes,
|
|
lr_desc,
|
|
device,
|
|
world_rank=0,
|
|
world_size=1,
|
|
_opset_version=12,
|
|
_enable_internal_postprocess=_enable_internal_postprocess,
|
|
_extra_postprocess=_extra_postprocess)
|
|
|
|
train_output = model.train_step(*inputs)
|
|
return model.onnx_model_
|
|
|
|
def count_all_nodes(self, model):
|
|
return len(model.graph.node)
|
|
|
|
def count_nodes(self, model, node_type):
|
|
count = 0
|
|
for node in model.graph.node:
|
|
if node.op_type == node_type:
|
|
count += 1
|
|
return count
|
|
|
|
def find_nodes(self, model, node_type):
|
|
nodes = []
|
|
for node in model.graph.node:
|
|
if node.op_type == node_type:
|
|
nodes.append(node)
|
|
return nodes
|
|
|
|
def get_name(self, name):
|
|
if os.path.exists(name):
|
|
return name
|
|
rel = os.path.join("testdata", name)
|
|
if os.path.exists(rel):
|
|
return rel
|
|
this = os.path.dirname(__file__)
|
|
data = os.path.join(this, "..", "..", "..", "..", "onnxruntime", "test", "testdata")
|
|
res = os.path.join(data, name)
|
|
if os.path.exists(res):
|
|
return res
|
|
raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res))
|
|
|
|
def test_layer_norm(self):
|
|
class LayerNormNet(nn.Module):
|
|
def __init__(self, target):
|
|
super(LayerNormNet, self).__init__()
|
|
self.ln_1 = nn.LayerNorm(10)
|
|
self.loss = nn.CrossEntropyLoss()
|
|
self.target = target
|
|
|
|
def forward(self, x):
|
|
output1 = self.ln_1(x)
|
|
loss = self.loss(output1, self.target)
|
|
return loss, output1
|
|
|
|
device = torch.device("cpu")
|
|
target = torch.ones(20, 10, 10, dtype=torch.int64).to(device)
|
|
model = LayerNormNet(target)
|
|
input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device)
|
|
|
|
input_desc = IODescription('input', [], "float32")
|
|
output0_desc = IODescription('output0', [], "float32")
|
|
output1_desc = IODescription('output1', [20, 5, 10, 10], "float32")
|
|
model_desc = ModelDescription([input_desc], [output0_desc, output1_desc])
|
|
|
|
learning_rate = torch.tensor([1.0000000e+00]).to(device)
|
|
input_args=[input, learning_rate]
|
|
|
|
onnx_model = self.get_onnx_model(model, model_desc, input_args, device)
|
|
|
|
count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization")
|
|
count_nodes = self.count_all_nodes(onnx_model)
|
|
|
|
assert count_layer_norm == 1
|
|
assert count_nodes == 3
|
|
|
|
def test_expand(self):
|
|
class ExpandNet(nn.Module):
|
|
def __init__(self, target):
|
|
super(ExpandNet, self).__init__()
|
|
self.loss = nn.CrossEntropyLoss()
|
|
self.target = target
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x, x1):
|
|
output = x.expand_as(x1)
|
|
output = self.linear(output)
|
|
output = output + output
|
|
loss = self.loss(output, self.target)
|
|
return loss, output
|
|
|
|
device = torch.device("cpu")
|
|
target = torch.ones(5, 5, 2, dtype=torch.int64).to(device)
|
|
model = ExpandNet(target).to(device)
|
|
|
|
x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device)
|
|
x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device)
|
|
|
|
input0_desc = IODescription('x', [5, 3, 1, 2], "float32")
|
|
input1_desc = IODescription('x1', [5, 3, 5, 2], "float32")
|
|
output0_desc = IODescription('output0', [], "float32")
|
|
output1_desc = IODescription('output1', [5, 3, 5, 2], "float32")
|
|
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
|
|
|
|
learning_rate = torch.tensor([1.0000000e+00]).to(device)
|
|
input_args = [x, x1, learning_rate]
|
|
|
|
onnx_model = self.get_onnx_model(model, model_desc, input_args, device)
|
|
|
|
# check that expand output has shape
|
|
expand_nodes = self.find_nodes(onnx_model, "Expand")
|
|
assert len(expand_nodes) == 1
|
|
|
|
model_info = onnx_model.graph.value_info
|
|
assert model_info[0].name == expand_nodes[0].output[0]
|
|
assert model_info[0].type == onnx_model.graph.input[1].type
|
|
|
|
def test_bert(self):
|
|
device = torch.device("cpu")
|
|
|
|
model_tester = BertModelTest.BertModelTester(self)
|
|
config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = model_tester.prepare_config_and_inputs()
|
|
|
|
model = BertForPreTraining(config=config)
|
|
model.eval()
|
|
|
|
loss, prediction_scores, seq_relationship_score = model(input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
masked_lm_labels=token_labels,
|
|
next_sentence_label=sequence_labels)
|
|
|
|
model_desc = ModelDescription([model_tester.input_ids_desc,
|
|
model_tester.attention_mask_desc,
|
|
model_tester.token_type_ids_desc,
|
|
model_tester.masked_lm_labels_desc,
|
|
model_tester.next_sentence_label_desc],
|
|
[model_tester.loss_desc,
|
|
model_tester.prediction_scores_desc,
|
|
model_tester.seq_relationship_scores_desc])
|
|
|
|
from collections import namedtuple
|
|
MyArgs = namedtuple("MyArgs",
|
|
"local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len")
|
|
args = MyArgs(local_rank=0,
|
|
world_size=1,
|
|
max_steps=100,
|
|
learning_rate=0.00001,
|
|
warmup_proportion=0.01,
|
|
batch_size=13,
|
|
seq_len=7)
|
|
|
|
dataloader = create_ort_test_dataloader(model_desc.inputs_,
|
|
args.batch_size,
|
|
args.seq_len,
|
|
device)
|
|
learning_rate = torch.tensor(1.0e+0, dtype=torch.float32).to(device)
|
|
for b in dataloader:
|
|
batch = b
|
|
break
|
|
learning_rate = torch.tensor([1.00e+00]).to(device)
|
|
inputs = batch + [learning_rate,]
|
|
|
|
onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model)
|
|
|
|
self._bert_helper(onnx_model)
|
|
|
|
def _bert_helper(self, onnx_model):
|
|
# count layer_norm
|
|
count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization")
|
|
assert count_layer_norm == 12
|
|
|
|
# get expand node and check output shape
|
|
expand_nodes = self.find_nodes(onnx_model, "Expand")
|
|
assert len(expand_nodes) == 1
|
|
|
|
model_info = onnx_model.graph.value_info
|
|
assert model_info[0].name == expand_nodes[0].output[0]
|
|
assert model_info[0].type == onnx_model.graph.input[0].type
|
|
|
|
def test_extra_postpass(self):
|
|
def postpass_replace_first_add_with_sub(model):
|
|
# this post pass replaces the first Add node with Sub in the model.
|
|
# Previous graph
|
|
# (subgraph 1) (subgraph 2)
|
|
# | |
|
|
# | |
|
|
# |________ ________|
|
|
# | |
|
|
# Add
|
|
# |
|
|
# (subgraph 3)
|
|
#
|
|
# Post graph
|
|
# (subgraph 1) (subgraph 2)
|
|
# | |
|
|
# | |
|
|
# |________ ________|
|
|
# | |
|
|
# Sub
|
|
# |
|
|
# (subgraph 3)
|
|
add_nodes = [n for n in model.graph.node if n.op_type == 'Add']
|
|
add_nodes[0].op_type = "Sub"
|
|
|
|
class MultiAdd(nn.Module):
|
|
def __init__(self, target):
|
|
super(MultiAdd, self).__init__()
|
|
self.loss = nn.CrossEntropyLoss()
|
|
self.target = target
|
|
self.linear = torch.nn.Linear(2, 2, bias=False)
|
|
|
|
def forward(self, x, x1):
|
|
output = x + x1
|
|
output = output + x
|
|
output = output + x1
|
|
output = self.linear(output)
|
|
loss = self.loss(output, self.target)
|
|
return loss, output
|
|
|
|
device = torch.device("cpu")
|
|
target = torch.ones(5, 2, dtype=torch.int64).to(device)
|
|
model = MultiAdd(target).to(device)
|
|
|
|
x = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
|
|
x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
|
|
|
|
input0_desc = IODescription('x', [5, 5, 2], "float32")
|
|
input1_desc = IODescription('x1', [5, 5, 2], "float32")
|
|
output0_desc = IODescription('output0', [], "float32")
|
|
output1_desc = IODescription('output1', [5, 5, 2], "float32")
|
|
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
|
|
|
|
learning_rate = torch.tensor([1.0000000e+00]).to(device)
|
|
input_args = [x, x1, learning_rate]
|
|
|
|
onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
|
|
_extra_postprocess=postpass_replace_first_add_with_sub)
|
|
|
|
# check that extra postpass is called, and called only once.
|
|
add_nodes = self.find_nodes(onnx_model, "Add")
|
|
sub_nodes = self.find_nodes(onnx_model, "Sub")
|
|
assert len(add_nodes) == 2
|
|
assert len(sub_nodes) == 1
|
|
|
|
|
|
unprocessed_onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
|
|
_extra_postprocess=None, _enable_internal_postprocess=False)
|
|
# check that the model is unchanged.
|
|
add_nodes = self.find_nodes(unprocessed_onnx_model, "Add")
|
|
sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub")
|
|
assert len(add_nodes) == 3
|
|
assert len(sub_nodes) == 0
|
|
|
|
processed_onnx_model = self.get_onnx_model(unprocessed_onnx_model, model_desc, input_args, device,
|
|
_extra_postprocess=postpass_replace_first_add_with_sub)
|
|
# check that extra postpass is called, and called only once.
|
|
add_nodes = self.find_nodes(processed_onnx_model, "Add")
|
|
sub_nodes = self.find_nodes(processed_onnx_model, "Sub")
|
|
assert len(add_nodes) == 2
|
|
assert len(sub_nodes) == 1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(module=__name__, buffer=True)
|