diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.cc b/onnxruntime/core/providers/webgpu/tensor/flatten.cc new file mode 100644 index 0000000000..81d28bd3c0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/flatten.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 1, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 9, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_KERNEL_EX( + Flatten, + kOnnxDomain, + 21, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.h b/onnxruntime/core/providers/webgpu/tensor/flatten.h new file mode 100644 index 0000000000..5fc49a844b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/flatten.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Flatten final : public OpKernel { + public: + explicit Flatten(const OpKernelInfo& info) : OpKernel{info} { + axis_ = info.GetAttrOrDefault("axis", 1); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* input_tensor = context->Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int64_t input_rank = input_shape.NumDimensions(); + + // Handle negative axis + int64_t axis = axis_; + if (axis < 0) { + axis += input_rank; + } + + if (axis > input_rank) { + return Status(common::ONNXRUNTIME, common::FAIL, "Invalid value for axis, must be less than or equal to input_rank"); + } + + int64_t first_dim = 1; + for (int64_t i = 0; i < axis; i++) { + first_dim *= input_shape[i]; + } + + int64_t second_dim = 1; + for (int64_t i = axis; i < input_rank; i++) { + second_dim *= input_shape[i]; + } + + TensorShape output_shape({first_dim, second_dim}); + Tensor* output_tensor = context->Output(0, output_shape); + + const void* source = input_tensor->DataRaw(); + void* target = output_tensor->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output_tensor)); + } + + return Status::OK(); + } + + private: + int64_t axis_; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 821c60ab60..90b6862758 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -347,7 +347,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); @@ -667,10 +668,12 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo,