From 84f73327f55b3dadbf20b69bc1a12cc2811986ed Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:33:37 -0700 Subject: [PATCH] allow scalar axes for Unsqueeze for WebGPU (#22054) ### Description Align with CPU behavior. https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc#L60-L62 --- onnxruntime/core/providers/js/operators/unsqueeze.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.h b/onnxruntime/core/providers/js/operators/unsqueeze.h index 7cbfdc38b7..f15a300889 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.h +++ b/onnxruntime/core/providers/js/operators/unsqueeze.h @@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase { if (num_inputs == 2) { // axes is an input const Tensor* axes_tensor = context->Input(1); ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); axes.assign(data, data + nDims);