onnxruntime/onnxruntime/test/python/transformers/test_optimizer.py

320 lines
14 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG
import unittest
import os
import pytest
from onnx import TensorProto, load_model
from parity_utilities import find_transformers_source
if find_transformers_source():
from optimizer import optimize_model
from onnx_model import OnnxModel
from onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt
from huggingface_models import MODELS
from benchmark_helper import Precision
else:
from onnxruntime.transformers.optimizer import optimize_model
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt
from onnxruntime.transformers.huggingface_models import MODELS
from onnxruntime.transformers.benchmark_helper import Precision
BERT_TEST_MODELS = {
"bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10
"bert_keras_squad": ('models', 'TFBertForQuestionAnswering.onnx'), # bert_squad_tensorflow2.1_keras2onnx_opset11
"gpt2_past": ('models', 'gpt2_past.onnx'), # gpt2_pytorch1.5_opset11
"gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'),
"multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'),
"bert_tf2onnx_0": ('models', 'bert_tf2onnx_0.onnx')
}
def _get_test_model_path(name):
sub_dir, file = BERT_TEST_MODELS[name]
if sub_dir == "FUSION":
relative_path = os.path.join(os.path.dirname(__file__), '..', '..', 'testdata', 'transform', 'fusion', file)
if (os.path.exists(relative_path)):
return relative_path
return os.path.join('.', 'testdata', 'transform', 'fusion', file)
else:
relative_path = os.path.join(os.path.dirname(__file__), 'test_data', sub_dir, file)
if (os.path.exists(relative_path)):
return relative_path
return os.path.join('.', 'transformers', 'test_data', sub_dir, file)
class TestBertOptimization(unittest.TestCase):
def verify_node_count(self, bert_model, expected_node_count, test_name):
for op_type, count in expected_node_count.items():
if len(bert_model.get_nodes_by_op_type(op_type)) != count:
print(f"Counters is not expected in test: {test_name}")
for op, counter in expected_node_count.items():
print("{}: {} expected={}".format(op, len(bert_model.get_nodes_by_op_type(op)), counter))
self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count)
# add test function for huggingface pytorch model
def _test_optimizer_on_huggingface_model(self,
model_name,
expected_fusion_result_list,
inputs_count=1,
validate_model=True):
# Remove cached model so that CI machine will have space
import shutil
shutil.rmtree('./cache_models', ignore_errors=True)
shutil.rmtree('./onnx_models', ignore_errors=True)
# expect fusion result list have the following keys
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
model_fusion_statistics = {}
input_names = MODELS[model_name][0]
import torch
with torch.no_grad():
_, is_valid_onnx_model, _, _ = export_onnx_model_from_pt(model_name, MODELS[model_name][1],
MODELS[model_name][2], MODELS[model_name][3], None,
'./cache_models', './onnx_models',
input_names[:inputs_count], False,
Precision.FLOAT32, True, True, True, True,
model_fusion_statistics)
onnx_model = list(model_fusion_statistics.keys())[0]
fusion_result_list = list(model_fusion_statistics[onnx_model].values())
if validate_model:
self.assertEqual(is_valid_onnx_model, True)
self.assertEqual(fusion_result_list, expected_fusion_result_list)
def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, inputs_count, validate_model=True):
# Remove cached model so that CI machine will have space
import shutil
shutil.rmtree('./cache_models', ignore_errors=True)
shutil.rmtree('./onnx_models', ignore_errors=True)
# expect fusion result list have the following keys
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
model_fusion_statistics = {}
print("testing mode ", model_name)
print("testing input number = ", inputs_count)
input_names = MODELS[model_name][0]
import torch
with torch.no_grad():
_, is_valid_onnx_model, _, _ = export_onnx_model_from_tf(model_name, MODELS[model_name][1],
MODELS[model_name][2], MODELS[model_name][3], None,
'./cache_models', './onnx_models',
input_names[:inputs_count], False,
Precision.FLOAT32, True, True, True, True,
model_fusion_statistics)
onnx_model = list(model_fusion_statistics.keys())[0]
fusion_result_list = list(model_fusion_statistics[onnx_model].values())
if validate_model:
self.assertEqual(is_valid_onnx_model, True)
self.assertEqual(fusion_result_list, expected_fusion_result_list)
# def test_keras_model_1(self):
# input = _get_test_model_path('bert_keras_0')
# bert_model = optimize_model(input, 'bert_keras', num_heads=2, hidden_size=8)
# expected_node_count = {
# 'EmbedLayerNormalization': 1,
# 'Attention': 12,
# 'LayerNormalization': 0,
# 'SkipLayerNormalization': 24,
# 'BiasGelu': 12,
# 'Gelu': 0,
# 'FastGelu': 0
# }
# self.verify_node_count(bert_model, expected_node_count, 'test_keras_model_1')
# def test_keras_squad_model(self):
# input = _get_test_model_path('bert_keras_squad')
# bert_model = optimize_model(input, 'bert_keras', num_heads=2, hidden_size=8)
# print("fused_operator_statistics for test_keras_squad_model", bert_model.get_fused_operator_statistics())
# self.assertTrue(bert_model.is_fully_optimized())
def test_gpt2_past(self):
input = _get_test_model_path('gpt2_past')
model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4)
expected_node_count = {
'EmbedLayerNormalization': 0,
'Attention': 12,
'Gelu': 0,
'FastGelu': 12,
'BiasGelu': 0,
'LayerNormalization': 25,
'SkipLayerNormalization': 0
}
self.verify_node_count(model, expected_node_count, 'test_gpt2_past')
def test_gpt2_past_fp16(self):
input_model_path = _get_test_model_path('gpt2_past')
model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True))
model.convert_float_to_float16(keep_io_types=False, use_symbolic_shape_infer=False)
for input in model.graph().input[1:]:
self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16)
for output in model.graph().output:
self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16)
def test_gpt2_past_mask(self):
input = _get_test_model_path('gpt2_past_mask')
model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4)
expected_node_count = {
'EmbedLayerNormalization': 0,
'Attention': 1,
'Gelu': 0,
'FastGelu': 1,
'BiasGelu': 0,
'LayerNormalization': 2,
'SkipLayerNormalization': 0
}
self.verify_node_count(model, expected_node_count, 'test_gpt2_past_mask')
def test_multiple_embed(self):
input_model_path = _get_test_model_path('multiple_embed')
model = optimize_model(input_model_path, 'bert', num_heads=2, hidden_size=4)
expected_node_count = {
'EmbedLayerNormalization': 2,
'Attention': 2,
'Gelu': 0,
'FastGelu': 0,
'BiasGelu': 0,
'LayerNormalization': 0,
'SkipLayerNormalization': 0
}
self.verify_node_count(model, expected_node_count, 'test_multiple_embed')
# def test_bert_tf2onnx_0(self):
# input = _get_test_model_path('bert_tf2onnx_0')
# model = optimize_model(input, 'bert_tf', num_heads=2, hidden_size=8)
# expected_node_count = {
# 'EmbedLayerNormalization': 0,
# 'Attention': 6,
# 'Gelu': 0,
# 'FastGelu': 6,
# 'BiasGelu': 0,
# 'LayerNormalization': 0,
# 'SkipLayerNormalization': 13
# }
# self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0')
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_bert_fusion_1(self):
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1)
@pytest.mark.slow
def test_huggingface_bert_fusion_2(self):
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2)
@pytest.mark.slow
def test_huggingface_bert_fusion_3(self):
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=3)
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_openaigpt_fusion(self):
self._test_optimizer_on_huggingface_model("openai-gpt", [0, 12, 0, 12, 0, 24, 0])
# @pytest.mark.slow
# def test_huggingface_gpt2_fusion(self):
# self._test_optimizer_on_huggingface_model("gpt2", [0, 12, 0, 12, 0, 25, 0])
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_xlm_fusion(self):
self._test_optimizer_on_huggingface_model("xlm-mlm-ende-1024", [0, 6, 0, 0, 6, 0, 13])
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_roberta_fusion(self):
self._test_optimizer_on_huggingface_model("roberta-base", [0, 12, 0, 0, 12, 1, 24])
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_distillbert_fusion(self):
self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=1)
self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=2)
# @pytest.mark.slow
# def test_huggingface_camembert_fusion(self):
# # output not close issue
# self._test_optimizer_on_huggingface_model("camembert-base", [0, 12, 0, 0, 12, 1, 24], validate_model=False)
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_albert_fusion(self):
self._test_optimizer_on_huggingface_model("albert-base-v1", [0, 12, 0, 0, 12, 1, 24])
# @pytest.mark.slow
# def test_huggingface_t5_fusion(self):
# self._test_optimizer_on_huggingface_model("t5-small", [0, 0, 0, 0, 0, 0, 0])
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_xlmroberta_fusion(self):
self._test_optimizer_on_huggingface_model("xlm-roberta-base", [0, 12, 0, 0, 12, 1, 24])
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_flaubert_fusion(self):
# output not close issue
self._test_optimizer_on_huggingface_model("flaubert/flaubert_base_cased", [0, 12, 0, 0, 12, 0, 25],
validate_model=False)
self._test_optimizer_on_huggingface_model("flaubert/flaubert_small_cased", [0, 6, 0, 0, 6, 12, 1],
validate_model=False)
# @pytest.mark.slow
# def test_huggingface_dialogpt_fusion(self):
# self._test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0])
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_bart_fusion(self):
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])
2021-03-17 03:49:51 +00:00
@pytest.mark.slow
def test_huggingface_bert_base_cased_from_tf2onnx_1(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1)
@pytest.mark.slow
def test_huggingface_bert_base_cased_from_tf2onnx_2(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2)
@pytest.mark.slow
def test_huggingface_bert_base_cased_from_tf2onnx_3(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3)
@pytest.mark.slow
def test_huggingface_distilgpt2_from_tf2onnx(self):
self._test_optimizer_on_tf_model("distilgpt2", [0, 0, 0, 0, 0, 12, 1], 1)
@pytest.mark.slow
def test_huggingface_albert_from_tf2onnx(self):
self._test_optimizer_on_tf_model("albert-base-v1", [0, 0, 0, 0, 0, 0, 25], 1)
@pytest.mark.slow
def test_huggingface_gpt2_from_tf2onnx(self):
self._test_optimizer_on_tf_model("gpt2", [0, 0, 0, 0, 0, 24, 1], 1, validate_model=False)
@pytest.mark.slow
def test_huggingface_roberta_from_tf2onnx(self):
self._test_optimizer_on_tf_model("roberta-base", [0, 12, 0, 0, 0, 0, 25], 1, validate_model=False)
@pytest.mark.slow
def test_huggingface_distilbert_from_tf2onnx(self):
self._test_optimizer_on_tf_model("distilbert-base-uncased", [0, 0, 0, 0, 0, 0, 13], 1, validate_model=False)
@pytest.mark.slow
def test_huggingface_xlm_from_tf2onnx(self):
self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False)
if __name__ == '__main__':
unittest.main()