mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Refactor Softmax implementation for WebGPU
This commit is contained in:
parent
e7e373713e
commit
f9b61dbc6f
3 changed files with 60 additions and 24 deletions
|
|
@ -2,13 +2,15 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/providers/webgpu/tensor/softmax.h"
|
||||
#include "core/providers/webgpu/math/softmax.h"
|
||||
#include "core/providers/webgpu/tensor/transpose.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/webgpu/shader_variable.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webgpu {
|
||||
|
||||
|
|
@ -45,6 +47,8 @@ static std::string MaxVector(std::string name, int components) {
|
|||
return name;
|
||||
case 2:
|
||||
return "max(" + name + ".x, " + name + ".y)";
|
||||
case 3:
|
||||
return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)";
|
||||
case 4:
|
||||
return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))";
|
||||
default:
|
||||
|
|
@ -76,13 +80,19 @@ static int GetMaxComponents(int64_t size) {
|
|||
|
||||
Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
// Add input and output variables
|
||||
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
const auto& output = shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
|
||||
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
int components = input.NumComponents();
|
||||
|
||||
std::string threadMaxDecl = input.StorageType() == "f32" ?
|
||||
"val threadMax = x_value_t(-3.402823e+38f);\n" :
|
||||
"val threadMax = x_value_t(-65504.0h));\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input StorageType: " << input.StorageType() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input ElementType: " << input.ElementType() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input ValueType: " << input.ValueType() << "\n";
|
||||
|
||||
|
||||
|
||||
std::string threadMaxDecl = input.ElementType() == "f32" ?
|
||||
"var threadMax = x_value_t(-3.402823e+38f);\n" :
|
||||
"var threadMax = x_value_t(-65504.0h);\n";
|
||||
|
||||
|
||||
// Define shared memory for row max and row sum
|
||||
|
|
@ -132,7 +142,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
<< " workgroupBarrier();\n"
|
||||
<< " }\n"
|
||||
<< " if (lindex == 0) {\n"
|
||||
<< " rowMaxShared = x_value_t(" << MaxVector('threadShared[0]', components) << ");\n"
|
||||
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
|
||||
<< " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
|
||||
|
|
@ -174,9 +184,15 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
|
|||
auto* output_tensor = context.Output(0, input_shape);
|
||||
|
||||
// normalize axis
|
||||
int64_t axis = axis < 0 ? axis_ + input_rank : axis_;
|
||||
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
|
||||
|
||||
bool is_transpose_required = axis < input_rank - 1;
|
||||
LOGS_DEFAULT(VERBOSE) <<"axis_: " << axis_ << " axis: " << axis << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Transpose required: " << (is_transpose_required ? "true" : "false") << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input shape: " << input_shape.ToString() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Output shape: " << output_tensor->Shape().ToString() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input rank: " << input_rank << "\n";
|
||||
|
||||
TensorShape transposed_input_shape = input_shape;
|
||||
Tensor transposed_input_tensor;
|
||||
Tensor intermediate_output;
|
||||
|
|
@ -184,25 +200,34 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
|
|||
|
||||
if (is_transpose_required) {
|
||||
AllocatorPtr alloc;
|
||||
perm.reserve(input_rank);
|
||||
for (size_t i = 0; i < input_rank; ++i) {
|
||||
perm.resize(input_rank);
|
||||
for (size_t i = 0; i < perm.size(); ++i) {
|
||||
perm[i] = i;
|
||||
}
|
||||
perm[axis] = input_rank - 1;
|
||||
perm[input_rank - 1] = axis;
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Allocating temporary tensors for transpose\n";
|
||||
|
||||
// allocate a temporary tensor to hold transposed input
|
||||
Tensor temp_input(input_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
|
||||
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose( perm, *input_tensor, temp_input));
|
||||
LOGS_DEFAULT(VERBOSE) << "Performing transpose\n";
|
||||
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(perm, *input_tensor, temp_input));
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Transpose done\n";
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
|
||||
transposed_input_tensor = std::move(temp_input);
|
||||
transposed_input_shape = transposed_input_tensor.Shape();
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Transposed input shape: " << transposed_input_shape.ToString() << "\n";
|
||||
|
||||
// Allocate memory for the intermediate output
|
||||
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
|
||||
Tensor temp_output(output_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
|
||||
intermediate_output = std::move(temp_output);
|
||||
} else {
|
||||
transposed_input_tensor = *input_tensor;
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -211,15 +236,24 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
|
|||
const size_t components = GetMaxComponents(cols);
|
||||
const auto packedCols = cols / components;
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Cols: " << cols << " Rows: " << rows << " Components: " << components << " PackedCols: " << packedCols << "\n";
|
||||
|
||||
size_t WG = rows == 1 ? 256: 64;
|
||||
|
||||
SoftmaxProgram program{WG};
|
||||
if (is_transpose_required) {
|
||||
program
|
||||
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
|
||||
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
|
||||
} else {
|
||||
program
|
||||
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
|
||||
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
|
||||
}
|
||||
|
||||
|
||||
program
|
||||
.CacheHint(std::to_string(components), std::to_string(WG))
|
||||
.AddInputs({*transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
|
||||
.AddOutputs({ is_transpose_required ? *intermediate_output : output_tensor})
|
||||
.SetWorkgroupSize(WG)
|
||||
.SetDispatchGroupSize(rows)
|
||||
.AddUniformVariables({
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
#include "core/providers/cpu/math/softmax.h"
|
||||
#include "core/providers/webgpu/webgpu_kernel.h"
|
||||
#include "core/providers/webgpu/program.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webgpu {
|
||||
|
|
@ -15,8 +15,8 @@ class Softmax final : public WebGpuKernel {
|
|||
public:
|
||||
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} {
|
||||
int opset_ = info.node().SinceVersion();
|
||||
size_t axis;
|
||||
Status status = info.GetAttr<size_t>("axis", &axis);
|
||||
int64_t axis;
|
||||
Status status = info.GetAttr<int64_t>("axis", &axis);
|
||||
|
||||
if (status.IsOK()) {
|
||||
axis_ = axis;
|
||||
|
|
@ -32,12 +32,12 @@ class Softmax final : public WebGpuKernel {
|
|||
Status ComputeInternal(ComputeContext& context) const override;
|
||||
|
||||
private:
|
||||
size_t axis_;
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
class SoftmaxProgram final : public Program<SoftmaxProgram> {
|
||||
public:
|
||||
SoftmaxProgram(size_t axis, int wg) : Program{"Softmax"}, axis_{axis}, WG_{wg} {
|
||||
SoftmaxProgram(size_t wg) : Program{"Softmax"}, WG{wg} {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
|
@ -45,7 +45,7 @@ class SoftmaxProgram final : public Program<SoftmaxProgram> {
|
|||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32});
|
||||
|
||||
private:
|
||||
int WG;
|
||||
size_t WG;
|
||||
};
|
||||
|
||||
} // namespace webgpu
|
||||
|
|
|
|||
|
|
@ -176,6 +176,10 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
|
|||
template <typename TOffset>
|
||||
inline std::string GetByOffset(TOffset&& offset) const;
|
||||
|
||||
std::string_view StorageType() const;
|
||||
std::string_view ValueType() const;
|
||||
std::string_view ElementType() const;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper);
|
||||
|
||||
|
|
@ -183,9 +187,7 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
|
|||
|
||||
std::string GetByOffsetImpl(std::string_view offset) const;
|
||||
std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const;
|
||||
std::string_view StorageType() const;
|
||||
std::string_view ValueType() const;
|
||||
std::string_view ElementType() const;
|
||||
|
||||
|
||||
friend class ShaderHelper;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue