diff --git a/onnxruntime/python/tools/transformers/affinity_helper.py b/onnxruntime/python/tools/transformers/affinity_helper.py index 26676c2282..8fb3e3b571 100644 --- a/onnxruntime/python/tools/transformers/affinity_helper.py +++ b/onnxruntime/python/tools/transformers/affinity_helper.py @@ -26,8 +26,7 @@ class AffinitySetting(): if self.is_os_supported: current_affinity = os.sched_getaffinity(self.pid) if (self.affinity != current_affinity): - logger.warning("Replacing affinity setting %s with %s", str(current_affinity), - str(self.affinity)) + logger.warning("Replacing affinity setting %s with %s", str(current_affinity), str(self.affinity)) os.sched_setaffinity(self.pid, self.affinity) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 691f40d949..0451c06ee8 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -358,7 +358,7 @@ class FusionAttention(Fusion): logger.debug("fuse_attention: failed to match v path") return (_, _, add_v, matmul_v) = v_nodes - + is_distill = False is_distill_add = False qk_paths = { diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py index a27029ddf4..d7cb54ad43 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py @@ -262,9 +262,13 @@ class FusionGptAttention(Fusion): if i == 1: add_qk = qk_nodes[1] _, input_mask_nodes, _ = self.model.match_parent_paths( - add_qk, [(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0, 0]), - (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0])], - output_name_to_node) + add_qk, + [ + (['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0, 0]), + (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0]), + (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0]), # useless cast and reshape are removed. + ], + output_name_to_node) # yapf: disable if input_mask_nodes is None: logger.debug("fuse_attention: failed to match input attention mask path") return diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index ef858237ce..5a036ac0f5 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -39,7 +39,8 @@ class FusionSkipLayerNormalization(Fusion): if self.shape_infer_helper is not None: if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): - logger.debug(f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same") + logger.debug( + f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same") return else: # shape_infer_helper can not handle subgraphs. Current work around is to disable skiplayernorm fusion diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 600f1b6df7..e7d2dbec1b 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -76,7 +76,8 @@ class FusionUtils: value = helper.get_attribute_value(attr) if isinstance(expected_value, list): - return isinstance(value, ndarray) and array_equal(expected_value, value, equal_nan=False) + return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal( + expected_value, value, equal_nan=False) else: return value == expected_value @@ -96,12 +97,13 @@ class FusionUtils: value = self.model.get_constant_value(node.input[input_index]) if isinstance(expected_value, list): - return isinstance(value, ndarray) and array_equal(expected_value, value, equal_nan=False) + return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal( + expected_value, value, equal_nan=False) else: return value == expected_value @staticmethod - def remove_useless_reshape_nodes(model:OnnxModel): + def remove_useless_reshape_nodes(model: OnnxModel): """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape """ shape_infer = model.infer_runtime_shape(update=True) @@ -114,7 +116,8 @@ class FusionUtils: input_shape = shape_infer.get_edge_shape(node.input[0]) output_shape = shape_infer.get_edge_shape(node.output[0]) if input_shape and output_shape and input_shape == output_shape: - logger.info(f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}") + logger.info( + f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}") nodes_to_remove.append(node) if nodes_to_remove: @@ -123,6 +126,7 @@ class FusionUtils: model.remove_node(node) model.prune_graph() + class NumpyHelper: @staticmethod def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray: diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 9e127ad32f..048dff3a74 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -15,6 +15,7 @@ from benchmark_helper import create_onnxruntime_session, Precision from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState from quantize_helper import QuantizeHelper from huggingface_models import MODEL_CLASSES + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' logger = logging.getLogger(__name__) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 7bf8ced4e6..747b1b2cd5 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -486,21 +486,24 @@ class OnnxModel: # restore opset version self.model.opset_import[0].version = original_opset_version - def convert_model_float32_to_float16(self, cast_input_output=True): + def convert_model_float32_to_float16(self, cast_input_output=True, use_symbolic_shape_infer=True): """Convert a graph to FLOAT16. By default, we will keep data types of inputs and outputs. For decoder model with past_key_values, it is recommended to set cast_input_output=False for better performance. Args: cast_input_output (bool, optional): keep data type of inputs and outputs, and add Cast nodes to convert float32 inputs to float16, and float16 to float32 for outputs. Defaults to True. + use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference. """ from packaging.version import Version import onnxconverter_common as oc if Version(oc.__version__) > Version("1.7.0"): - # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) are not recognized by onnx shape inference. - shape_infer_helper = SymbolicShapeInferenceHelper(self.model) - model_with_shape = shape_infer_helper.infer_shapes(self.model, auto_merge=True, guess_output_rank=False) - self.model = oc.float16.convert_float_to_float16(model_with_shape, + model = self.model + if use_symbolic_shape_infer: + # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) are not recognized by onnx shape inference. + shape_infer_helper = SymbolicShapeInferenceHelper(model) + model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False) + self.model = oc.float16.convert_float_to_float16(model, keep_io_types=cast_input_output, - disable_shape_infer=True) + disable_shape_infer=use_symbolic_shape_infer) return graph = self.model.graph diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index c0f8f35517..53f2ca1f9d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -11,6 +11,7 @@ import numpy as np from collections import deque from onnx import ModelProto, TensorProto, numpy_helper from onnx_model_bert_tf import BertOnnxModelTF + logger = logging.getLogger(__name__) diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 6aa2084acf..b46c8e1c49 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -96,25 +96,26 @@ class TestFusion(unittest.TestCase): expected = onnx.load(expected_model_path) self.assertEqual(str(optimized_model.model.graph), str(expected.graph)) - # def test_gpt2_attention_fusion(self): - # hidden_size = 64 - # num_heads = 4 - # for add_order in [False, True]: - # model = create_gpt2_attention(hidden_size=hidden_size, num_heads=num_heads, switch_add_inputs=add_order) - # dir = '.' - # model_path = os.path.join(dir, "gpt2_attention.onnx") - # onnx.save(model, model_path) - # optimized_model = optimize_model(model_path, - # model_type='gpt2', - # num_heads=num_heads, - # hidden_size=hidden_size, - # disable_onnxruntime=True) - # os.remove(model_path) + def test_gpt2_attention_fusion(self): + hidden_size = 64 + num_heads = 4 + for add_order in [False, True]: + model = create_gpt2_attention(hidden_size=hidden_size, num_heads=num_heads, switch_add_inputs=add_order) + dir = '.' + model_path = os.path.join(dir, "gpt2_attention.onnx") + onnx.save(model, model_path) + optimized_model = optimize_model(model_path, + model_type='gpt2', + num_heads=num_heads, + hidden_size=hidden_size, + disable_onnxruntime=True) + optimized_model.topological_sort() + os.remove(model_path) - # model_name = "gpt2_attention_{}.onnx".format("add_opt" if add_order else "opt") - # expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'models', model_name) - # expected = onnx.load(expected_model_path) - # self.assertEqual(str(optimized_model.model.graph), str(expected.graph)) + model_name = "gpt2_attention_{}.onnx".format("add_opt" if add_order else "opt") + expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'models', model_name) + expected = onnx.load(expected_model_path) + self.assertEqual(str(optimized_model.model.graph), str(expected.graph)) if __name__ == '__main__': diff --git a/onnxruntime/test/python/transformers/test_optimizer.py b/onnxruntime/test/python/transformers/test_optimizer.py index afe6b30cdb..3f7bc7fa56 100644 --- a/onnxruntime/test/python/transformers/test_optimizer.py +++ b/onnxruntime/test/python/transformers/test_optimizer.py @@ -10,22 +10,31 @@ import unittest import os -import onnx -import onnxruntime import pytest -from onnx import helper, TensorProto, ModelProto, load_model -from onnx.helper import make_node, make_tensor_value_info -import numpy as np -from onnx import numpy_helper +from onnx import TensorProto, load_model import sys -from onnxruntime.transformers.optimizer import optimize_model, optimize_by_onnxruntime -from onnxruntime.transformers.onnx_model import OnnxModel +# Try import optimizer from source directory so that we need not build and install package after making change. +source_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'python', 'tools', 'transformers') +if (os.path.exists(source_dir) and source_dir not in sys.path): + sys.path.append(source_dir) + from optimizer import optimize_model + from onnx_model import OnnxModel + from onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt + from huggingface_models import MODELS + from benchmark_helper import Precision +else: + from onnxruntime.transformers.optimizer import optimize_model + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt + from onnxruntime.transformers.huggingface_models import MODELS + from onnxruntime.transformers.benchmark_helper import Precision + BERT_TEST_MODELS = { - "bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10 - "bert_keras_squad": ('models', 'TFBertForQuestionAnswering.onnx'), # bert_squad_tensorflow2.1_keras2onnx_opset11 - "gpt2_past": ('models', 'gpt2_past.onnx'), # gpt2_pytorch1.5_opset11 + "bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10 + "bert_keras_squad": ('models', 'TFBertForQuestionAnswering.onnx'), # bert_squad_tensorflow2.1_keras2onnx_opset11 + "gpt2_past": ('models', 'gpt2_past.onnx'), # gpt2_pytorch1.5_opset11 "gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'), "multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'), "bert_tf2onnx_0": ('models', 'bert_tf2onnx_0.onnx') @@ -35,10 +44,15 @@ BERT_TEST_MODELS = { def _get_test_model_path(name): sub_dir, file = BERT_TEST_MODELS[name] if sub_dir == "FUSION": - #return os.path.join('..', '..', '..', '..', 'test', 'testdata', 'transform', 'fusion', file) - return os.path.join('./', 'testdata', 'transform', 'fusion', file) + relative_path = os.path.join(os.path.dirname(__file__), '..', '..', 'testdata', 'transform', 'fusion', file) + if (os.path.exists(relative_path)): + return relative_path + return os.path.join('.', 'testdata', 'transform', 'fusion', file) else: - return os.path.join('./', 'transformers', 'test_data', sub_dir, file) + relative_path = os.path.join(os.path.dirname(__file__), 'test_data', sub_dir, file) + if (os.path.exists(relative_path)): + return relative_path + return os.path.join('.', 'transformers', 'test_data', sub_dir, file) class TestBertOptimization(unittest.TestCase): @@ -63,9 +77,6 @@ class TestBertOptimization(unittest.TestCase): # expect fusion result list have the following keys # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization model_fusion_statistics = {} - from onnx_exporter import export_onnx_model_from_pt - from huggingface_models import MODELS - from benchmark_helper import Precision input_names = MODELS[model_name][0] @@ -94,9 +105,6 @@ class TestBertOptimization(unittest.TestCase): # expect fusion result list have the following keys # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization model_fusion_statistics = {} - from onnx_exporter import export_onnx_model_from_tf - from huggingface_models import MODELS - from benchmark_helper import Precision print("testing mode ", model_name) print("testing input number = ", inputs_count) input_names = MODELS[model_name][0] @@ -157,28 +165,28 @@ class TestBertOptimization(unittest.TestCase): } self.verify_node_count(model, expected_node_count, 'test_gpt2_past') - # def test_gpt2_past_fp16(self): - # input_model_path = _get_test_model_path('gpt2_past') - # model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True)) - # model.convert_model_float32_to_float16(cast_input_output=False) - # for input in model.graph().input[1:]: - # self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16) - # for output in model.graph().output: - # self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16) + def test_gpt2_past_fp16(self): + input_model_path = _get_test_model_path('gpt2_past') + model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True)) + model.convert_model_float32_to_float16(cast_input_output=False, use_symbolic_shape_infer=False) + for input in model.graph().input[1:]: + self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16) + for output in model.graph().output: + self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16) - # def test_gpt2_past_mask(self): - # input = _get_test_model_path('gpt2_past_mask') - # model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4) - # expected_node_count = { - # 'EmbedLayerNormalization': 0, - # 'Attention': 1, - # 'Gelu': 0, - # 'FastGelu': 1, - # 'BiasGelu': 0, - # 'LayerNormalization': 2, - # 'SkipLayerNormalization': 0 - # } - # self.verify_node_count(model, expected_node_count, 'test_gpt2_past_mask') + def test_gpt2_past_mask(self): + input = _get_test_model_path('gpt2_past_mask') + model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4) + expected_node_count = { + 'EmbedLayerNormalization': 0, + 'Attention': 1, + 'Gelu': 0, + 'FastGelu': 1, + 'BiasGelu': 0, + 'LayerNormalization': 2, + 'SkipLayerNormalization': 0 + } + self.verify_node_count(model, expected_node_count, 'test_gpt2_past_mask') def test_multiple_embed(self): input_model_path = _get_test_model_path('multiple_embed') @@ -209,9 +217,15 @@ class TestBertOptimization(unittest.TestCase): # self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0') @pytest.mark.slow - def test_huggingface_bert_fusion(self): + def test_huggingface_bert_fusion_1(self): self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1) + + @pytest.mark.slow + def test_huggingface_bert_fusion_2(self): self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2) + + @pytest.mark.slow + def test_huggingface_bert_fusion_3(self): self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=3) @pytest.mark.slow @@ -269,9 +283,15 @@ class TestBertOptimization(unittest.TestCase): self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30]) @pytest.mark.slow - def test_huggingface_bert_base_cased_from_tf2onnx(self): + def test_huggingface_bert_base_cased_from_tf2onnx_1(self): self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1) + + @pytest.mark.slow + def test_huggingface_bert_base_cased_from_tf2onnx_2(self): self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2) + + @pytest.mark.slow + def test_huggingface_bert_base_cased_from_tf2onnx_3(self): self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3) @pytest.mark.slow diff --git a/onnxruntime/test/python/transformers/test_profiler.py b/onnxruntime/test/python/transformers/test_profiler.py index 7202b78c55..e18fb2a784 100644 --- a/onnxruntime/test/python/transformers/test_profiler.py +++ b/onnxruntime/test/python/transformers/test_profiler.py @@ -24,11 +24,13 @@ class TestBertProfiler(unittest.TestCase): results = run(args) self.assertTrue(len(results) > 1) + @pytest.mark.slow def test_profiler_gpu(self): input_model_path = _get_test_model_path('bert_keras_squad') if 'CUDAExecutionProvider' in onnxruntime.get_available_providers(): self.run_profile(f'--model {input_model_path} --batch_size 1 --sequence_length 7 --use_gpu') + @pytest.mark.slow def test_profiler_cpu(self): input_model_path = _get_test_model_path('bert_keras_squad') self.run_profile(f'--model {input_model_path} --batch_size 1 --sequence_length 7 --dummy_inputs default')