mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/webgpu] Support Reshape/Shape 21+ on jsep (#21871)
### Description <!-- Describe your changes. --> #21618 With this PR, the cross device copying (`MemcpyToHost`) can totally be removed for model `wav2vec2`. And the overall time becomes 48ms from 604ms. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
5d54dc1462
commit
252222034f
4 changed files with 64 additions and 8 deletions
|
|
@ -90,10 +90,10 @@ Do not modify directly.*
|
|||
| ReduceSum | ai.onnx(1-10,11-12,13+) | |
|
||||
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
|
||||
| Relu | ai.onnx(6-12,13,14+) | |
|
||||
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
|
||||
| Reshape | ai.onnx(5-12,13,14-18,19-20,21+) | no GPU kernel |
|
||||
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
|
||||
| RotaryEmbedding | com.microsoft(1+) | |
|
||||
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
|
||||
| Shape | ai.onnx(1-12,13-14,15-18,19-20,21+) | no GPU kernel; an ORT warning is generated - need to fix |
|
||||
| Sigmoid | ai.onnx(6-12,13+) | |
|
||||
| SimplifiedLayerNormalization | ai.onnx(1+) | |
|
||||
| Sin | ai.onnx(7+) | |
|
||||
|
|
|
|||
|
|
@ -220,11 +220,15 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Les
|
|||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, 18, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Shape);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 5, 12, Reshape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, 18, Reshape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Reshape);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
|
||||
|
|
@ -484,11 +488,15 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, 18, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Shape)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 5, 12, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, 18, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Reshape)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,31 @@ namespace js {
|
|||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Reshape,
|
||||
kOnnxDomain,
|
||||
14,
|
||||
21,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedDataTypes())
|
||||
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.Alias(0, 0)
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Reshape);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Reshape,
|
||||
kOnnxDomain,
|
||||
19, 20,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedDataTypes())
|
||||
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.Alias(0, 0)
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Reshape);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Reshape,
|
||||
kOnnxDomain,
|
||||
14, 18,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedDataTypes())
|
||||
|
|
|
|||
|
|
@ -32,10 +32,34 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Shape,
|
||||
kOnnxDomain,
|
||||
15, 18,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
// properly force CPU/GPU synch inside the kernel
|
||||
.OutputMemoryType(OrtMemTypeCPU, 0)
|
||||
.TypeConstraint("T", JsepSupportedDataTypes())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Shape,
|
||||
kOnnxDomain,
|
||||
19, 20,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
// properly force CPU/GPU synch inside the kernel
|
||||
.OutputMemoryType(OrtMemTypeCPU, 0)
|
||||
.TypeConstraint("T", JsepSupportedDataTypes())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Shape,
|
||||
kOnnxDomain,
|
||||
15,
|
||||
21,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
// properly force CPU/GPU synch inside the kernel
|
||||
|
|
|
|||
Loading…
Reference in a new issue