mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix SkipLayerNormalization shape inference (#18724)
SkipLayerNorm has more than one input, so `propagateShapeAndTypeFromFirstInput` is not enough.
This commit is contained in:
parent
e2e488d6f8
commit
80f274ca6f
3 changed files with 43 additions and 3 deletions
|
|
@ -1285,7 +1285,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
.Output(3, "input_skip_bias_sum", "Sum of the input and skip inputs (and bias if it exists) with shape (batch_size, sequence_length, hidden_size).", "T", OpSchema::Optional)
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
|
||||
.TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
|
||||
.TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
SkipSimplifiedLayerNormalization, 1,
|
||||
|
|
@ -1334,7 +1334,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
OpSchema::Optional)
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
|
||||
.TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
|
||||
.TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference));
|
||||
|
||||
constexpr const char* NGramRepeatBlock_ver1_doc = R"DOC(
|
||||
Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form a repeated n-gram if added to the back of the input_ids.
|
||||
|
|
|
|||
|
|
@ -114,6 +114,45 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c
|
|||
}
|
||||
}
|
||||
|
||||
void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateShapeAndTypeFromFirstInput(ctx);
|
||||
|
||||
auto stash_type = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
auto output_type = ctx.getOutputType(1);
|
||||
output_type->mutable_tensor_type()->set_elem_type(static_cast<int32_t>(stash_type));
|
||||
}
|
||||
if (ctx.getNumOutputs() > 2) {
|
||||
auto output_type = ctx.getOutputType(2);
|
||||
output_type->mutable_tensor_type()->set_elem_type(static_cast<int32_t>(stash_type));
|
||||
}
|
||||
if (ctx.getNumOutputs() > 3) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 3);
|
||||
}
|
||||
if (!hasNInputShapes(ctx, 1)) {
|
||||
return;
|
||||
}
|
||||
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
int64_t input_ndim = input_shape.dim_size();
|
||||
int axis = static_cast<int>(input_ndim - 1);
|
||||
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape();
|
||||
mean_shape->CopyFrom(input_shape);
|
||||
mean_shape->mutable_dim(axis)->set_dim_value(1);
|
||||
}
|
||||
|
||||
if (ctx.getNumOutputs() > 2) {
|
||||
auto inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape();
|
||||
inv_std_dev_shape->CopyFrom(input_shape);
|
||||
inv_std_dev_shape->mutable_dim(axis)->set_dim_value(1);
|
||||
}
|
||||
|
||||
if (ctx.getNumOutputs() > 3) {
|
||||
propagateShapeFromInputToOutput(ctx, 0, 3);
|
||||
}
|
||||
}
|
||||
|
||||
// Shape inference for Attention and QAttention
|
||||
void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) {
|
||||
// Input 0, 1, 2 are input, weights and bias.
|
||||
|
|
|
|||
|
|
@ -13,5 +13,6 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index);
|
||||
void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx);
|
||||
void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx);
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue