mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
allow scalar axes for Unsqueeze for WebGPU (#22054)
### Description Align with CPU behavior. https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc#L60-L62
This commit is contained in:
parent
951b1b7160
commit
84f73327f5
1 changed files with 3 additions and 2 deletions
|
|
@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase {
|
|||
if (num_inputs == 2) { // axes is an input
|
||||
const Tensor* axes_tensor = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
|
||||
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
|
||||
"An axes tensor must be a vector tensor.");
|
||||
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 ||
|
||||
axes_tensor->Shape().NumDimensions() == 1,
|
||||
"An axes tensor must be a scalar or a vector tensor.");
|
||||
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
|
||||
const auto* data = axes_tensor->Data<int64_t>();
|
||||
axes.assign(data, data + nDims);
|
||||
|
|
|
|||
Loading…
Reference in a new issue