mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
handle fp16 for where op (#19969)
this prevents falling back from webgpu to cpu, aka helps performance
This commit is contained in:
parent
141966bb69
commit
a4ac727cbb
1 changed files with 14 additions and 12 deletions
|
|
@ -6,18 +6,19 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", \
|
||||
{DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>(), \
|
||||
DataTypeImpl::GetTensorType<uint32_t>(), \
|
||||
DataTypeImpl::GetTensorType<bool>()}), \
|
||||
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", \
|
||||
{DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>(), \
|
||||
DataTypeImpl::GetTensorType<uint32_t>(), \
|
||||
DataTypeImpl::GetTensorType<bool>()}), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
|
||||
|
|
@ -29,6 +30,7 @@ namespace js {
|
|||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", \
|
||||
{DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>(), \
|
||||
DataTypeImpl::GetTensorType<uint32_t>(), \
|
||||
DataTypeImpl::GetTensorType<bool>()}), \
|
||||
|
|
|
|||
Loading…
Reference in a new issue