mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[webgpu] Add Alias def for Flatten (#23038)
### Description Add `Alias` definition for Flatten in WebGPU EP. also add int32/uint32 in type constraint T.
This commit is contained in:
parent
6d9636f07c
commit
22ae97c7dc
1 changed files with 21 additions and 6 deletions
|
|
@ -13,7 +13,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
1, 8,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -21,7 +24,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
9, 10,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -29,7 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
11, 12,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -37,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
13, 20,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
|
@ -45,8 +57,11 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
21,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", WebGpuSupportedNumberTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue