diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md old mode 100644 new mode 100755 index 45b2f8da5f..fc6259e842 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1527,6 +1527,8 @@ This version of the operator has been available since version 1 of the 'com.micr
epsilon : float
The epsilon value to use to avoid division by zero.
+
mask_index_type : int
+
The mask index tensor type for shape inference (0: None, 1: 1D mask_index)
#### Inputs (7 - 9) @@ -1552,12 +1554,12 @@ This version of the operator has been available since version 1 of the 'com.micr
2D position ids with shape (batch_size, sequence_length) or (1, sequence_length)
-#### Outputs (2 - 3) +#### Outputs (1 - 3)
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
mask_index : T1
+
mask_index (optional) : T1
1D mask_index tensor with shape (batch_size)
embedding_sum (optional) : T
sum of word_embedding and position_embedding without layer normalization
diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc index 81d525c598..570f4108c3 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc @@ -149,7 +149,7 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { } // Calculate mask - if (nullptr != mask) { + if (nullptr != mask && nullptr != mask_index) { const int32_t* mask_data = mask->Data(); int32_t* mask_index_data = mask_index->MutableData(); for (int b = 0; b < batch_size; b++) { @@ -162,7 +162,7 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { } mask_index_data[b] = cur_sum; } - } else { + } else if (mask_index != nullptr) { memset(mask_index->MutableData(), 0, batch_size * sizeof(int32_t)); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/qembed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/quantization/qembed_layer_norm.cc index 50e77daec3..2907abbfe7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qembed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qembed_layer_norm.cc @@ -167,7 +167,7 @@ Status ComputeInternal(OpKernelContext* context, float epsilon) { } // Calculate mask - if (nullptr != mask) { + if (nullptr != mask && nullptr != mask_index) { const int32_t* mask_data = mask->Data(); int32_t* mask_index_data = mask_index->MutableData(); for (int b = 0; b < batch_size; b++) { @@ -180,7 +180,7 @@ Status ComputeInternal(OpKernelContext* context, float epsilon) { } mask_index_data[b] = cur_sum; } - } else { + } else if (mask_index != nullptr) { memset(mask_index->MutableData(), 0, batch_size * sizeof(int32_t)); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc index b4b2845b3d..864e2d1623 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc @@ -66,7 +66,7 @@ Status EmbedLayerNorm::ComputeInternal(OpKernelContext* context) const { return LaunchEmbedLayerNormKernel( Stream(context), output->MutableData(), - mask_index->MutableData(), + nullptr == mask_index ? nullptr : mask_index->MutableData(), input_ids->Data(), nullptr == segment_ids ? nullptr : segment_ids->Data(), nullptr == mask ? nullptr : mask->Data(), diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu index 3a88105603..a2dfca8cd6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -212,11 +212,14 @@ Status LaunchEmbedLayerNormKernel( void* embedding_sum, const int* position_ids, const bool broadcast_position_ids) { - if (nullptr == input_mask) { - CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream)); - } else { - ORT_RETURN_IF_ERROR( - ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); + + if (mask_index != nullptr) { + if (nullptr == input_mask) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream)); + } else { + ORT_RETURN_IF_ERROR( + ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); + } } if (element_size == 2) { diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 3767cb8b0a..c7466bdfa0 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -855,6 +855,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc(EmbedLayerNormalization_ver1_doc) .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultEmbedLayerNormEpsilon) + .Attr("mask_index_type", "The mask index tensor type for shape inference (0: None, 1: 1D mask_index)", AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "2D words IDs with shape (batch_size, sequence_length)", "T1") .Input(1, "segment_ids", "2D segment IDs with shape (batch_size, sequence_length)", "T1", OpSchema::Optional) .Input(2, "word_embedding", "2D with shape (,hidden_size)", "T") @@ -865,7 +866,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(7, "mask", "2D attention mask with shape (batch_size, sequence_length)", "T1", OpSchema::Optional) .Input(8, "position_ids", "2D position ids with shape (batch_size, sequence_length) or (1, sequence_length)", "T1", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .Output(1, "mask_index", "1D mask_index tensor with shape (batch_size)", "T1") + .Output(1, "mask_index", "1D mask_index tensor with shape (batch_size)", "T1", OpSchema::Optional) .Output(2, "embedding_sum", "sum of word_embedding and position_embedding without layer normalization", "T", OpSchema::Optional) .TypeConstraint("T1", {"tensor(int32)"}, "Constrain input and output integer tensors types") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output float tensors types.") diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc index 9c9e73bf1f..4324839c7a 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc @@ -3,12 +3,16 @@ #include "core/graph/contrib_ops/shape_inference_functions.h" #include +#include namespace onnxruntime { namespace contrib { void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 2, 0); - propagateElemTypeFromInputToOutput(ctx, 0, 1); + auto mask_index_type = getAttribute(ctx, "mask_index_type", 1); + if (mask_index_type > 0) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } if (!hasInputShape(ctx, 0)) { // TODO(kreeger): In this case update the output to (?, ?, hidden_size). return; @@ -97,11 +101,13 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c updateOutputShape(ctx, 0, output_shape); // mask_index shape is (batch_size) - ONNX_NAMESPACE::TensorShapeProto mask_index_shape; - *mask_index_shape.add_dim() = input_ids_dims[0]; - updateOutputShape(ctx, 1, mask_index_shape); + if (mask_index_type > 0) { + ONNX_NAMESPACE::TensorShapeProto mask_index_shape; + *mask_index_shape.add_dim() = input_ids_dims[0]; + updateOutputShape(ctx, 1, mask_index_shape); + } - if (ctx.getNumOutputs() > 2) { + if (ctx.getNumOutputs() == 3 || (ctx.getNumOutputs() == 2 && mask_index_type == 0)) { updateOutputShape(ctx, 2, output_shape); propagateElemTypeFromInputToOutput(ctx, 0, 2); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp index 3efddb9392..6a8333cd72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp @@ -459,7 +459,7 @@ public: maskIndexOutputEdge.FromNodeOutputIndex = 0; outputEdges.push_back(std::move(maskIndexOutputEdge)); } - else + else if (maskIndexDesc.Desc) { // Insert the edge feeding into the MaskIndex output DML_OUTPUT_GRAPH_EDGE_DESC maskIndexOutputEdge = {}; diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index b46045444e..8caab3d2cd 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2262,9 +2262,10 @@ class SymbolicShapeInference: vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape)) - mask_index_shape = [input_ids_shape[0]] - vi = self.known_vi_[node.output[1]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) + if len(node.output) > 1 and node.output[1]: + mask_index_shape = [input_ids_shape[0]] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) if len(node.output) > 2: # Optional output of add before layer normalization is done diff --git a/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc index 884f4422d5..0f35a7ff4b 100644 --- a/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc @@ -93,7 +93,7 @@ static void RunTest(const embedlayernorm::OpData& data, ToFloat16(data.beta_data), /*is_initializer=*/true); tester.AddAttribute("epsilon", data.epsilon); - if (data.has_mask) { + if (data.has_mask && data.mask_data.size()) { tester.AddInput("mask", mask_dims, data.mask_data); } tester.AddOutput("output", output_dims, ToFloat16(data.output_data)); @@ -117,12 +117,17 @@ static void RunTest(const embedlayernorm::OpData& data, tester.AddInput("gamma", gamma_dims, data.gamma_data, /*is_initializer=*/true); tester.AddInput("beta", beta_dims, data.beta_data, /*is_initializer=*/true); tester.AddAttribute("epsilon", data.epsilon); - if (data.has_mask) { + if (data.has_mask && data.mask_data.size()) { tester.AddInput("mask", mask_dims, data.mask_data); } tester.AddOutput("output", output_dims, data.output_data); } - tester.AddOutput("mask_index", mask_index_dims, data.mask_index_data); + tester.AddAttribute("mask_index_type", static_cast(data.mask_index_type)); + if (data.mask_index_data.size()) { + tester.AddOutput("mask_index", mask_index_dims, data.mask_index_data); + } else { + tester.AddOptionalOutputEdge(); + } if (sum_output) { std::vector embedding_sum_output_dims = output_dims; if (use_float16) { @@ -188,6 +193,13 @@ TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_EmbeddingSum) { TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_EmbeddingSum_Float16) { RunTest(embedlayernorm::EmbedLayerNormBatch1_EmbeddingSum(), true, true); } + +TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_EmbeddingSum_NoMaskIndex) { + RunTest(embedlayernorm::EmbedLayerNormBatch1_EmbeddingSum_NoMaskIndex(), + /* use_float16 = */ false, + /* sum_output = */ true); +} + TEST(EmbedLayerNormTest, EmbedLayerNormBatch2) { RunTest(embedlayernorm::EmbedLayerNormBatch2()); } diff --git a/onnxruntime/test/contrib_ops/embed_layer_norm_test_vectors.h b/onnxruntime/test/contrib_ops/embed_layer_norm_test_vectors.h index 67f9c90c07..ceb16c21ad 100644 --- a/onnxruntime/test/contrib_ops/embed_layer_norm_test_vectors.h +++ b/onnxruntime/test/contrib_ops/embed_layer_norm_test_vectors.h @@ -31,11 +31,12 @@ class OpData { const std::vector& output_data, const std::vector& mask_index_data, float epsilon = kEpsilon, + int mask_index_type = 1, bool has_mask = true, bool has_segment = true, const std::vector& embedding_sum_data = {}, const std::vector& position_ids_data = {}) - : batch_size(batch_size), sequence_size(sequence_size), hidden_size(hidden_size), input_ids_data(input_ids_data), segment_ids_data(segment_ids_data), mask_data(mask_data), word_embedding_data(word_embedding_data), position_embedding_data(position_embedding_data), segment_embedding_data(segment_embedding_data), gamma_data(gamma_data), beta_data(beta_data), output_data(output_data), mask_index_data(mask_index_data), epsilon(epsilon), has_mask(has_mask), has_segment(has_segment), embedding_sum_data(embedding_sum_data), position_ids_data(position_ids_data) {} + : batch_size(batch_size), sequence_size(sequence_size), hidden_size(hidden_size), input_ids_data(input_ids_data), segment_ids_data(segment_ids_data), mask_data(mask_data), word_embedding_data(word_embedding_data), position_embedding_data(position_embedding_data), segment_embedding_data(segment_embedding_data), gamma_data(gamma_data), beta_data(beta_data), output_data(output_data), mask_index_data(mask_index_data), epsilon(epsilon), mask_index_type(mask_index_type), has_mask(has_mask), has_segment(has_segment), embedding_sum_data(embedding_sum_data), position_ids_data(position_ids_data) {} const int batch_size; const int sequence_size; @@ -51,6 +52,7 @@ class OpData { const std::vector output_data; const std::vector mask_index_data; const float epsilon; + const int mask_index_type; const bool has_mask = true; const bool has_segment = true; const std::vector embedding_sum_data; @@ -110,6 +112,7 @@ inline OpData EmbedLayerNormBatch2(bool has_mask = true) { int batch_size = 3; int sequence_size = 2; int hidden_size = 4; + int mask_index_type = 1; std::vector input_ids_data = { 1, 3, @@ -169,7 +172,7 @@ inline OpData EmbedLayerNormBatch2(bool has_mask = true) { return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data, mask_data, word_embedding_data, position_embedding_data, segment_embedding_data, - gamma_data, beta_data, output_data, mask_index_data, kEpsilon, has_mask); + gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type, has_mask); } inline OpData EmbedLayerNormLargeBatchSmallHiddenSize() { @@ -245,6 +248,7 @@ inline OpData EmbedLayerNormBatch_Distill() { int batch_size = 3; int sequence_size = 2; int hidden_size = 4; + int mask_index_type = 1; std::vector input_ids_data = { 1, 3, @@ -292,7 +296,7 @@ inline OpData EmbedLayerNormBatch_Distill() { return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data, mask_data, word_embedding_data, position_embedding_data, segment_embedding_data, - gamma_data, beta_data, output_data, mask_index_data, kEpsilon, + gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type, /*has_mask=*/true, /*has_segment=*/false); } @@ -301,6 +305,7 @@ inline OpData EmbedLayerNormBatch1_PositionIds(bool diff_order = false) { int batch_size = 1; int sequence_size = 2; int hidden_size = 4; + int mask_index_type = 1; std::vector input_ids_data = { 1, 3}; @@ -356,7 +361,7 @@ inline OpData EmbedLayerNormBatch1_PositionIds(bool diff_order = false) { return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data, mask_data, word_embedding_data, position_embedding_data, segment_embedding_data, - gamma_data, beta_data, output_data, mask_index_data, kEpsilon, + gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type, /*has_mask=*/true, /*has_segment=*/false, embedding_sum_output_data, @@ -367,6 +372,7 @@ inline OpData EmbedLayerNormBatch3_PositionIds_BroadCast() { int batch_size = 3; int sequence_size = 2; int hidden_size = 4; + int mask_index_type = 1; std::vector input_ids_data = { 1, 3, 1, 3, 1, 3}; @@ -416,7 +422,7 @@ inline OpData EmbedLayerNormBatch3_PositionIds_BroadCast() { return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data, mask_data, word_embedding_data, position_embedding_data, segment_embedding_data, - gamma_data, beta_data, output_data, mask_index_data, kEpsilon, + gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type, /*has_mask=*/true, /*has_segment=*/false, embedding_sum_output_data, @@ -427,6 +433,7 @@ inline OpData EmbedLayerNormBatch1_EmbeddingSum() { int batch_size = 1; int sequence_size = 2; int hidden_size = 4; + int mask_index_type = 1; std::vector input_ids_data = { 1, 3}; @@ -470,11 +477,64 @@ inline OpData EmbedLayerNormBatch1_EmbeddingSum() { return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data, mask_data, word_embedding_data, position_embedding_data, segment_embedding_data, - gamma_data, beta_data, output_data, mask_index_data, kEpsilon, + gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type, /*has_mask=*/true, /*has_segment=*/false, embedding_sum_data); } + +inline OpData EmbedLayerNormBatch1_EmbeddingSum_NoMaskIndex() { + int batch_size = 1; + int sequence_size = 2; + int hidden_size = 4; + int mask_index_type = 0; + + std::vector input_ids_data = { + 1, 3}; + + std::vector segment_ids_data = {}; + + std::vector mask_data = {}; + + std::vector word_embedding_data = { + 0.2f, 0.1f, 0.4f, -0.6f, + 0.3f, 0.2f, 0.5f, 0.6f, + 0.6f, 0.7f, 0.0f, -0.1f, + 0.8f, 0.6f, 0.9f, 1.2f, + 0.1f, 0.3f, 0.5f, 0.9f, + 1.0f, -2.0f, 1.1f, 0.8f}; + + std::vector position_embedding_data = { + 0.1f, 0.1f, 0.4f, 0.6f, + 0.6f, 0.0f, 0.8f, 0.6f, + 0.3f, 0.9f, -2.0f, 0.8f}; + + std::vector segment_embedding_data = {}; + + std::vector gamma_data = { + 0.25f, 0.15f, 0.45f, -0.66f}; + + std::vector beta_data = { + 0.6f, 0.2f, 0.5f, -0.6f}; + + std::vector output_data = { + 0.39587587118148804, 0.03670068085193634, 0.7449488639831543, -1.4981462955474854, + 0.61326867341995239, -0.046796366572380066, 0.81048583984375, -1.1954958438873291}; + + std::vector mask_index_data = {}; + + std::vector embedding_sum_data = { + 0.40000000596046448, 0.30000001192092896, 0.89999997615814209, 1.2000000476837158, + 1.4000000953674316, 0.60000002384185791, 1.7000000476837158, 1.8000000715255737}; + + return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data, + mask_data, word_embedding_data, position_embedding_data, segment_embedding_data, + gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type, + /*has_mask=*/true, + /*has_segment=*/false, + embedding_sum_data); +} + } // namespace embedlayernorm } // namespace test } // namespace onnxruntime