From 83c0275354bf21845df21d83dc9e4e249c8e74ed Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 7 Nov 2023 10:10:01 -0800 Subject: [PATCH] add bfloat16 support for ConcatTraining and SplitTraining ops (#18280) ### Description Updates input/output type constraints on training operators ConcatTraining and SplitTraining to include bfloat16 which was introduced in IR version 4. ### Motivation and Context Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime training. Co-authored-by: Prathik Rao --- orttraining/orttraining/core/graph/training_op_defs.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 80d937fa16..283883c2e3 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -2193,7 +2193,7 @@ Example 4: OpSchema::Variadic) .TypeConstraint( "T", - OpSchema::all_tensor_types(), + OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { for (int i = 0; i < static_cast(ctx.getNumOutputs()); ++i) { @@ -2270,7 +2270,7 @@ Example 4: OpSchema::Optional) .TypeConstraint( "T", - OpSchema::all_tensor_types(), + OpSchema::all_tensor_types_ir4(), "Constrain output types to any tensor type.") .TypeConstraint( "Tint",