mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
add bf16 for Tile CUDA executor (#20854)
### Description add bf16 for Tile CUDA executor ### Motivation and Context required change to support phimm model for ORT training
This commit is contained in:
parent
0babc33725
commit
8aa2667ae6
2 changed files with 3 additions and 2 deletions
|
|
@ -827,7 +827,7 @@ Do not modify directly.*
|
|||
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|ThresholdedRelu|*in* X:**T**<br> *out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|
||||
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|
||||
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|
||||
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|
||||
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()})
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<BFloat16>()})
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Tile);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue