add bfloat16 support for ConcatTraining and SplitTraining ops (#18280)

### Description
<!-- Describe your changes. -->

Updates input/output type constraints on training operators
ConcatTraining and SplitTraining to include bfloat16 which was
introduced in IR version 4.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Prathik Rao 2023-11-07 10:10:01 -08:00 committed by GitHub
parent a16d528399
commit 83c0275354
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2193,7 +2193,7 @@ Example 4:
OpSchema::Variadic)
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
for (int i = 0; i < static_cast<int>(ctx.getNumOutputs()); ++i) {
@ -2270,7 +2270,7 @@ Example 4:
OpSchema::Optional)
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain output types to any tensor type.")
.TypeConstraint(
"Tint",