Fix issue there 9573-quantizing-distilbert-models-after-optimizing-wi… (#15659)

…th-ort-leads-to-invalid-node-input-names

### Description
Fix issue where Quantizing DistilBERT models after optimizing with ORT
leads to invalid node input names



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Jian Chen 2023-04-28 08:45:20 -07:00 committed by GitHub
parent 7e6331d5c7
commit c401cf4b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 11 deletions

View file

@ -802,7 +802,12 @@ class ONNXQuantizer:
zero_point_names.append(quantized_value.zp_name)
quantized_input_names.append(quantized_value.q_name)
continue
# adding this for case embed_layernorm.py has optional segment_embedding
if not node_input:
quantized_input_names.append("")
scale_names.append("")
zero_point_names.append("")
continue
# Quantize the input
initializer = find_by_name(node_input, self.model.initializer())
if initializer is not None:

View file

@ -27,7 +27,7 @@ class TestOpEmbedLayerNormalization(unittest.TestCase):
dr = TestDataFeeds(input_data_list)
return dr
def construct_model(self, batch, hidden_size, sequence_length, model_path):
def construct_model(self, batch, hidden_size, sequence_length, model_path, empty_segment=False):
# <segment_ids> <input_ids>
# \ /
# (EmbedLayerNormalization)
@ -72,10 +72,10 @@ class TestOpEmbedLayerNormalization(unittest.TestCase):
# EmbedLayerNormalization Node:
embed_layer_norm_inputs = [
"input_ids",
"segment_ids",
"segment_ids" if not empty_segment else "",
"word_embed",
"pos_embed",
"seg_embed",
"seg_embed" if not empty_segment else "",
"gamma",
"beta",
]
@ -92,13 +92,17 @@ class TestOpEmbedLayerNormalization(unittest.TestCase):
graph_name = "embed_layernorm_graph"
inputs = [input_ids_tensor, segment_ids_tensor]
outputs = [layernorm_out_tensor, mask_index_out_tensor]
initializers = [
word_embed_initializer,
pos_embed_initializer,
seg_embed_initializer,
gamma_initializer,
beta_initializer,
]
initializers = (
[
word_embed_initializer,
pos_embed_initializer,
seg_embed_initializer,
gamma_initializer,
beta_initializer,
]
if not empty_segment
else [word_embed_initializer, pos_embed_initializer, gamma_initializer, beta_initializer]
)
graph = helper.make_graph(nodes, graph_name, inputs, outputs, initializer=initializers)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 14)])
@ -132,6 +136,33 @@ class TestOpEmbedLayerNormalization(unittest.TestCase):
check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next())
def test_quantize_batch_size_1_empty_segment(self):
batch = 1
hidden_size = 4
sequence_length = 4
model_f32_path = "test_embed_layer_norm_unit_test_batch1_empty_segment.onnx"
model_uint8_path = "test_embed_layer_norm_unit_test_batch1_uint8_empty_segment.onnx"
self.construct_model(batch, hidden_size, sequence_length, model_f32_path, empty_segment=True)
data_reader = self.input_feeds_int32(
1,
{
"input_ids": [batch, sequence_length],
"segment_ids": [batch, sequence_length],
},
)
quantize_dynamic(model_f32_path, model_uint8_path)
# Quantization should not have any DequantizeLinear nodes:
qnode_counts = {"DequantizeLinear": 0, "QEmbedLayerNormalization": 1}
check_op_type_count(self, model_uint8_path, **qnode_counts)
data_reader.rewind()
check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next())
def test_quantize_batch_size_2(self):
batch = 2
hidden_size = 4
@ -159,6 +190,33 @@ class TestOpEmbedLayerNormalization(unittest.TestCase):
check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next())
def test_quantize_batch_size_2_empty_segment(self):
batch = 2
hidden_size = 4
sequence_length = 4
model_f32_path = "test_embed_layer_norm_unit_test_batch2_empty_segment.onnx"
model_uint8_path = "test_embed_layer_norm_unit_test_batch2_uint8_empty_segment.onnx"
self.construct_model(batch, hidden_size, sequence_length, model_f32_path, empty_segment=True)
data_reader = self.input_feeds_int32(
1,
{
"input_ids": [batch, sequence_length],
"segment_ids": [batch, sequence_length],
},
)
quantize_dynamic(model_f32_path, model_uint8_path)
# Quantization should not have any DequantizeLinear nodes:
qnode_counts = {"DequantizeLinear": 0, "QEmbedLayerNormalization": 1}
check_op_type_count(self, model_uint8_path, **qnode_counts)
data_reader.rewind()
check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next())
if __name__ == "__main__":
unittest.main()