Refactor Softmax implementation for WebGPU

This commit is contained in:
vraspar 2025-01-24 19:22:05 -08:00
parent e7e373713e
commit f9b61dbc6f
3 changed files with 60 additions and 24 deletions

View file

@ -2,13 +2,15 @@
// Licensed under the MIT License.
#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/tensor/softmax.h"
#include "core/providers/webgpu/math/softmax.h"
#include "core/providers/webgpu/tensor/transpose.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_variable.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/common/logging/logging.h"
namespace onnxruntime {
namespace webgpu {
@ -45,6 +47,8 @@ static std::string MaxVector(std::string name, int components) {
return name;
case 2:
return "max(" + name + ".x, " + name + ".y)";
case 3:
return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)";
case 4:
return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))";
default:
@ -76,13 +80,19 @@ static int GetMaxComponents(int64_t size) {
Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add input and output variables
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& output = shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
int components = input.NumComponents();
std::string threadMaxDecl = input.StorageType() == "f32" ?
"val threadMax = x_value_t(-3.402823e+38f);\n" :
"val threadMax = x_value_t(-65504.0h));\n";
LOGS_DEFAULT(VERBOSE) << "Input StorageType: " << input.StorageType() << "\n";
LOGS_DEFAULT(VERBOSE) << "Input ElementType: " << input.ElementType() << "\n";
LOGS_DEFAULT(VERBOSE) << "Input ValueType: " << input.ValueType() << "\n";
std::string threadMaxDecl = input.ElementType() == "f32" ?
"var threadMax = x_value_t(-3.402823e+38f);\n" :
"var threadMax = x_value_t(-65504.0h);\n";
// Define shared memory for row max and row sum
@ -132,7 +142,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " rowMaxShared = x_value_t(" << MaxVector('threadShared[0]', components) << ");\n"
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"
@ -174,9 +184,15 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
auto* output_tensor = context.Output(0, input_shape);
// normalize axis
int64_t axis = axis < 0 ? axis_ + input_rank : axis_;
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
bool is_transpose_required = axis < input_rank - 1;
LOGS_DEFAULT(VERBOSE) <<"axis_: " << axis_ << " axis: " << axis << "\n";
LOGS_DEFAULT(VERBOSE) << "Transpose required: " << (is_transpose_required ? "true" : "false") << "\n";
LOGS_DEFAULT(VERBOSE) << "Input shape: " << input_shape.ToString() << "\n";
LOGS_DEFAULT(VERBOSE) << "Output shape: " << output_tensor->Shape().ToString() << "\n";
LOGS_DEFAULT(VERBOSE) << "Input rank: " << input_rank << "\n";
TensorShape transposed_input_shape = input_shape;
Tensor transposed_input_tensor;
Tensor intermediate_output;
@ -184,25 +200,34 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
if (is_transpose_required) {
AllocatorPtr alloc;
perm.reserve(input_rank);
for (size_t i = 0; i < input_rank; ++i) {
perm.resize(input_rank);
for (size_t i = 0; i < perm.size(); ++i) {
perm[i] = i;
}
perm[axis] = input_rank - 1;
perm[input_rank - 1] = axis;
LOGS_DEFAULT(VERBOSE) << "Allocating temporary tensors for transpose\n";
// allocate a temporary tensor to hold transposed input
Tensor temp_input(input_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
ORT_RETURN_IF_ERROR(Transpose::DoTranspose( perm, *input_tensor, temp_input));
LOGS_DEFAULT(VERBOSE) << "Performing transpose\n";
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(perm, *input_tensor, temp_input));
LOGS_DEFAULT(VERBOSE) << "Transpose done\n";
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
transposed_input_tensor = std::move(temp_input);
transposed_input_shape = transposed_input_tensor.Shape();
LOGS_DEFAULT(VERBOSE) << "Transposed input shape: " << transposed_input_shape.ToString() << "\n";
// Allocate memory for the intermediate output
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
Tensor temp_output(output_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
intermediate_output = std::move(temp_output);
} else {
transposed_input_tensor = *input_tensor;
}
@ -211,15 +236,24 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
const size_t components = GetMaxComponents(cols);
const auto packedCols = cols / components;
LOGS_DEFAULT(VERBOSE) << "Cols: " << cols << " Rows: " << rows << " Components: " << components << " PackedCols: " << packedCols << "\n";
size_t WG = rows == 1 ? 256: 64;
SoftmaxProgram program{WG};
if (is_transpose_required) {
program
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
} else {
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
}
program
.CacheHint(std::to_string(components), std::to_string(WG))
.AddInputs({*transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({ is_transpose_required ? *intermediate_output : output_tensor})
.SetWorkgroupSize(WG)
.SetDispatchGroupSize(rows)
.AddUniformVariables({

View file

@ -4,9 +4,9 @@
#pragma once
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/cpu/math/softmax.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace webgpu {
@ -15,8 +15,8 @@ class Softmax final : public WebGpuKernel {
public:
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} {
int opset_ = info.node().SinceVersion();
size_t axis;
Status status = info.GetAttr<size_t>("axis", &axis);
int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);
if (status.IsOK()) {
axis_ = axis;
@ -32,12 +32,12 @@ class Softmax final : public WebGpuKernel {
Status ComputeInternal(ComputeContext& context) const override;
private:
size_t axis_;
int64_t axis_;
};
class SoftmaxProgram final : public Program<SoftmaxProgram> {
public:
SoftmaxProgram(size_t axis, int wg) : Program{"Softmax"}, axis_{axis}, WG_{wg} {
SoftmaxProgram(size_t wg) : Program{"Softmax"}, WG{wg} {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@ -45,7 +45,7 @@ class SoftmaxProgram final : public Program<SoftmaxProgram> {
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32});
private:
int WG;
size_t WG;
};
} // namespace webgpu

View file

@ -176,6 +176,10 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
template <typename TOffset>
inline std::string GetByOffset(TOffset&& offset) const;
std::string_view StorageType() const;
std::string_view ValueType() const;
std::string_view ElementType() const;
private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper);
@ -183,9 +187,7 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
std::string GetByOffsetImpl(std::string_view offset) const;
std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const;
std::string_view StorageType() const;
std::string_view ValueType() const;
std::string_view ElementType() const;
friend class ShaderHelper;
};