[JS/WebGPU] Squeeze operator implementation (#16024)

### Description

This PR adds an implementation of the `Squeeze` operator to WebGPU JSEP.
The implementation follows the [operator
schema](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze)
and allows one or two inputs.

### How was it tested

1. I created two models. Without `axes`:

```Python
import onnx.helper

node = onnx.helper.make_node(
    "Squeeze",
    inputs=["T"],
    outputs=["y"],
)
graph = onnx.helper.make_graph([node], "test", [onnx.helper.make_tensor_value_info("T", 1, [3, 1, 4, 5])], 
    [onnx.helper.make_tensor_value_info("y", 1, [3, 4, 5])])
onnx.save(onnx.helper.make_model(graph), "squeeze.onnx")
```

And with `axes`:

```Python
import onnx.helper

node = onnx.helper.make_node(
    "Squeeze",
    inputs=["T", "axes"],
    outputs=["y"],
)
graph = onnx.helper.make_graph([node], "test", [onnx.helper.make_tensor_value_info("T", 1, [3, 1, 4, 5]), onnx.helper.make_tensor_value_info("axes", 7, [1])], [onnx.helper.make_tensor_value_info("y", 1, [3, 4, 5])])
onnx.save(onnx.helper.make_model(graph), "squeeze-dim.onnx")
```

2. I compiled the runtime using @fs-eire's
[instructions](https://gist.github.com/fs-eire/a55b2c7e10a6864b9602c279b8b75dce).
3. I ran the test models in the browser using this minimal setup:
```HTML
<html>
    <script src=".\dist\ort.webgpu.min.js"></script>
    <script>
        async function run() {
            const session = await ort.InferenceSession.create('squeeze-dim.onnx', {executionProviders: ['webgpu']});
            console.log(session);
            const input = new ort.Tensor('float32', new Float32Array(60), [3, 1, 4, 5]);
            const dim = new ort.Tensor('int64', [-3n], [1]);
            const output = await session.run({ "T": input, "axes": dim });
            console.log(output);
        }
        run();
    </script>
</html>
```

### Motivation and Context

Improve operator coverage for WebGPU JSEP.
This commit is contained in:
Alexander Visheratin 2023-05-26 18:53:05 -04:00 committed by GitHub
parent 5e41d1600a
commit 415c26e46e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 103 additions and 0 deletions

View file

@ -48,6 +48,7 @@ Do not modify directly.*
| Sin | ai.onnx(7+) | |
| Sinh | ai.onnx(9+) | |
| Sqrt | ai.onnx(6-12,13+) | |
| Squeeze | ai.onnx(1-10,11-12,13+) | |
| Sub | ai.onnx(7-12,13,14+) | |
| Tan | ai.onnx(7+) | |
| ThresholdedRelu | ai.onnx(10+) | |

View file

@ -141,6 +141,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose);
@ -250,6 +254,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose)>,

View file

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "squeeze.h"
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_KERNEL_EX(
Squeeze,
kOnnxDomain,
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("axes", DataTypeImpl::GetTensorType<int64_t>())
.Alias(0, 0)
.InputMemoryType(OrtMemTypeCPU, 1),
Squeeze);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Squeeze,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.Alias(0, 0),
Squeeze);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Squeeze,
kOnnxDomain,
1, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.Alias(0, 0),
Squeeze);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cpu/tensor/squeeze.h"
#include "core/providers/js/js_kernel.h"
#include "core/framework/data_transfer_manager.h"
namespace onnxruntime {
namespace js {
class Squeeze final : public JsKernel, public SqueezeBase {
public:
explicit Squeeze(const OpKernelInfo& info) : JsKernel(info), SqueezeBase(info) {}
Status Compute(OpKernelContext* context) const override {
const Tensor* X = context->Input<Tensor>(0);
if (X == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set");
}
const TensorShape& X_shape = X->Shape();
TensorShapeVector axes;
size_t num_inputs = context->InputCount();
if (num_inputs == 2) { // axes is an input
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->Data<int64_t>();
axes.assign(data, data + nDims);
} else {
axes.assign(axes_.begin(), axes_.end());
}
TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes);
Tensor* Y = context->Output(0, TensorShape(output_shape));
const void* source = X->DataRaw();
void* target = Y->MutableDataRaw();
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
if (target != source) {
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y));
}
return Status::OK();
}
};
} // namespace js
} // namespace onnxruntime