From 415c26e46e8fdd1ccc4a39b934de1f9e17bcaa30 Mon Sep 17 00:00:00 2001 From: Alexander Visheratin Date: Fri, 26 May 2023 18:53:05 -0400 Subject: [PATCH] [JS/WebGPU] Squeeze operator implementation (#16024) ### Description This PR adds an implementation of the `Squeeze` operator to WebGPU JSEP. The implementation follows the [operator schema](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) and allows one or two inputs. ### How was it tested 1. I created two models. Without `axes`: ```Python import onnx.helper node = onnx.helper.make_node( "Squeeze", inputs=["T"], outputs=["y"], ) graph = onnx.helper.make_graph([node], "test", [onnx.helper.make_tensor_value_info("T", 1, [3, 1, 4, 5])], [onnx.helper.make_tensor_value_info("y", 1, [3, 4, 5])]) onnx.save(onnx.helper.make_model(graph), "squeeze.onnx") ``` And with `axes`: ```Python import onnx.helper node = onnx.helper.make_node( "Squeeze", inputs=["T", "axes"], outputs=["y"], ) graph = onnx.helper.make_graph([node], "test", [onnx.helper.make_tensor_value_info("T", 1, [3, 1, 4, 5]), onnx.helper.make_tensor_value_info("axes", 7, [1])], [onnx.helper.make_tensor_value_info("y", 1, [3, 4, 5])]) onnx.save(onnx.helper.make_model(graph), "squeeze-dim.onnx") ``` 2. I compiled the runtime using @fs-eire's [instructions](https://gist.github.com/fs-eire/a55b2c7e10a6864b9602c279b8b75dce). 3. I ran the test models in the browser using this minimal setup: ```HTML ``` ### Motivation and Context Improve operator coverage for WebGPU JSEP. --- js/web/docs/webgpu-operators.md | 1 + .../providers/js/js_execution_provider.cc | 8 +++ .../core/providers/js/operators/squeeze.cc | 42 +++++++++++++++ .../core/providers/js/operators/squeeze.h | 52 +++++++++++++++++++ 4 files changed, 103 insertions(+) create mode 100644 onnxruntime/core/providers/js/operators/squeeze.cc create mode 100644 onnxruntime/core/providers/js/operators/squeeze.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 3578b76e1f..87b81bf6da 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -48,6 +48,7 @@ Do not modify directly.* | Sin | ai.onnx(7+) | | | Sinh | ai.onnx(9+) | | | Sqrt | ai.onnx(6-12,13+) | | +| Squeeze | ai.onnx(1-10,11-12,13+) | | | Sub | ai.onnx(7-12,13,14+) | | | Tan | ai.onnx(7+) | | | ThresholdedRelu | ai.onnx(10+) | | diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6d7ab8ae72..2399a2a465 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -141,6 +141,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); @@ -250,6 +254,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/squeeze.cc b/onnxruntime/core/providers/js/operators/squeeze.cc new file mode 100644 index 0000000000..a51fecfd8b --- /dev/null +++ b/onnxruntime/core/providers/js/operators/squeeze.cc @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "squeeze.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(0, 0), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(0, 0), + Squeeze); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/squeeze.h b/onnxruntime/core/providers/js/operators/squeeze.h new file mode 100644 index 0000000000..49958d549c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/squeeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/tensor/squeeze.h" +#include "core/providers/js/js_kernel.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace js { + +class Squeeze final : public JsKernel, public SqueezeBase { + public: + explicit Squeeze(const OpKernelInfo& info) : JsKernel(info), SqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace js +} // namespace onnxruntime