allow scalar axes for Unsqueeze for WebGPU (#22054)

### Description

Align with CPU behavior.


https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc#L60-L62
This commit is contained in:
Yulong Wang 2024-09-12 10:33:37 -07:00 committed by GitHub
parent 951b1b7160
commit 84f73327f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase {
if (num_inputs == 2) { // axes is an input
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a vector tensor.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 ||
axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a scalar or a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->Data<int64_t>();
axes.assign(data, data + nDims);