[webgpu] Add Alias def for Flatten (#23038)

### Description

Add `Alias` definition for Flatten in WebGPU EP.

also add int32/uint32 in type constraint T.
This commit is contained in:
Yulong Wang 2024-12-09 14:19:43 -08:00 committed by GitHub
parent 6d9636f07c
commit 22ae97c7dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -13,7 +13,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
1, 8,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -21,7 +24,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
9, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -29,7 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -37,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
13, 20,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
ONNX_OPERATOR_KERNEL_EX(
@ -45,8 +57,11 @@ ONNX_OPERATOR_KERNEL_EX(
kOnnxDomain,
21,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);
} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime