Fix SkipLayerNormalization shape inference (#18724)

SkipLayerNorm has more than one input, so `propagateShapeAndTypeFromFirstInput` is not enough.
This commit is contained in:
Patrice Vignola 2024-01-16 09:42:59 -08:00 committed by GitHub
parent e2e488d6f8
commit 80f274ca6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 3 deletions

View file

@ -1285,7 +1285,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Output(3, "input_skip_bias_sum", "Sum of the input and skip inputs (and bias if it exists) with shape (batch_size, sequence_length, hidden_size).", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
.TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
.TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference));
ONNX_MS_OPERATOR_SET_SCHEMA(
SkipSimplifiedLayerNormalization, 1,
@ -1334,7 +1334,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
.TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
.TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference));
constexpr const char* NGramRepeatBlock_ver1_doc = R"DOC(
Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form a repeated n-gram if added to the back of the input_ids.

View file

@ -114,6 +114,45 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c
}
}
void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) {
propagateShapeAndTypeFromFirstInput(ctx);
auto stash_type = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
if (ctx.getNumOutputs() > 1) {
auto output_type = ctx.getOutputType(1);
output_type->mutable_tensor_type()->set_elem_type(static_cast<int32_t>(stash_type));
}
if (ctx.getNumOutputs() > 2) {
auto output_type = ctx.getOutputType(2);
output_type->mutable_tensor_type()->set_elem_type(static_cast<int32_t>(stash_type));
}
if (ctx.getNumOutputs() > 3) {
propagateElemTypeFromInputToOutput(ctx, 0, 3);
}
if (!hasNInputShapes(ctx, 1)) {
return;
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
int64_t input_ndim = input_shape.dim_size();
int axis = static_cast<int>(input_ndim - 1);
if (ctx.getNumOutputs() > 1) {
auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape();
mean_shape->CopyFrom(input_shape);
mean_shape->mutable_dim(axis)->set_dim_value(1);
}
if (ctx.getNumOutputs() > 2) {
auto inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape();
inv_std_dev_shape->CopyFrom(input_shape);
inv_std_dev_shape->mutable_dim(axis)->set_dim_value(1);
}
if (ctx.getNumOutputs() > 3) {
propagateShapeFromInputToOutput(ctx, 0, 3);
}
}
// Shape inference for Attention and QAttention
void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) {
// Input 0, 1, 2 are input, weights and bias.

View file

@ -13,5 +13,6 @@ namespace onnxruntime {
namespace contrib {
void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index);
void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx);
void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx);
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime