[JS/WebGPU] Added Flatten operator support. (#16860)

### Description
Added Flatten operator support to JSEP.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-07-27 12:50:45 -07:00 committed by GitHub
parent ec935a5533
commit e67547b978
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 9 deletions

View file

@ -31,6 +31,7 @@ Do not modify directly.*
| Erf | ai.onnx(9-12,13+) | |
| Exp | ai.onnx(6-12,13+) | |
| Expand | ai.onnx(8-12,13+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| Gemm | ai.onnx(7-8,9-10,11+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |

View file

@ -524,15 +524,15 @@
// "test_eyelike_populate_off_main_diagonal",
// "test_eyelike_with_dtype",
// "test_eyelike_without_dtype",
// "test_flatten_axis0",
// "test_flatten_axis1",
// "test_flatten_axis2",
// "test_flatten_axis3",
// "test_flatten_default_axis",
// "test_flatten_negative_axis1",
// "test_flatten_negative_axis2",
// "test_flatten_negative_axis3",
// "test_flatten_negative_axis4",
"test_flatten_axis0",
"test_flatten_axis1",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_default_axis",
"test_flatten_negative_axis1",
"test_flatten_negative_axis2",
"test_flatten_negative_axis3",
"test_flatten_negative_axis4",
"test_floor_example",
"test_floor",
// "test_gather_0",

View file

@ -251,6 +251,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten);
std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
@ -438,6 +443,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "flatten.h"
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
1, 8,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
9, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Flatten);
ONNX_OPERATOR_KERNEL_EX(
Flatten,
kOnnxDomain,
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Flatten);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
#include "core/framework/data_transfer_manager.h"
#include "core/providers/common.h"
namespace onnxruntime {
namespace js {
class Flatten : public JsKernel {
public:
Flatten(const OpKernelInfo& info) : JsKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK());
}
Status Compute(OpKernelContext* context) const override {
const Tensor* X = context->Input<Tensor>(0);
if (X == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
}
const TensorShape& xShape = X->Shape();
auto axis = axis_ >= 0 ? axis_ : HandleNegativeAxis(axis_, xShape.NumDimensions());
ORT_ENFORCE(gsl::narrow_cast<int64_t>(xShape.NumDimensions()) >= axis, "The rank of input tensor must be >= axis");
const TensorShape yShape = {xShape.SizeToDimension(onnxruntime::narrow<size_t>(axis)),
xShape.SizeFromDimension(onnxruntime::narrow<size_t>(axis))};
Tensor* Y = context->Output(0, yShape);
const void* source = X->DataRaw();
void* target = Y->MutableDataRaw();
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
if (target != source) {
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y));
}
return Status::OK();
}
private:
int64_t axis_;
};
} // namespace js
} // namespace onnxruntime