From 252222034f1baffe06776eba645ede99d4220201 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 28 Aug 2024 00:02:39 +0800 Subject: [PATCH] [js/webgpu] Support Reshape/Shape 21+ on jsep (#21871) ### Description #21618 With this PR, the cross device copying (`MemcpyToHost`) can totally be removed for model `wav2vec2`. And the overall time becomes 48ms from 604ms. ### Motivation and Context --- js/web/docs/webgpu-operators.md | 4 +-- .../providers/js/js_execution_provider.cc | 16 +++++++++--- .../core/providers/js/operators/reshape.cc | 26 ++++++++++++++++++- .../core/providers/js/operators/shape_op.cc | 26 ++++++++++++++++++- 4 files changed, 64 insertions(+), 8 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 425d479ad3..1c140de448 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -90,10 +90,10 @@ Do not modify directly.* | ReduceSum | ai.onnx(1-10,11-12,13+) | | | ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | | | Relu | ai.onnx(6-12,13,14+) | | -| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel | +| Reshape | ai.onnx(5-12,13,14-18,19-20,21+) | no GPU kernel | | Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | RotaryEmbedding | com.microsoft(1+) | | -| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix | +| Shape | ai.onnx(1-12,13-14,15-18,19-20,21+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | SimplifiedLayerNormalization | ai.onnx(1+) | | | Sin | ai.onnx(7+) | | diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index e289cba956..781083ea87 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -220,11 +220,15 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Les class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, 18, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Shape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 5, 12, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, 18, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze); @@ -484,11 +488,15 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/reshape.cc b/onnxruntime/core/providers/js/operators/reshape.cc index bd6772649a..e2a8d04e1a 100644 --- a/onnxruntime/core/providers/js/operators/reshape.cc +++ b/onnxruntime/core/providers/js/operators/reshape.cc @@ -10,7 +10,31 @@ namespace js { ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, - 14, + 21, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 19, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 14, 18, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedDataTypes()) diff --git a/onnxruntime/core/providers/js/operators/shape_op.cc b/onnxruntime/core/providers/js/operators/shape_op.cc index 733580a469..b1c0aa389f 100644 --- a/onnxruntime/core/providers/js/operators/shape_op.cc +++ b/onnxruntime/core/providers/js/operators/shape_op.cc @@ -32,10 +32,34 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 15, 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + // properly force CPU/GPU synch inside the kernel + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 19, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + // properly force CPU/GPU synch inside the kernel + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + ONNX_OPERATOR_KERNEL_EX( Shape, kOnnxDomain, - 15, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) // properly force CPU/GPU synch inside the kernel