[js/webgpu] Support Reshape/Shape 21+ on jsep (#21871)

### Description
<!-- Describe your changes. -->
#21618

With this PR, the cross device copying (`MemcpyToHost`) can totally be
removed for model `wav2vec2`. And the overall time becomes 48ms from
604ms.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Jiajia Qin 2024-08-28 00:02:39 +08:00 committed by GitHub
parent 5d54dc1462
commit 252222034f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 64 additions and 8 deletions

View file

@ -90,10 +90,10 @@ Do not modify directly.*
| ReduceSum | ai.onnx(1-10,11-12,13+) | |
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
| Relu | ai.onnx(6-12,13,14+) | |
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
| Reshape | ai.onnx(5-12,13,14-18,19-20,21+) | no GPU kernel |
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
| RotaryEmbedding | com.microsoft(1+) | |
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
| Shape | ai.onnx(1-12,13-14,15-18,19-20,21+) | no GPU kernel; an ORT warning is generated - need to fix |
| Sigmoid | ai.onnx(6-12,13+) | |
| SimplifiedLayerNormalization | ai.onnx(1+) | |
| Sin | ai.onnx(7+) | |

View file

@ -220,11 +220,15 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Les
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, 18, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 5, 12, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, 18, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
@ -484,11 +488,15 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, 18, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 5, 12, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, 18, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,

View file

@ -10,7 +10,31 @@ namespace js {
ONNX_OPERATOR_KERNEL_EX(
Reshape,
kOnnxDomain,
14,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes())
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
.Alias(0, 0)
.InputMemoryType(OrtMemTypeCPU, 1),
Reshape);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Reshape,
kOnnxDomain,
19, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes())
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
.Alias(0, 0)
.InputMemoryType(OrtMemTypeCPU, 1),
Reshape);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Reshape,
kOnnxDomain,
14, 18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes())

View file

@ -32,10 +32,34 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Shape,
kOnnxDomain,
15, 18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
// properly force CPU/GPU synch inside the kernel
.OutputMemoryType(OrtMemTypeCPU, 0)
.TypeConstraint("T", JsepSupportedDataTypes())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Shape,
kOnnxDomain,
19, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
// properly force CPU/GPU synch inside the kernel
.OutputMemoryType(OrtMemTypeCPU, 0)
.TypeConstraint("T", JsepSupportedDataTypes())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
ONNX_OPERATOR_KERNEL_EX(
Shape,
kOnnxDomain,
15,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
// properly force CPU/GPU synch inside the kernel