handle fp16 for where op (#19969)

this prevents falling back from webgpu to cpu, aka helps performance
This commit is contained in:
Guenther Schmuelling 2024-03-18 13:42:51 -07:00 committed by GitHub
parent 141966bb69
commit a4ac727cbb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,18 +6,19 @@
namespace onnxruntime {
namespace js {
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", \
{DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>(), \
DataTypeImpl::GetTensorType<uint32_t>(), \
DataTypeImpl::GetTensorType<bool>()}), \
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", \
{DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>(), \
DataTypeImpl::GetTensorType<uint32_t>(), \
DataTypeImpl::GetTensorType<bool>()}), \
KERNEL_CLASS);
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
@ -29,6 +30,7 @@ namespace js {
KernelDefBuilder() \
.TypeConstraint("T", \
{DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>(), \
DataTypeImpl::GetTensorType<uint32_t>(), \
DataTypeImpl::GetTensorType<bool>()}), \