mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
[WebGPU EP] Flatten implementation (#22964)
Implements flatten operator for native webgpu.
This commit is contained in:
parent
9ed0c7fe26
commit
5c644d3747
3 changed files with 122 additions and 5 deletions
52
onnxruntime/core/providers/webgpu/tensor/flatten.cc
Normal file
52
onnxruntime/core/providers/webgpu/tensor/flatten.cc
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/webgpu/tensor/flatten.h"
|
||||
#include "core/providers/webgpu/webgpu_execution_provider.h"
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webgpu {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
1, 8,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
9, 10,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
13, 20,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
21,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Flatten);
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
62
onnxruntime/core/providers/webgpu/tensor/flatten.h
Normal file
62
onnxruntime/core/providers/webgpu/tensor/flatten.h
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/nn/flatten.h"
|
||||
#include "core/framework/data_transfer_manager.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webgpu {
|
||||
|
||||
class Flatten final : public OpKernel {
|
||||
public:
|
||||
explicit Flatten(const OpKernelInfo& info) : OpKernel{info} {
|
||||
axis_ = info.GetAttrOrDefault<int64_t>("axis", 1);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
const Tensor* input_tensor = context->Input<Tensor>(0);
|
||||
const TensorShape& input_shape = input_tensor->Shape();
|
||||
int64_t input_rank = input_shape.NumDimensions();
|
||||
|
||||
// Handle negative axis
|
||||
int64_t axis = axis_;
|
||||
if (axis < 0) {
|
||||
axis += input_rank;
|
||||
}
|
||||
|
||||
if (axis > input_rank) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Invalid value for axis, must be less than or equal to input_rank");
|
||||
}
|
||||
|
||||
int64_t first_dim = 1;
|
||||
for (int64_t i = 0; i < axis; i++) {
|
||||
first_dim *= input_shape[i];
|
||||
}
|
||||
|
||||
int64_t second_dim = 1;
|
||||
for (int64_t i = axis; i < input_rank; i++) {
|
||||
second_dim *= input_shape[i];
|
||||
}
|
||||
|
||||
TensorShape output_shape({first_dim, second_dim});
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
|
||||
const void* source = input_tensor->DataRaw();
|
||||
void* target = output_tensor->MutableDataRaw();
|
||||
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
|
||||
if (target != source) {
|
||||
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output_tensor));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -347,7 +347,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13,
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Flatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile);
|
||||
|
||||
|
|
@ -667,10 +668,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice)>,
|
||||
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Flatten)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile)>,
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue