[js/webgpu] Add C++ registration for operator Tanh in JSEP (#17124)

add webgpu/tanh

Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
This commit is contained in:
Guenther Schmuelling 2023-08-12 11:43:39 -07:00 committed by GitHub
parent e7adbb38f6
commit 9204cd7392
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 0 deletions

View file

@ -75,6 +75,7 @@ Do not modify directly.*
| Squeeze | ai.onnx(1-10,11-12,13+) | |
| Sub | ai.onnx(7-12,13,14+) | |
| Tan | ai.onnx(7+) | |
| Tanh | ai.onnx(6-12,13+) | |
| ThresholdedRelu | ai.onnx(10+) | |
| Transpose | ai.onnx(1-12,13+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13+) | |

View file

@ -107,6 +107,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Cosh
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Asinh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Acosh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Atanh);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tanh);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 10, Clip);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, Clip);
@ -329,6 +331,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(9, Asinh),
KERNEL_CREATE_INFO(9, Acosh),
KERNEL_CREATE_INFO(9, Atanh),
KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh),
KERNEL_CREATE_INFO(13, Tanh),
// activations
KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip),

View file

@ -89,6 +89,10 @@ JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh)
JSEP_KERNEL_IMPL(Atanh, Atanh)
JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh)
JSEP_KERNEL_IMPL(Tanh, Tanh)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh)
JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh)
// activation
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)