From 22ae97c7dc7822eb19cafa0bd9bfee67da3075ec Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:19:43 -0800 Subject: [PATCH] [webgpu] Add Alias def for Flatten (#23038) ### Description Add `Alias` definition for Flatten in WebGPU EP. also add int32/uint32 in type constraint T. --- .../core/providers/webgpu/tensor/flatten.cc | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.cc b/onnxruntime/core/providers/webgpu/tensor/flatten.cc index 81d28bd3c0..11ded865b6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/flatten.cc +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.cc @@ -13,7 +13,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 1, 8, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -21,7 +24,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 9, 10, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -29,7 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 11, 12, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -37,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 13, 20, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_KERNEL_EX( @@ -45,8 +57,11 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 21, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime