From 80f274ca6f2f4572d827edd6dc7f736d7a8c036a Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 16 Jan 2024 09:42:59 -0800 Subject: [PATCH] Fix SkipLayerNormalization shape inference (#18724) SkipLayerNorm has more than one input, so `propagateShapeAndTypeFromFirstInput` is not enough. --- .../core/graph/contrib_ops/bert_defs.cc | 4 +- .../contrib_ops/shape_inference_functions.cc | 39 +++++++++++++++++++ .../contrib_ops/shape_inference_functions.h | 3 +- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index df8d0a59cb..0317ffcfb0 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1285,7 +1285,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(3, "input_skip_bias_sum", "Sum of the input and skip inputs (and bias if it exists) with shape (batch_size, sequence_length, hidden_size).", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); ONNX_MS_OPERATOR_SET_SCHEMA( SkipSimplifiedLayerNormalization, 1, @@ -1334,7 +1334,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); constexpr const char* NGramRepeatBlock_ver1_doc = R"DOC( Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form a repeated n-gram if added to the back of the input_ids. diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc index eeef20e9df..8b1812f62b 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc @@ -114,6 +114,45 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c } } +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + + auto stash_type = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + if (ctx.getNumOutputs() > 1) { + auto output_type = ctx.getOutputType(1); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 2) { + auto output_type = ctx.getOutputType(2); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 3) { + propagateElemTypeFromInputToOutput(ctx, 0, 3); + } + if (!hasNInputShapes(ctx, 1)) { + return; + } + auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + int64_t input_ndim = input_shape.dim_size(); + int axis = static_cast(input_ndim - 1); + + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); + mean_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 2) { + auto inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape(); + inv_std_dev_shape->CopyFrom(input_shape); + inv_std_dev_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 3) { + propagateShapeFromInputToOutput(ctx, 0, 3); + } +} + // Shape inference for Attention and QAttention void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { // Input 0, 1, 2 are input, weights and bias. diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h index 93cf5b304f..6eb06af153 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h @@ -13,5 +13,6 @@ namespace onnxruntime { namespace contrib { void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index); void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime