add bf16 for Tile CUDA executor (#20854)

### Description
add bf16 for Tile CUDA executor



### Motivation and Context
required change to support phimm model for ORT training
This commit is contained in:
Frank Dong 2024-06-17 05:52:13 -07:00 committed by GitHub
parent 0babc33725
commit 8aa2667ae6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 2 deletions

View file

@ -827,7 +827,7 @@ Do not modify directly.*
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|ThresholdedRelu|*in* X:**T**<br> *out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)|
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|

View file

@ -36,7 +36,8 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<MLFloat16>()})
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<BFloat16>()})
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Tile);