mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[JS/WebGPU] Squeeze operator implementation (#16024)
### Description This PR adds an implementation of the `Squeeze` operator to WebGPU JSEP. The implementation follows the [operator schema](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) and allows one or two inputs. ### How was it tested 1. I created two models. Without `axes`: ```Python import onnx.helper node = onnx.helper.make_node( "Squeeze", inputs=["T"], outputs=["y"], ) graph = onnx.helper.make_graph([node], "test", [onnx.helper.make_tensor_value_info("T", 1, [3, 1, 4, 5])], [onnx.helper.make_tensor_value_info("y", 1, [3, 4, 5])]) onnx.save(onnx.helper.make_model(graph), "squeeze.onnx") ``` And with `axes`: ```Python import onnx.helper node = onnx.helper.make_node( "Squeeze", inputs=["T", "axes"], outputs=["y"], ) graph = onnx.helper.make_graph([node], "test", [onnx.helper.make_tensor_value_info("T", 1, [3, 1, 4, 5]), onnx.helper.make_tensor_value_info("axes", 7, [1])], [onnx.helper.make_tensor_value_info("y", 1, [3, 4, 5])]) onnx.save(onnx.helper.make_model(graph), "squeeze-dim.onnx") ``` 2. I compiled the runtime using @fs-eire's [instructions](https://gist.github.com/fs-eire/a55b2c7e10a6864b9602c279b8b75dce). 3. I ran the test models in the browser using this minimal setup: ```HTML <html> <script src=".\dist\ort.webgpu.min.js"></script> <script> async function run() { const session = await ort.InferenceSession.create('squeeze-dim.onnx', {executionProviders: ['webgpu']}); console.log(session); const input = new ort.Tensor('float32', new Float32Array(60), [3, 1, 4, 5]); const dim = new ort.Tensor('int64', [-3n], [1]); const output = await session.run({ "T": input, "axes": dim }); console.log(output); } run(); </script> </html> ``` ### Motivation and Context Improve operator coverage for WebGPU JSEP.
This commit is contained in:
parent
5e41d1600a
commit
415c26e46e
4 changed files with 103 additions and 0 deletions
|
|
@ -48,6 +48,7 @@ Do not modify directly.*
|
|||
| Sin | ai.onnx(7+) | |
|
||||
| Sinh | ai.onnx(9+) | |
|
||||
| Sqrt | ai.onnx(6-12,13+) | |
|
||||
| Squeeze | ai.onnx(1-10,11-12,13+) | |
|
||||
| Sub | ai.onnx(7-12,13,14+) | |
|
||||
| Tan | ai.onnx(7+) | |
|
||||
| ThresholdedRelu | ai.onnx(10+) | |
|
||||
|
|
|
|||
|
|
@ -141,6 +141,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose);
|
||||
|
||||
|
|
@ -250,6 +254,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose)>,
|
||||
|
||||
|
|
|
|||
42
onnxruntime/core/providers/js/operators/squeeze.cc
Normal file
42
onnxruntime/core/providers/js/operators/squeeze.cc
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "squeeze.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Squeeze,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("axes", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.Alias(0, 0)
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Squeeze);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Squeeze,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.Alias(0, 0),
|
||||
Squeeze);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Squeeze,
|
||||
kOnnxDomain,
|
||||
1, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.Alias(0, 0),
|
||||
Squeeze);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
52
onnxruntime/core/providers/js/operators/squeeze.h
Normal file
52
onnxruntime/core/providers/js/operators/squeeze.h
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/tensor/squeeze.h"
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/framework/data_transfer_manager.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class Squeeze final : public JsKernel, public SqueezeBase {
|
||||
public:
|
||||
explicit Squeeze(const OpKernelInfo& info) : JsKernel(info), SqueezeBase(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
if (X == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set");
|
||||
}
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
|
||||
TensorShapeVector axes;
|
||||
size_t num_inputs = context->InputCount();
|
||||
if (num_inputs == 2) { // axes is an input
|
||||
const Tensor* axes_tensor = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
|
||||
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
|
||||
"An axes tensor must be a vector tensor.");
|
||||
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
|
||||
const auto* data = axes_tensor->Data<int64_t>();
|
||||
axes.assign(data, data + nDims);
|
||||
} else {
|
||||
axes.assign(axes_.begin(), axes_.end());
|
||||
}
|
||||
|
||||
TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes);
|
||||
Tensor* Y = context->Output(0, TensorShape(output_shape));
|
||||
const void* source = X->DataRaw();
|
||||
void* target = Y->MutableDataRaw();
|
||||
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
|
||||
if (target != source) {
|
||||
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue