mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
add bfloat16 support for where operator (#18118)
### Description <!-- Describe your changes. --> Adds bfloat16 as a valid input parameter type for where node for ONNX opset 16+. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime training. --------- Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
c8e1038eab
commit
8978bdc59d
4 changed files with 5 additions and 1 deletions
|
|
@ -801,7 +801,7 @@ Do not modify directly.*
|
|||
|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Upsample|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T**<br> *out* Y:**T**|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|
||||
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|
||||
|Where|*in* condition:**B**<br> *in* X:**T**<br> *in* Y:**T**<br> *out* output:**T**|16+|**B** = tensor(bool)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|
||||
|Where|*in* condition:**B**<br> *in* X:**T**<br> *in* Y:**T**<br> *out* output:**T**|16+|**B** = tensor(bool)<br/> **T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|
||||
|||[9, 15]|**B** = tensor(bool)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|
||||
|Xor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|
||||
| |
|
||||
|
|
|
|||
|
|
@ -1228,6 +1228,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 18, Scan);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, Where);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, BFloat16, Where);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, Where);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double_t, Where);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, int32_t, Where);
|
||||
|
|
@ -2111,6 +2112,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 18, Scan)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, BFloat16, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double_t, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, int32_t, Where)>,
|
||||
|
|
|
|||
|
|
@ -216,5 +216,6 @@ SPECIALIZED_COMPUTE(int64_t)
|
|||
SPECIALIZED_COMPUTE(float)
|
||||
SPECIALIZED_COMPUTE(double_t)
|
||||
SPECIALIZED_COMPUTE(MLFloat16)
|
||||
SPECIALIZED_COMPUTE(BFloat16)
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -238,6 +238,7 @@ SPECIALIZED_IMPL(int64_t)
|
|||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(double_t)
|
||||
SPECIALIZED_IMPL(half)
|
||||
SPECIALIZED_IMPL(BFloat16)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue