diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.cc b/onnxruntime/core/providers/webgpu/tensor/flatten.cc index 81d28bd3c0..11ded865b6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/flatten.cc +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.cc @@ -13,7 +13,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 1, 8, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -21,7 +24,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 9, 10, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -29,7 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 11, 12, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -37,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 13, 20, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_KERNEL_EX( @@ -45,8 +57,11 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 21, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime