From 1d97d6ef55433298dee58634b0ea59f736e8a72e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 15 Jan 2025 21:01:05 -0800 Subject: [PATCH] [webgpu] fix Split operator implementation when input is 1D (#23376) ### Description [webgpu] fix Split operator implementation when input is 1D --- onnxruntime/core/providers/webgpu/tensor/split.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/split.cc b/onnxruntime/core/providers/webgpu/tensor/split.cc index 700fa07679..83bf832cc5 100644 --- a/onnxruntime/core/providers/webgpu/tensor/split.cc +++ b/onnxruntime/core/providers/webgpu/tensor/split.cc @@ -65,11 +65,11 @@ Status SplitProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size") << " var indices = " << input.OffsetToIndices("global_idx") << ";\n" - << " var index = indices[" << axis_ << "];\n" + << " var index = " << input.IndicesGet("indices", axis_) << ";\n" << " let output_number = calculate_output_index(index);\n" << " if (output_number != 0u) {\n" << " index -= uniforms.sizes_in_split_axis[output_number - 1u];\n" - << " indices[" << axis_ << "] = index;\n" + << " " << input.IndicesSet("indices", axis_, "index") << "\n" << " }\n" << " write_buffer_data(output_number, global_idx, indices);\n";