diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 80d937fa16..283883c2e3 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -2193,7 +2193,7 @@ Example 4: OpSchema::Variadic) .TypeConstraint( "T", - OpSchema::all_tensor_types(), + OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { for (int i = 0; i < static_cast(ctx.getNumOutputs()); ++i) { @@ -2270,7 +2270,7 @@ Example 4: OpSchema::Optional) .TypeConstraint( "T", - OpSchema::all_tensor_types(), + OpSchema::all_tensor_types_ir4(), "Constrain output types to any tensor type.") .TypeConstraint( "Tint",