2020-02-07 19:01:03 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# coding: utf-8
|
|
|
|
|
# -------------------------------------------------------------------------
|
2020-03-10 19:57:49 +00:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
2020-02-07 19:01:03 +00:00
|
|
|
# license information.
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import os
|
2020-06-11 21:19:55 +00:00
|
|
|
import pytest
|
2021-08-06 23:16:17 +00:00
|
|
|
from onnx import TensorProto, load_model
|
2021-01-12 18:38:39 +00:00
|
|
|
|
2021-08-09 17:55:49 +00:00
|
|
|
from parity_utilities import find_transformers_source
|
|
|
|
|
if find_transformers_source():
|
2021-08-06 23:16:17 +00:00
|
|
|
from optimizer import optimize_model
|
|
|
|
|
from onnx_model import OnnxModel
|
|
|
|
|
from onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt
|
|
|
|
|
from huggingface_models import MODELS
|
|
|
|
|
from benchmark_helper import Precision
|
|
|
|
|
else:
|
|
|
|
|
from onnxruntime.transformers.optimizer import optimize_model
|
|
|
|
|
from onnxruntime.transformers.onnx_model import OnnxModel
|
|
|
|
|
from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt
|
|
|
|
|
from onnxruntime.transformers.huggingface_models import MODELS
|
|
|
|
|
from onnxruntime.transformers.benchmark_helper import Precision
|
|
|
|
|
|
2020-02-16 07:59:49 +00:00
|
|
|
BERT_TEST_MODELS = {
|
2021-08-06 23:16:17 +00:00
|
|
|
"bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10
|
|
|
|
|
"bert_keras_squad": ('models', 'TFBertForQuestionAnswering.onnx'), # bert_squad_tensorflow2.1_keras2onnx_opset11
|
|
|
|
|
"gpt2_past": ('models', 'gpt2_past.onnx'), # gpt2_pytorch1.5_opset11
|
2020-07-10 22:28:27 +00:00
|
|
|
"gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'),
|
|
|
|
|
"multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'),
|
2021-06-09 02:43:59 +00:00
|
|
|
"bert_tf2onnx_0": ('models', 'bert_tf2onnx_0.onnx')
|
2020-02-16 07:59:49 +00:00
|
|
|
}
|
|
|
|
|
|
2020-06-11 21:19:55 +00:00
|
|
|
|
|
|
|
|
def _get_test_model_path(name):
|
|
|
|
|
sub_dir, file = BERT_TEST_MODELS[name]
|
2020-07-10 22:28:27 +00:00
|
|
|
if sub_dir == "FUSION":
|
2021-08-06 23:16:17 +00:00
|
|
|
relative_path = os.path.join(os.path.dirname(__file__), '..', '..', 'testdata', 'transform', 'fusion', file)
|
|
|
|
|
if (os.path.exists(relative_path)):
|
|
|
|
|
return relative_path
|
|
|
|
|
return os.path.join('.', 'testdata', 'transform', 'fusion', file)
|
2020-07-10 22:28:27 +00:00
|
|
|
else:
|
2021-08-06 23:16:17 +00:00
|
|
|
relative_path = os.path.join(os.path.dirname(__file__), 'test_data', sub_dir, file)
|
|
|
|
|
if (os.path.exists(relative_path)):
|
|
|
|
|
return relative_path
|
|
|
|
|
return os.path.join('.', 'transformers', 'test_data', sub_dir, file)
|
2020-06-11 21:19:55 +00:00
|
|
|
|
2020-03-20 21:34:10 +00:00
|
|
|
|
2020-02-07 19:01:03 +00:00
|
|
|
class TestBertOptimization(unittest.TestCase):
|
2020-05-28 08:16:41 +00:00
|
|
|
def verify_node_count(self, bert_model, expected_node_count, test_name):
|
2020-02-07 19:01:03 +00:00
|
|
|
for op_type, count in expected_node_count.items():
|
|
|
|
|
if len(bert_model.get_nodes_by_op_type(op_type)) != count:
|
2020-05-28 08:16:41 +00:00
|
|
|
print(f"Counters is not expected in test: {test_name}")
|
|
|
|
|
for op, counter in expected_node_count.items():
|
|
|
|
|
print("{}: {} expected={}".format(op, len(bert_model.get_nodes_by_op_type(op)), counter))
|
2020-02-07 19:01:03 +00:00
|
|
|
self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count)
|
|
|
|
|
|
2020-10-14 02:24:14 +00:00
|
|
|
# add test function for huggingface pytorch model
|
2020-12-03 18:52:33 +00:00
|
|
|
def _test_optimizer_on_huggingface_model(self,
|
2021-01-12 18:38:39 +00:00
|
|
|
model_name,
|
|
|
|
|
expected_fusion_result_list,
|
|
|
|
|
inputs_count=1,
|
|
|
|
|
validate_model=True):
|
2021-06-09 02:43:59 +00:00
|
|
|
# Remove cached model so that CI machine will have space
|
|
|
|
|
import shutil
|
|
|
|
|
shutil.rmtree('./cache_models', ignore_errors=True)
|
|
|
|
|
shutil.rmtree('./onnx_models', ignore_errors=True)
|
2020-10-14 02:24:14 +00:00
|
|
|
# expect fusion result list have the following keys
|
|
|
|
|
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
|
|
|
|
|
model_fusion_statistics = {}
|
|
|
|
|
|
|
|
|
|
input_names = MODELS[model_name][0]
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
_, is_valid_onnx_model, _, _ = export_onnx_model_from_pt(model_name, MODELS[model_name][1],
|
|
|
|
|
MODELS[model_name][2], MODELS[model_name][3], None,
|
|
|
|
|
'./cache_models', './onnx_models',
|
|
|
|
|
input_names[:inputs_count], False,
|
|
|
|
|
Precision.FLOAT32, True, True, True, True,
|
|
|
|
|
model_fusion_statistics)
|
|
|
|
|
|
|
|
|
|
onnx_model = list(model_fusion_statistics.keys())[0]
|
|
|
|
|
fusion_result_list = list(model_fusion_statistics[onnx_model].values())
|
|
|
|
|
|
|
|
|
|
if validate_model:
|
|
|
|
|
self.assertEqual(is_valid_onnx_model, True)
|
|
|
|
|
self.assertEqual(fusion_result_list, expected_fusion_result_list)
|
2021-01-12 18:38:39 +00:00
|
|
|
|
2020-12-03 18:52:33 +00:00
|
|
|
def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, inputs_count, validate_model=True):
|
2021-06-09 02:43:59 +00:00
|
|
|
# Remove cached model so that CI machine will have space
|
|
|
|
|
import shutil
|
|
|
|
|
shutil.rmtree('./cache_models', ignore_errors=True)
|
|
|
|
|
shutil.rmtree('./onnx_models', ignore_errors=True)
|
|
|
|
|
|
2020-12-03 18:52:33 +00:00
|
|
|
# expect fusion result list have the following keys
|
|
|
|
|
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
|
|
|
|
|
model_fusion_statistics = {}
|
|
|
|
|
print("testing mode ", model_name)
|
|
|
|
|
print("testing input number = ", inputs_count)
|
|
|
|
|
input_names = MODELS[model_name][0]
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
_, is_valid_onnx_model, _, _ = export_onnx_model_from_tf(model_name, MODELS[model_name][1],
|
|
|
|
|
MODELS[model_name][2], MODELS[model_name][3], None,
|
|
|
|
|
'./cache_models', './onnx_models',
|
|
|
|
|
input_names[:inputs_count], False,
|
|
|
|
|
Precision.FLOAT32, True, True, True, True,
|
|
|
|
|
model_fusion_statistics)
|
|
|
|
|
|
|
|
|
|
onnx_model = list(model_fusion_statistics.keys())[0]
|
|
|
|
|
fusion_result_list = list(model_fusion_statistics[onnx_model].values())
|
|
|
|
|
|
|
|
|
|
if validate_model:
|
|
|
|
|
self.assertEqual(is_valid_onnx_model, True)
|
|
|
|
|
self.assertEqual(fusion_result_list, expected_fusion_result_list)
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# def test_keras_model_1(self):
|
|
|
|
|
# input = _get_test_model_path('bert_keras_0')
|
2020-02-07 19:01:03 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# bert_model = optimize_model(input, 'bert_keras', num_heads=2, hidden_size=8)
|
2020-03-06 02:31:52 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# expected_node_count = {
|
|
|
|
|
# 'EmbedLayerNormalization': 1,
|
|
|
|
|
# 'Attention': 12,
|
|
|
|
|
# 'LayerNormalization': 0,
|
|
|
|
|
# 'SkipLayerNormalization': 24,
|
|
|
|
|
# 'BiasGelu': 12,
|
|
|
|
|
# 'Gelu': 0,
|
|
|
|
|
# 'FastGelu': 0
|
|
|
|
|
# }
|
|
|
|
|
# self.verify_node_count(bert_model, expected_node_count, 'test_keras_model_1')
|
2020-07-10 22:28:27 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# def test_keras_squad_model(self):
|
|
|
|
|
# input = _get_test_model_path('bert_keras_squad')
|
2020-02-07 19:01:03 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# bert_model = optimize_model(input, 'bert_keras', num_heads=2, hidden_size=8)
|
2020-02-07 19:01:03 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# print("fused_operator_statistics for test_keras_squad_model", bert_model.get_fused_operator_statistics())
|
2020-04-17 23:23:53 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# self.assertTrue(bert_model.is_fully_optimized())
|
2020-04-17 23:23:53 +00:00
|
|
|
|
2020-06-11 21:19:55 +00:00
|
|
|
def test_gpt2_past(self):
|
|
|
|
|
input = _get_test_model_path('gpt2_past')
|
2020-06-19 21:01:37 +00:00
|
|
|
model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4)
|
2020-06-11 21:19:55 +00:00
|
|
|
|
|
|
|
|
expected_node_count = {
|
|
|
|
|
'EmbedLayerNormalization': 0,
|
|
|
|
|
'Attention': 12,
|
|
|
|
|
'Gelu': 0,
|
|
|
|
|
'FastGelu': 12,
|
|
|
|
|
'BiasGelu': 0,
|
|
|
|
|
'LayerNormalization': 25,
|
|
|
|
|
'SkipLayerNormalization': 0
|
|
|
|
|
}
|
2020-06-19 21:01:37 +00:00
|
|
|
self.verify_node_count(model, expected_node_count, 'test_gpt2_past')
|
|
|
|
|
|
2021-08-06 23:16:17 +00:00
|
|
|
def test_gpt2_past_fp16(self):
|
|
|
|
|
input_model_path = _get_test_model_path('gpt2_past')
|
|
|
|
|
model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True))
|
2021-08-27 21:35:38 +00:00
|
|
|
model.convert_float_to_float16(keep_io_types=False, use_symbolic_shape_infer=False)
|
2021-08-06 23:16:17 +00:00
|
|
|
for input in model.graph().input[1:]:
|
|
|
|
|
self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16)
|
|
|
|
|
for output in model.graph().output:
|
|
|
|
|
self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16)
|
|
|
|
|
|
|
|
|
|
def test_gpt2_past_mask(self):
|
|
|
|
|
input = _get_test_model_path('gpt2_past_mask')
|
|
|
|
|
model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4)
|
|
|
|
|
expected_node_count = {
|
|
|
|
|
'EmbedLayerNormalization': 0,
|
|
|
|
|
'Attention': 1,
|
|
|
|
|
'Gelu': 0,
|
|
|
|
|
'FastGelu': 1,
|
|
|
|
|
'BiasGelu': 0,
|
|
|
|
|
'LayerNormalization': 2,
|
|
|
|
|
'SkipLayerNormalization': 0
|
|
|
|
|
}
|
|
|
|
|
self.verify_node_count(model, expected_node_count, 'test_gpt2_past_mask')
|
2020-07-10 22:28:27 +00:00
|
|
|
|
|
|
|
|
def test_multiple_embed(self):
|
|
|
|
|
input_model_path = _get_test_model_path('multiple_embed')
|
|
|
|
|
model = optimize_model(input_model_path, 'bert', num_heads=2, hidden_size=4)
|
|
|
|
|
expected_node_count = {
|
|
|
|
|
'EmbedLayerNormalization': 2,
|
|
|
|
|
'Attention': 2,
|
|
|
|
|
'Gelu': 0,
|
|
|
|
|
'FastGelu': 0,
|
|
|
|
|
'BiasGelu': 0,
|
|
|
|
|
'LayerNormalization': 0,
|
|
|
|
|
'SkipLayerNormalization': 0
|
|
|
|
|
}
|
|
|
|
|
self.verify_node_count(model, expected_node_count, 'test_multiple_embed')
|
|
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# def test_bert_tf2onnx_0(self):
|
|
|
|
|
# input = _get_test_model_path('bert_tf2onnx_0')
|
|
|
|
|
# model = optimize_model(input, 'bert_tf', num_heads=2, hidden_size=8)
|
|
|
|
|
# expected_node_count = {
|
|
|
|
|
# 'EmbedLayerNormalization': 0,
|
|
|
|
|
# 'Attention': 6,
|
|
|
|
|
# 'Gelu': 0,
|
|
|
|
|
# 'FastGelu': 6,
|
|
|
|
|
# 'BiasGelu': 0,
|
|
|
|
|
# 'LayerNormalization': 0,
|
|
|
|
|
# 'SkipLayerNormalization': 13
|
|
|
|
|
# }
|
|
|
|
|
# self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0')
|
2020-12-01 19:19:16 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2021-08-06 23:16:17 +00:00
|
|
|
def test_huggingface_bert_fusion_1(self):
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1)
|
2021-08-06 23:16:17 +00:00
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_bert_fusion_2(self):
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2)
|
2021-08-06 23:16:17 +00:00
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_bert_fusion_3(self):
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=3)
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-10-14 02:24:14 +00:00
|
|
|
def test_huggingface_openaigpt_fusion(self):
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("openai-gpt", [0, 12, 0, 12, 0, 24, 0])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# @pytest.mark.slow
|
|
|
|
|
# def test_huggingface_gpt2_fusion(self):
|
|
|
|
|
# self._test_optimizer_on_huggingface_model("gpt2", [0, 12, 0, 12, 0, 25, 0])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-10-14 02:24:14 +00:00
|
|
|
def test_huggingface_xlm_fusion(self):
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("xlm-mlm-ende-1024", [0, 6, 0, 0, 6, 0, 13])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-10-14 02:24:14 +00:00
|
|
|
def test_huggingface_roberta_fusion(self):
|
2021-06-09 02:43:59 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("roberta-base", [0, 12, 0, 0, 12, 1, 24])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-10-14 02:24:14 +00:00
|
|
|
def test_huggingface_distillbert_fusion(self):
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=1)
|
|
|
|
|
self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=2)
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# @pytest.mark.slow
|
|
|
|
|
# def test_huggingface_camembert_fusion(self):
|
|
|
|
|
# # output not close issue
|
|
|
|
|
# self._test_optimizer_on_huggingface_model("camembert-base", [0, 12, 0, 0, 12, 1, 24], validate_model=False)
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-10-14 02:24:14 +00:00
|
|
|
def test_huggingface_albert_fusion(self):
|
2021-06-09 02:43:59 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("albert-base-v1", [0, 12, 0, 0, 12, 1, 24])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# @pytest.mark.slow
|
|
|
|
|
# def test_huggingface_t5_fusion(self):
|
|
|
|
|
# self._test_optimizer_on_huggingface_model("t5-small", [0, 0, 0, 0, 0, 0, 0])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-10-14 02:24:14 +00:00
|
|
|
def test_huggingface_xlmroberta_fusion(self):
|
2021-06-09 02:43:59 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("xlm-roberta-base", [0, 12, 0, 0, 12, 1, 24])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-10-14 02:24:14 +00:00
|
|
|
def test_huggingface_flaubert_fusion(self):
|
|
|
|
|
# output not close issue
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("flaubert/flaubert_base_cased", [0, 12, 0, 0, 12, 0, 25],
|
2021-01-12 18:38:39 +00:00
|
|
|
validate_model=False)
|
2020-12-03 18:52:33 +00:00
|
|
|
self._test_optimizer_on_huggingface_model("flaubert/flaubert_small_cased", [0, 6, 0, 0, 6, 12, 1],
|
2021-01-12 18:38:39 +00:00
|
|
|
validate_model=False)
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
# @pytest.mark.slow
|
|
|
|
|
# def test_huggingface_dialogpt_fusion(self):
|
|
|
|
|
# self._test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0])
|
2020-10-14 02:24:14 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2020-12-15 22:30:15 +00:00
|
|
|
def test_huggingface_bart_fusion(self):
|
|
|
|
|
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])
|
2021-01-12 18:38:39 +00:00
|
|
|
|
2021-03-17 03:49:51 +00:00
|
|
|
@pytest.mark.slow
|
2021-08-06 23:16:17 +00:00
|
|
|
def test_huggingface_bert_base_cased_from_tf2onnx_1(self):
|
2021-03-17 22:33:57 +00:00
|
|
|
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1)
|
2021-08-06 23:16:17 +00:00
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_bert_base_cased_from_tf2onnx_2(self):
|
2021-03-17 22:33:57 +00:00
|
|
|
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2)
|
2021-08-06 23:16:17 +00:00
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_bert_base_cased_from_tf2onnx_3(self):
|
2021-03-17 22:33:57 +00:00
|
|
|
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3)
|
2020-03-20 21:34:10 +00:00
|
|
|
|
2021-03-17 22:33:57 +00:00
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_distilgpt2_from_tf2onnx(self):
|
|
|
|
|
self._test_optimizer_on_tf_model("distilgpt2", [0, 0, 0, 0, 0, 12, 1], 1)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_albert_from_tf2onnx(self):
|
|
|
|
|
self._test_optimizer_on_tf_model("albert-base-v1", [0, 0, 0, 0, 0, 0, 25], 1)
|
2021-06-09 02:43:59 +00:00
|
|
|
|
2021-03-17 22:33:57 +00:00
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_gpt2_from_tf2onnx(self):
|
|
|
|
|
self._test_optimizer_on_tf_model("gpt2", [0, 0, 0, 0, 0, 24, 1], 1, validate_model=False)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_roberta_from_tf2onnx(self):
|
|
|
|
|
self._test_optimizer_on_tf_model("roberta-base", [0, 12, 0, 0, 0, 0, 25], 1, validate_model=False)
|
2021-06-09 02:43:59 +00:00
|
|
|
|
2021-03-17 22:33:57 +00:00
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_distilbert_from_tf2onnx(self):
|
|
|
|
|
self._test_optimizer_on_tf_model("distilbert-base-uncased", [0, 0, 0, 0, 0, 0, 13], 1, validate_model=False)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
|
def test_huggingface_xlm_from_tf2onnx(self):
|
|
|
|
|
self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False)
|
2021-01-12 18:38:39 +00:00
|
|
|
|
2021-06-09 02:43:59 +00:00
|
|
|
|
2020-02-07 19:01:03 +00:00
|
|
|
if __name__ == '__main__':
|
2021-01-12 18:38:39 +00:00
|
|
|
unittest.main()
|