mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
add bfloat16 support for ConcatTraining and SplitTraining ops (#18280)
### Description <!-- Describe your changes. --> Updates input/output type constraints on training operators ConcatTraining and SplitTraining to include bfloat16 which was introduced in IR version 4. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime training. Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
a16d528399
commit
83c0275354
1 changed files with 2 additions and 2 deletions
|
|
@ -2193,7 +2193,7 @@ Example 4:
|
|||
OpSchema::Variadic)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
OpSchema::all_tensor_types_ir4(),
|
||||
"Constrain input and output types to all tensor types.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
for (int i = 0; i < static_cast<int>(ctx.getNumOutputs()); ++i) {
|
||||
|
|
@ -2270,7 +2270,7 @@ Example 4:
|
|||
OpSchema::Optional)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
OpSchema::all_tensor_types_ir4(),
|
||||
"Constrain output types to any tensor type.")
|
||||
.TypeConstraint(
|
||||
"Tint",
|
||||
|
|
|
|||
Loading…
Reference in a new issue