From e67547b9783655dc09ab76d072edd3bd6368cfde Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Thu, 27 Jul 2023 12:50:45 -0700 Subject: [PATCH] [JS/WebGPU] Added Flatten operator support. (#16860) ### Description Added Flatten operator support to JSEP. ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + js/web/test/suite-test-list.jsonc | 18 +++---- .../providers/js/js_execution_provider.cc | 11 ++++ .../core/providers/js/operators/flatten.cc | 50 +++++++++++++++++++ .../core/providers/js/operators/flatten.h | 45 +++++++++++++++++ 5 files changed, 116 insertions(+), 9 deletions(-) create mode 100644 onnxruntime/core/providers/js/operators/flatten.cc create mode 100644 onnxruntime/core/providers/js/operators/flatten.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 7eea38db17..722d8d0421 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -31,6 +31,7 @@ Do not modify directly.* | Erf | ai.onnx(9-12,13+) | | | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | +| Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | Gemm | ai.onnx(7-8,9-10,11+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 1f3b2b979c..c7f09a2768 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -524,15 +524,15 @@ // "test_eyelike_populate_off_main_diagonal", // "test_eyelike_with_dtype", // "test_eyelike_without_dtype", - // "test_flatten_axis0", - // "test_flatten_axis1", - // "test_flatten_axis2", - // "test_flatten_axis3", - // "test_flatten_default_axis", - // "test_flatten_negative_axis1", - // "test_flatten_negative_axis2", - // "test_flatten_negative_axis3", - // "test_flatten_negative_axis4", + "test_flatten_axis0", + "test_flatten_axis1", + "test_flatten_axis2", + "test_flatten_axis3", + "test_flatten_default_axis", + "test_flatten_negative_axis1", + "test_flatten_negative_axis2", + "test_flatten_negative_axis3", + "test_flatten_negative_axis4", "test_floor_example", "test_floor", // "test_gather_0", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 0365e0ae0d..6f0a0be780 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -251,6 +251,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -438,6 +443,12 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/flatten.cc b/onnxruntime/core/providers/js/operators/flatten.cc new file mode 100644 index 0000000000..7e4b4c3509 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/flatten.cc @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "flatten.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 1, 8, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 9, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + +ONNX_OPERATOR_KERNEL_EX( + Flatten, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/flatten.h b/onnxruntime/core/providers/js/operators/flatten.h new file mode 100644 index 0000000000..9977a47b03 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/flatten.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/framework/data_transfer_manager.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace js { + +class Flatten : public JsKernel { + public: + Flatten(const OpKernelInfo& info) : JsKernel(info) { + ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK()); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + } + const TensorShape& xShape = X->Shape(); + auto axis = axis_ >= 0 ? axis_ : HandleNegativeAxis(axis_, xShape.NumDimensions()); + ORT_ENFORCE(gsl::narrow_cast(xShape.NumDimensions()) >= axis, "The rank of input tensor must be >= axis"); + const TensorShape yShape = {xShape.SizeToDimension(onnxruntime::narrow(axis)), + xShape.SizeFromDimension(onnxruntime::narrow(axis))}; + Tensor* Y = context->Output(0, yShape); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + return Status::OK(); + } + + private: + int64_t axis_; +}; + +} // namespace js +} // namespace onnxruntime