From c401cf4b5163bc5dd0d60970de40d4067febcf75 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Fri, 28 Apr 2023 08:45:20 -0700 Subject: [PATCH] =?UTF-8?q?Fix=20issue=20there=209573-quantizing-distilber?= =?UTF-8?q?t-models-after-optimizing-wi=E2=80=A6=20(#15659)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …th-ort-leads-to-invalid-node-input-names ### Description Fix issue where Quantizing DistilBERT models after optimizing with ORT leads to invalid node input names ### Motivation and Context --- .../tools/quantization/onnx_quantizer.py | 7 +- .../quantization/test_op_embed_layernorm.py | 78 ++++++++++++++++--- 2 files changed, 74 insertions(+), 11 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 33cd8bb7bc..6625a15637 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -802,7 +802,12 @@ class ONNXQuantizer: zero_point_names.append(quantized_value.zp_name) quantized_input_names.append(quantized_value.q_name) continue - + # adding this for case embed_layernorm.py has optional segment_embedding + if not node_input: + quantized_input_names.append("") + scale_names.append("") + zero_point_names.append("") + continue # Quantize the input initializer = find_by_name(node_input, self.model.initializer()) if initializer is not None: diff --git a/onnxruntime/test/python/quantization/test_op_embed_layernorm.py b/onnxruntime/test/python/quantization/test_op_embed_layernorm.py index c68dac39b0..0b726abff4 100644 --- a/onnxruntime/test/python/quantization/test_op_embed_layernorm.py +++ b/onnxruntime/test/python/quantization/test_op_embed_layernorm.py @@ -27,7 +27,7 @@ class TestOpEmbedLayerNormalization(unittest.TestCase): dr = TestDataFeeds(input_data_list) return dr - def construct_model(self, batch, hidden_size, sequence_length, model_path): + def construct_model(self, batch, hidden_size, sequence_length, model_path, empty_segment=False): # # \ / # (EmbedLayerNormalization) @@ -72,10 +72,10 @@ class TestOpEmbedLayerNormalization(unittest.TestCase): # EmbedLayerNormalization Node: embed_layer_norm_inputs = [ "input_ids", - "segment_ids", + "segment_ids" if not empty_segment else "", "word_embed", "pos_embed", - "seg_embed", + "seg_embed" if not empty_segment else "", "gamma", "beta", ] @@ -92,13 +92,17 @@ class TestOpEmbedLayerNormalization(unittest.TestCase): graph_name = "embed_layernorm_graph" inputs = [input_ids_tensor, segment_ids_tensor] outputs = [layernorm_out_tensor, mask_index_out_tensor] - initializers = [ - word_embed_initializer, - pos_embed_initializer, - seg_embed_initializer, - gamma_initializer, - beta_initializer, - ] + initializers = ( + [ + word_embed_initializer, + pos_embed_initializer, + seg_embed_initializer, + gamma_initializer, + beta_initializer, + ] + if not empty_segment + else [word_embed_initializer, pos_embed_initializer, gamma_initializer, beta_initializer] + ) graph = helper.make_graph(nodes, graph_name, inputs, outputs, initializer=initializers) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 14)]) @@ -132,6 +136,33 @@ class TestOpEmbedLayerNormalization(unittest.TestCase): check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next()) + def test_quantize_batch_size_1_empty_segment(self): + batch = 1 + hidden_size = 4 + sequence_length = 4 + + model_f32_path = "test_embed_layer_norm_unit_test_batch1_empty_segment.onnx" + model_uint8_path = "test_embed_layer_norm_unit_test_batch1_uint8_empty_segment.onnx" + + self.construct_model(batch, hidden_size, sequence_length, model_f32_path, empty_segment=True) + + data_reader = self.input_feeds_int32( + 1, + { + "input_ids": [batch, sequence_length], + "segment_ids": [batch, sequence_length], + }, + ) + + quantize_dynamic(model_f32_path, model_uint8_path) + + # Quantization should not have any DequantizeLinear nodes: + qnode_counts = {"DequantizeLinear": 0, "QEmbedLayerNormalization": 1} + check_op_type_count(self, model_uint8_path, **qnode_counts) + data_reader.rewind() + + check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next()) + def test_quantize_batch_size_2(self): batch = 2 hidden_size = 4 @@ -159,6 +190,33 @@ class TestOpEmbedLayerNormalization(unittest.TestCase): check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next()) + def test_quantize_batch_size_2_empty_segment(self): + batch = 2 + hidden_size = 4 + sequence_length = 4 + + model_f32_path = "test_embed_layer_norm_unit_test_batch2_empty_segment.onnx" + model_uint8_path = "test_embed_layer_norm_unit_test_batch2_uint8_empty_segment.onnx" + + self.construct_model(batch, hidden_size, sequence_length, model_f32_path, empty_segment=True) + + data_reader = self.input_feeds_int32( + 1, + { + "input_ids": [batch, sequence_length], + "segment_ids": [batch, sequence_length], + }, + ) + + quantize_dynamic(model_f32_path, model_uint8_path) + + # Quantization should not have any DequantizeLinear nodes: + qnode_counts = {"DequantizeLinear": 0, "QEmbedLayerNormalization": 1} + check_op_type_count(self, model_uint8_path, **qnode_counts) + data_reader.rewind() + + check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next()) + if __name__ == "__main__": unittest.main()