fix test_optimizer.py (#12219)

* fix optimizer test
* update message and skip test instead of uncomment
* fix deprecated warning
This commit is contained in:
Tianlei Wu 2022-07-20 19:21:26 -07:00 committed by GitHub
parent c72bb8aaa9
commit 568d08994f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 110 additions and 125 deletions

View file

@ -76,11 +76,11 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit
# convert raw_data (bytes type)
if tensor.raw_data:
# convert n.raw_data to float
float32_list = np.fromstring(tensor.raw_data, dtype="float32")
float32_list = np.frombuffer(tensor.raw_data, dtype="float32")
# convert float to float16
float16_list = convert_np_to_float16(float32_list, min_positive_val, max_finite_val)
# convert float16 to bytes and write back to raw_data
tensor.raw_data = float16_list.tostring()
tensor.raw_data = float16_list.tobytes()
return tensor
@ -384,7 +384,7 @@ def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=
float32_data = np.array(tensor.float_data)
if tensor.raw_data:
float32_data = np.fromstring(tensor.raw_data, dtype="float32")
float32_data = np.frombuffer(tensor.raw_data, dtype="float32")
float16_data = convert_np_to_float16(float32_data, min_positive_val, max_finite_val)
return np.amax(np.abs(float32_data - np.float32(float16_data)))

View file

@ -207,7 +207,10 @@ class FusionAttention(Fusion):
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
if q_weight is None:
print(f"{q_matmul.input[1]} is not initializer. Please set do_constant_folding=True in torch.onnx.export")
print(
f"{q_matmul.input[1]} is not an initializer. "
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
)
return None
if not (k_weight and v_weight and q_bias and k_bias):
return None
@ -227,7 +230,8 @@ class FusionAttention(Fusion):
if hidden_size > 0 and hidden_size != qw_in_size:
logger.warning(
f"Input hidden size {hidden_size} is not same as weight matrix dimension of q,k,v paths {qw_in_size}, provide correct input hidden size or pass 0"
f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
"Please provide a correct input hidden size or pass in 0"
)
is_qkv_diff_dims = False

View file

@ -12,7 +12,7 @@
# For Bert model file like name.onnx, optimized model for GPU or CPU from OnnxRuntime will output as
# name_ort_gpu.onnx or name_ort_cpu.onnx in the same directory.
#
# This script is retained for experiment purpose. Useful senarios like the following:
# This script is retained for experiment purpose. Useful scenarios like the following:
# (1) Change model from fp32 to fp16 for mixed precision inference in GPU with Tensor Core.
# (2) Change input data type from int64 to int32.
# (3) Some model cannot be handled by OnnxRuntime, and you can modify this script to get optimized model.
@ -142,7 +142,8 @@ def optimize_by_fusion(
if model.producer_name and producer != model.producer_name:
logger.warning(
f"Model producer not matched: Expect {producer}, Got {model.producer_name} {model.producer_version}. Please specify correct --model_type parameter."
f'Model producer not matched: Expected "{producer}", Got "{model.producer_name}".'
"Please specify correct --model_type parameter."
)
if optimization_options is None:
@ -168,7 +169,7 @@ def optimize_model(
num_heads: int = 0,
hidden_size: int = 0,
optimization_options: Optional[FusionOptions] = None,
opt_level: int = None,
opt_level: Optional[int] = None,
use_gpu: bool = False,
only_onnxruntime: bool = False,
):
@ -213,7 +214,7 @@ def optimize_model(
if model_type != "bert" and (num_heads == 0 or hidden_size == 0):
logger.warning("Please specify parameters of num_heads and hidden_size when model_type is not 'bert'")
(optimizer_class, producer, default_opt_level) = MODEL_TYPES[model_type]
(optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type]
if opt_level is None:
opt_level = default_opt_level
@ -226,7 +227,8 @@ def optimize_model(
if only_onnxruntime
else [
"MatMulScaleFusion",
"MatMulAddFusion" "SimplifiedLayerNormFusion",
"MatMulAddFusion",
"SimplifiedLayerNormFusion",
"GemmActivationFusion",
"BiasSoftmaxFusion",
]
@ -238,7 +240,7 @@ def optimize_model(
disabled_optimizers=disabled_optimizers,
)
elif opt_level == 1:
# basic optimizations (like constant folding and cast elimation) are not specified to exection provider.
# basic optimizations (like constant folding and cast elimination) are not specified to execution provider.
# CPU provider is used here so that there is no extra node for GPU memory copy.
temp_model_path = optimize_by_onnxruntime(input, use_gpu=False, opt_level=1)
@ -255,7 +257,7 @@ def optimize_model(
# Remove the temporary model.
if temp_model_path:
os.remove(temp_model_path)
logger.debug("Remove tempoary model: {}".format(temp_model_path))
logger.debug("Remove temporary model: {}".format(temp_model_path))
return optimizer

View file

@ -8,28 +8,30 @@
# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG
import os
import shutil
import unittest
import pytest
import torch
from model_loader import get_fusion_test_model, get_test_data_path
from onnx import TensorProto, load_model
from parity_utilities import find_transformers_source
from transformers import is_tf_available
if find_transformers_source():
from benchmark_helper import OptimizerInfo, Precision
from benchmark_helper import ConfigModifier, OptimizerInfo, Precision
from huggingface_models import MODELS
from onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf
from onnx_model import OnnxModel
from optimizer import optimize_model
else:
from onnxruntime.transformers.benchmark_helper import OptimizerInfo, Precision
from onnxruntime.transformers.benchmark_helper import ConfigModifier, OptimizerInfo, Precision
from onnxruntime.transformers.huggingface_models import MODELS
from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.optimizer import optimize_model
BERT_TEST_MODELS = {
TEST_MODELS = {
"bert_keras_0": (
"models",
"TFBertForSequenceClassification_1.onnx",
@ -46,22 +48,22 @@ BERT_TEST_MODELS = {
def _get_test_model_path(name):
sub_dir, file = BERT_TEST_MODELS[name]
sub_dir, file = TEST_MODELS[name]
if sub_dir == "FUSION":
return get_fusion_test_model(file)
else:
return get_test_data_path(sub_dir, file)
class TestBertOptimization(unittest.TestCase):
def verify_node_count(self, bert_model, expected_node_count, test_name):
class TestModelOptimization(unittest.TestCase):
def verify_node_count(self, onnx_model, expected_node_count, test_name):
for op_type, count in expected_node_count.items():
if len(bert_model.get_nodes_by_op_type(op_type)) != count:
if len(onnx_model.get_nodes_by_op_type(op_type)) != count:
print(f"Counters is not expected in test: {test_name}")
for op, counter in expected_node_count.items():
print("{}: {} expected={}".format(op, len(bert_model.get_nodes_by_op_type(op)), counter))
print("{}: {} expected={}".format(op, len(onnx_model.get_nodes_by_op_type(op)), counter))
self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count)
self.assertEqual(len(onnx_model.get_nodes_by_op_type(op_type)), count)
# add test function for huggingface pytorch model
def _test_optimizer_on_huggingface_model(
@ -72,8 +74,6 @@ class TestBertOptimization(unittest.TestCase):
validate_model=True,
):
# Remove cached model so that CI machine will have space
import shutil
shutil.rmtree("./cache_models", ignore_errors=True)
shutil.rmtree("./onnx_models", ignore_errors=True)
# expect fusion result list have the following keys
@ -82,15 +82,17 @@ class TestBertOptimization(unittest.TestCase):
input_names = MODELS[model_name][0]
import torch
config_modifier = ConfigModifier(None)
fusion_options = None
model_class = "AutoModel"
with torch.no_grad():
_, is_valid_onnx_model, _, _ = export_onnx_model_from_pt(
model_name,
MODELS[model_name][1],
MODELS[model_name][2],
MODELS[model_name][3],
None,
MODELS[model_name][1], # opset version
MODELS[model_name][2], # use_external_data_format
MODELS[model_name][3], # optimization model type
model_class,
config_modifier,
"./cache_models",
"./onnx_models",
input_names[:inputs_count],
@ -101,6 +103,7 @@ class TestBertOptimization(unittest.TestCase):
True,
True,
model_fusion_statistics,
fusion_options,
)
onnx_model = list(model_fusion_statistics.keys())[0]
@ -110,73 +113,6 @@ class TestBertOptimization(unittest.TestCase):
self.assertEqual(is_valid_onnx_model, True)
self.assertEqual(fusion_result_list, expected_fusion_result_list)
def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, inputs_count, validate_model=True):
# Remove cached model so that CI machine will have space
import shutil
shutil.rmtree("./cache_models", ignore_errors=True)
shutil.rmtree("./onnx_models", ignore_errors=True)
# expect fusion result list have the following keys
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
model_fusion_statistics = {}
print("testing mode ", model_name)
print("testing input number = ", inputs_count)
input_names = MODELS[model_name][0]
import torch
with torch.no_grad():
_, is_valid_onnx_model, _, _ = export_onnx_model_from_tf(
model_name,
MODELS[model_name][1],
MODELS[model_name][2],
MODELS[model_name][3],
None,
"./cache_models",
"./onnx_models",
input_names[:inputs_count],
False,
Precision.FLOAT32,
True,
True,
True,
True,
model_fusion_statistics,
)
onnx_model = list(model_fusion_statistics.keys())[0]
fusion_result_list = list(model_fusion_statistics[onnx_model].values())
if validate_model:
self.assertEqual(is_valid_onnx_model, True)
self.assertEqual(fusion_result_list, expected_fusion_result_list)
# def test_keras_model_1(self):
# input = _get_test_model_path('bert_keras_0')
# bert_model = optimize_model(input, 'bert_keras', num_heads=2, hidden_size=8)
# expected_node_count = {
# 'EmbedLayerNormalization': 1,
# 'Attention': 12,
# 'LayerNormalization': 0,
# 'SkipLayerNormalization': 24,
# 'BiasGelu': 12,
# 'Gelu': 0,
# 'FastGelu': 0
# }
# self.verify_node_count(bert_model, expected_node_count, 'test_keras_model_1')
# def test_keras_squad_model(self):
# input = _get_test_model_path('bert_keras_squad')
# bert_model = optimize_model(input, 'bert_keras', num_heads=2, hidden_size=8)
# print("fused_operator_statistics for test_keras_squad_model", bert_model.get_fused_operator_statistics())
# self.assertTrue(bert_model.is_fully_optimized())
def test_gpt2_past(self):
input = _get_test_model_path("gpt2_past")
model = optimize_model(input, "gpt2", num_heads=2, hidden_size=4)
@ -247,20 +183,6 @@ class TestBertOptimization(unittest.TestCase):
}
self.verify_node_count(model, expected_node_count, file)
# def test_bert_tf2onnx_0(self):
# input = _get_test_model_path('bert_tf2onnx_0')
# model = optimize_model(input, 'bert_tf', num_heads=2, hidden_size=8)
# expected_node_count = {
# 'EmbedLayerNormalization': 0,
# 'Attention': 6,
# 'Gelu': 0,
# 'FastGelu': 6,
# 'BiasGelu': 0,
# 'LayerNormalization': 0,
# 'SkipLayerNormalization': 13
# }
# self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0')
@pytest.mark.slow
def test_huggingface_bert_fusion_1(self):
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1)
@ -277,11 +199,13 @@ class TestBertOptimization(unittest.TestCase):
def test_huggingface_openaigpt_fusion(self):
self._test_optimizer_on_huggingface_model("openai-gpt", [0, 12, 0, 12, 0, 24, 0])
# @pytest.mark.slow
# def test_huggingface_gpt2_fusion(self):
# self._test_optimizer_on_huggingface_model("gpt2", [0, 12, 0, 12, 0, 25, 0])
@pytest.mark.slow
@unittest.skip("skip failed fusion test of gpt-2 on PyTorch 1.12 and transformers 4.18. TODO: fix it")
def test_huggingface_gpt2_fusion(self):
self._test_optimizer_on_huggingface_model("gpt2", [0, 12, 0, 12, 0, 25, 0])
@pytest.mark.slow
@unittest.skip("skip failed fusion test of xlm on PyTorch 1.12 and transformers 4.18. TODO: fix it")
def test_huggingface_xlm_fusion(self):
self._test_optimizer_on_huggingface_model("xlm-mlm-ende-1024", [0, 6, 0, 0, 6, 0, 13])
@ -294,26 +218,28 @@ class TestBertOptimization(unittest.TestCase):
self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=1)
self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=2)
# @pytest.mark.slow
# def test_huggingface_camembert_fusion(self):
# # output not close issue
# self._test_optimizer_on_huggingface_model("camembert-base", [0, 12, 0, 0, 12, 1, 24], validate_model=False)
@pytest.mark.slow
@unittest.skip("skip failed fusion test of camembert on PyTorch 1.12 and transformers 4.18. TODO: fix it")
def test_huggingface_camembert_fusion(self):
self._test_optimizer_on_huggingface_model("camembert-base", [0, 12, 0, 0, 12, 1, 24], validate_model=False)
@pytest.mark.slow
@unittest.skip("skip failed fusion test of albert on PyTorch 1.12 and transformers 4.18. TODO: fix it")
def test_huggingface_albert_fusion(self):
self._test_optimizer_on_huggingface_model("albert-base-v1", [0, 12, 0, 0, 12, 1, 24])
# @pytest.mark.slow
# def test_huggingface_t5_fusion(self):
# self._test_optimizer_on_huggingface_model("t5-small", [0, 0, 0, 0, 0, 0, 0])
@pytest.mark.slow
@unittest.skip("skip fusion test of t5 since it is not implemented yet")
def test_huggingface_t5_fusion(self):
self._test_optimizer_on_huggingface_model("t5-small", [0, 0, 0, 0, 0, 0, 0])
@pytest.mark.slow
def test_huggingface_xlmroberta_fusion(self):
self._test_optimizer_on_huggingface_model("xlm-roberta-base", [0, 12, 0, 0, 12, 1, 24])
@pytest.mark.slow
@unittest.skip("skip failed fusion test of flaubert on PyTorch 1.12 and transformers 4.18. TODO: fix it")
def test_huggingface_flaubert_fusion(self):
# output not close issue
self._test_optimizer_on_huggingface_model(
"flaubert/flaubert_base_cased",
[0, 12, 0, 0, 12, 0, 25],
@ -325,14 +251,67 @@ class TestBertOptimization(unittest.TestCase):
validate_model=False,
)
# @pytest.mark.slow
# def test_huggingface_dialogpt_fusion(self):
# self._test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0])
@pytest.mark.slow
@unittest.skip("skip failed fusion test of dialogpt on PyTorch 1.12 and transformers 4.18. TODO: fix it")
def test_huggingface_dialogpt_fusion(self):
self._test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0])
@pytest.mark.slow
def test_huggingface_bart_fusion(self):
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])
@unittest.skipUnless(is_tf_available(), "skip TestBertOptimizationTF since tensorflow is not available")
class TestTensorflowModelOptimization(unittest.TestCase):
def Setup(self):
try:
import tf2onnx
except ImportError:
self.skipTest("skip TestBertOptimizationTF since tf2onnx not installed")
def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, inputs_count, validate_model=True):
# Remove cached model so that CI machine will have space
shutil.rmtree("./cache_models", ignore_errors=True)
shutil.rmtree("./onnx_models", ignore_errors=True)
# expect fusion result list have the following keys
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
model_fusion_statistics = {}
print("testing mode ", model_name)
print("testing input number = ", inputs_count)
input_names = MODELS[model_name][0]
config_modifier = ConfigModifier(None)
fusion_options = None
model_class = "AutoModel"
with torch.no_grad():
_, is_valid_onnx_model, _, _ = export_onnx_model_from_tf(
model_name,
MODELS[model_name][1], # opset version
MODELS[model_name][2], # use_external_data_format
MODELS[model_name][3], # optimization model
model_class,
config_modifier,
"./cache_models",
"./onnx_models",
input_names[:inputs_count],
False,
Precision.FLOAT32,
True,
True,
True,
True,
model_fusion_statistics,
fusion_options,
)
onnx_model = list(model_fusion_statistics.keys())[0]
fusion_result_list = list(model_fusion_statistics[onnx_model].values())
if validate_model:
self.assertEqual(is_valid_onnx_model, True)
self.assertEqual(fusion_result_list, expected_fusion_result_list)
@pytest.mark.slow
def test_huggingface_bert_base_cased_from_tf2onnx_1(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1)