From f9b61dbc6f9ae29a22ccfcd33228ffa9cea38ae2 Mon Sep 17 00:00:00 2001 From: vraspar Date: Fri, 24 Jan 2025 19:22:05 -0800 Subject: [PATCH] Refactor Softmax implementation for WebGPU --- .../core/providers/webgpu/math/softmax.cc | 64 ++++++++++++++----- .../core/providers/webgpu/math/softmax.h | 12 ++-- .../core/providers/webgpu/shader_variable.h | 8 ++- 3 files changed, 60 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index 796c56f67c..4abefa704c 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -2,13 +2,15 @@ // Licensed under the MIT License. #include "core/common/inlined_containers.h" -#include "core/providers/webgpu/tensor/softmax.h" +#include "core/providers/webgpu/math/softmax.h" #include "core/providers/webgpu/tensor/transpose.h" #include "core/providers/cpu/tensor/utils.h" #include "core/providers/webgpu/shader_variable.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/common/logging/logging.h" + namespace onnxruntime { namespace webgpu { @@ -45,6 +47,8 @@ static std::string MaxVector(std::string name, int components) { return name; case 2: return "max(" + name + ".x, " + name + ".y)"; + case 3: + return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)"; case 4: return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))"; default: @@ -76,13 +80,19 @@ static int GetMaxComponents(int64_t size) { Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { // Add input and output variables - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); int components = input.NumComponents(); - std::string threadMaxDecl = input.StorageType() == "f32" ? - "val threadMax = x_value_t(-3.402823e+38f);\n" : - "val threadMax = x_value_t(-65504.0h));\n"; + LOGS_DEFAULT(VERBOSE) << "Input StorageType: " << input.StorageType() << "\n"; + LOGS_DEFAULT(VERBOSE) << "Input ElementType: " << input.ElementType() << "\n"; + LOGS_DEFAULT(VERBOSE) << "Input ValueType: " << input.ValueType() << "\n"; + + + + std::string threadMaxDecl = input.ElementType() == "f32" ? + "var threadMax = x_value_t(-3.402823e+38f);\n" : + "var threadMax = x_value_t(-65504.0h);\n"; // Define shared memory for row max and row sum @@ -132,7 +142,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " workgroupBarrier();\n" << " }\n" << " if (lindex == 0) {\n" - << " rowMaxShared = x_value_t(" << MaxVector('threadShared[0]', components) << ");\n" + << " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n" << " }\n" << " workgroupBarrier();\n" @@ -174,9 +184,15 @@ Status Softmax::ComputeInternal(ComputeContext& context) const { auto* output_tensor = context.Output(0, input_shape); // normalize axis - int64_t axis = axis < 0 ? axis_ + input_rank : axis_; + int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_; bool is_transpose_required = axis < input_rank - 1; + LOGS_DEFAULT(VERBOSE) <<"axis_: " << axis_ << " axis: " << axis << "\n"; + LOGS_DEFAULT(VERBOSE) << "Transpose required: " << (is_transpose_required ? "true" : "false") << "\n"; + LOGS_DEFAULT(VERBOSE) << "Input shape: " << input_shape.ToString() << "\n"; + LOGS_DEFAULT(VERBOSE) << "Output shape: " << output_tensor->Shape().ToString() << "\n"; + LOGS_DEFAULT(VERBOSE) << "Input rank: " << input_rank << "\n"; + TensorShape transposed_input_shape = input_shape; Tensor transposed_input_tensor; Tensor intermediate_output; @@ -184,25 +200,34 @@ Status Softmax::ComputeInternal(ComputeContext& context) const { if (is_transpose_required) { AllocatorPtr alloc; - perm.reserve(input_rank); - for (size_t i = 0; i < input_rank; ++i) { + perm.resize(input_rank); + for (size_t i = 0; i < perm.size(); ++i) { perm[i] = i; } perm[axis] = input_rank - 1; perm[input_rank - 1] = axis; + LOGS_DEFAULT(VERBOSE) << "Allocating temporary tensors for transpose\n"; + // allocate a temporary tensor to hold transposed input Tensor temp_input(input_tensor->DataType(), TensorShape(transposed_input_shape), alloc); - ORT_RETURN_IF_ERROR(Transpose::DoTranspose( perm, *input_tensor, temp_input)); + LOGS_DEFAULT(VERBOSE) << "Performing transpose\n"; + + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(perm, *input_tensor, temp_input)); + + LOGS_DEFAULT(VERBOSE) << "Transpose done\n"; + + LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n"; transposed_input_tensor = std::move(temp_input); transposed_input_shape = transposed_input_tensor.Shape(); + LOGS_DEFAULT(VERBOSE) << "Transposed input shape: " << transposed_input_shape.ToString() << "\n"; + // Allocate memory for the intermediate output + LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n"; Tensor temp_output(output_tensor->DataType(), TensorShape(transposed_input_shape), alloc); intermediate_output = std::move(temp_output); - } else { - transposed_input_tensor = *input_tensor; } @@ -211,15 +236,24 @@ Status Softmax::ComputeInternal(ComputeContext& context) const { const size_t components = GetMaxComponents(cols); const auto packedCols = cols / components; + LOGS_DEFAULT(VERBOSE) << "Cols: " << cols << " Rows: " << rows << " Components: " << components << " PackedCols: " << packedCols << "\n"; + size_t WG = rows == 1 ? 256: 64; SoftmaxProgram program{WG}; + if (is_transpose_required) { + program + .AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) + .AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); + } else { + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); + } program .CacheHint(std::to_string(components), std::to_string(WG)) - .AddInputs({*transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({ is_transpose_required ? *intermediate_output : output_tensor}) .SetWorkgroupSize(WG) .SetDispatchGroupSize(rows) .AddUniformVariables({ diff --git a/onnxruntime/core/providers/webgpu/math/softmax.h b/onnxruntime/core/providers/webgpu/math/softmax.h index b8bc37a0c0..b67425471d 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.h +++ b/onnxruntime/core/providers/webgpu/math/softmax.h @@ -4,9 +4,9 @@ #pragma once #include "core/providers/webgpu/webgpu_supported_types.h" -#include "core/providers/cpu/math/softmax.h" #include "core/providers/webgpu/webgpu_kernel.h" #include "core/providers/webgpu/program.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { namespace webgpu { @@ -15,8 +15,8 @@ class Softmax final : public WebGpuKernel { public: Softmax(const OpKernelInfo& info) : WebGpuKernel{info} { int opset_ = info.node().SinceVersion(); - size_t axis; - Status status = info.GetAttr("axis", &axis); + int64_t axis; + Status status = info.GetAttr("axis", &axis); if (status.IsOK()) { axis_ = axis; @@ -32,12 +32,12 @@ class Softmax final : public WebGpuKernel { Status ComputeInternal(ComputeContext& context) const override; private: - size_t axis_; + int64_t axis_; }; class SoftmaxProgram final : public Program { public: - SoftmaxProgram(size_t axis, int wg) : Program{"Softmax"}, axis_{axis}, WG_{wg} { + SoftmaxProgram(size_t wg) : Program{"Softmax"}, WG{wg} { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -45,7 +45,7 @@ class SoftmaxProgram final : public Program { WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32}); private: - int WG; + size_t WG; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 4c87bc9158..3b8ed7bf42 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -176,6 +176,10 @@ class ShaderVariableHelper : public ShaderIndicesHelper { template inline std::string GetByOffset(TOffset&& offset) const; + std::string_view StorageType() const; + std::string_view ValueType() const; + std::string_view ElementType() const; + private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); @@ -183,9 +187,7 @@ class ShaderVariableHelper : public ShaderIndicesHelper { std::string GetByOffsetImpl(std::string_view offset) const; std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; - std::string_view StorageType() const; - std::string_view ValueType() const; - std::string_view ElementType() const; + friend class ShaderHelper; };