diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 3578b76e1f..87b81bf6da 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -48,6 +48,7 @@ Do not modify directly.* | Sin | ai.onnx(7+) | | | Sinh | ai.onnx(9+) | | | Sqrt | ai.onnx(6-12,13+) | | +| Squeeze | ai.onnx(1-10,11-12,13+) | | | Sub | ai.onnx(7-12,13,14+) | | | Tan | ai.onnx(7+) | | | ThresholdedRelu | ai.onnx(10+) | | diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6d7ab8ae72..2399a2a465 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -141,6 +141,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); @@ -250,6 +254,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/squeeze.cc b/onnxruntime/core/providers/js/operators/squeeze.cc new file mode 100644 index 0000000000..a51fecfd8b --- /dev/null +++ b/onnxruntime/core/providers/js/operators/squeeze.cc @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "squeeze.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(0, 0), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(0, 0), + Squeeze); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/squeeze.h b/onnxruntime/core/providers/js/operators/squeeze.h new file mode 100644 index 0000000000..49958d549c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/squeeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/tensor/squeeze.h" +#include "core/providers/js/js_kernel.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace js { + +class Squeeze final : public JsKernel, public SqueezeBase { + public: + explicit Squeeze(const OpKernelInfo& info) : JsKernel(info), SqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace js +} // namespace onnxruntime