From a436b3af1aa09bd3607a0ff2cd56ed63596cee1a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 7 Nov 2024 08:10:05 -0800 Subject: [PATCH] [webgpu] fix indices type when it's 4D (#22758) ### Description Fix indices type from `array` to `vec4` when the variable is 4D. --- onnxruntime/core/providers/webgpu/shader_variable.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 46b3d0d902..15020b801c 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -81,8 +81,8 @@ constexpr static const auto ELEMENT_TYPE = details::_to_std_array(ELEMENT_TYPE_A inline std::string GetIndicesType(int rank) { return rank < 2 ? "u32" - : (rank < 4 ? MakeStringWithClassicLocale("vec", rank, "") - : MakeStringWithClassicLocale("array")); + : (rank <= 4 ? MakeStringWithClassicLocale("vec", rank, "") + : MakeStringWithClassicLocale("array")); } } // namespace