mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[JS/WebGPU] Added Flatten operator support. (#16860)
### Description Added Flatten operator support to JSEP. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
ec935a5533
commit
e67547b978
5 changed files with 116 additions and 9 deletions
|
|
@ -31,6 +31,7 @@ Do not modify directly.*
|
|||
| Erf | ai.onnx(9-12,13+) | |
|
||||
| Exp | ai.onnx(6-12,13+) | |
|
||||
| Expand | ai.onnx(8-12,13+) | |
|
||||
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
|
||||
| Floor | ai.onnx(6-12,13+) | |
|
||||
| Gemm | ai.onnx(7-8,9-10,11+) | |
|
||||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
|
|
|
|||
|
|
@ -524,15 +524,15 @@
|
|||
// "test_eyelike_populate_off_main_diagonal",
|
||||
// "test_eyelike_with_dtype",
|
||||
// "test_eyelike_without_dtype",
|
||||
// "test_flatten_axis0",
|
||||
// "test_flatten_axis1",
|
||||
// "test_flatten_axis2",
|
||||
// "test_flatten_axis3",
|
||||
// "test_flatten_default_axis",
|
||||
// "test_flatten_negative_axis1",
|
||||
// "test_flatten_negative_axis2",
|
||||
// "test_flatten_negative_axis3",
|
||||
// "test_flatten_negative_axis4",
|
||||
"test_flatten_axis0",
|
||||
"test_flatten_axis1",
|
||||
"test_flatten_axis2",
|
||||
"test_flatten_axis3",
|
||||
"test_flatten_default_axis",
|
||||
"test_flatten_negative_axis1",
|
||||
"test_flatten_negative_axis2",
|
||||
"test_flatten_negative_axis3",
|
||||
"test_flatten_negative_axis4",
|
||||
"test_floor_example",
|
||||
"test_floor",
|
||||
// "test_gather_0",
|
||||
|
|
|
|||
|
|
@ -251,6 +251,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
||||
|
|
@ -438,6 +443,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten)>,
|
||||
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
50
onnxruntime/core/providers/js/operators/flatten.cc
Normal file
50
onnxruntime/core/providers/js/operators/flatten.cc
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "flatten.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
1, 8,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
9, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Flatten,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Flatten);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
45
onnxruntime/core/providers/js/operators/flatten.h
Normal file
45
onnxruntime/core/providers/js/operators/flatten.h
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/framework/data_transfer_manager.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class Flatten : public JsKernel {
|
||||
public:
|
||||
Flatten(const OpKernelInfo& info) : JsKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK());
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
if (X == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
}
|
||||
const TensorShape& xShape = X->Shape();
|
||||
auto axis = axis_ >= 0 ? axis_ : HandleNegativeAxis(axis_, xShape.NumDimensions());
|
||||
ORT_ENFORCE(gsl::narrow_cast<int64_t>(xShape.NumDimensions()) >= axis, "The rank of input tensor must be >= axis");
|
||||
const TensorShape yShape = {xShape.SizeToDimension(onnxruntime::narrow<size_t>(axis)),
|
||||
xShape.SizeFromDimension(onnxruntime::narrow<size_t>(axis))};
|
||||
Tensor* Y = context->Output(0, yShape);
|
||||
const void* source = X->DataRaw();
|
||||
void* target = Y->MutableDataRaw();
|
||||
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
|
||||
|
||||
if (target != source) {
|
||||
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue