diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index dcdf73cbdb..bbfc33a915 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -801,7 +801,7 @@ Do not modify directly.* |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| +|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |||[9, 15]|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |Xor|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| | | diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4f5469ad8d..2d242d7d6f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1228,6 +1228,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, BFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, int32_t, Where); @@ -2111,6 +2112,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index b3f92c913a..4d98b3a8a1 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -216,5 +216,6 @@ SPECIALIZED_COMPUTE(int64_t) SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double_t) SPECIALIZED_COMPUTE(MLFloat16) +SPECIALIZED_COMPUTE(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/where_impl.cu b/onnxruntime/core/providers/cuda/tensor/where_impl.cu index 0fbd2062ca..d7909454e9 100644 --- a/onnxruntime/core/providers/cuda/tensor/where_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/where_impl.cu @@ -238,6 +238,7 @@ SPECIALIZED_IMPL(int64_t) SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double_t) SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime