Refactor Softmax and remove debug logs

This commit is contained in:
vraspar 2025-01-29 16:33:20 -08:00
parent f9b61dbc6f
commit 87de60730a
5 changed files with 50 additions and 90 deletions

View file

@ -8,9 +8,6 @@
#include "core/providers/webgpu/shader_variable.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/common/logging/logging.h"
namespace onnxruntime {
namespace webgpu {
@ -84,16 +81,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
int components = input.NumComponents();
LOGS_DEFAULT(VERBOSE) << "Input StorageType: " << input.StorageType() << "\n";
LOGS_DEFAULT(VERBOSE) << "Input ElementType: " << input.ElementType() << "\n";
LOGS_DEFAULT(VERBOSE) << "Input ValueType: " << input.ValueType() << "\n";
std::string threadMaxDecl = input.ElementType() == "f32" ?
"var threadMax = x_value_t(-3.402823e+38f);\n" :
"var threadMax = x_value_t(-65504.0h);\n";
std::string threadMaxDecl = input.ElementType() == "f32" ? "var threadMax = x_value_t(-3.402823e+38f);\n" : "var threadMax = x_value_t(-65504.0h);\n";
// Define shared memory for row max and row sum
shader.AdditionalImplementation()
@ -142,7 +130,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"
@ -163,7 +151,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " rowSumShared = x_value_t(" << SumVector("threadShared[0]", components) << ");\n"
<< " rowSumShared = x_value_t(" << SumVector("threadShared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"
@ -179,71 +167,44 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
Status Softmax::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
size_t input_rank = input_shape.NumDimensions();
int64_t input_rank = input_shape.NumDimensions();
auto* output_tensor = context.Output(0, input_shape);
// normalize axis
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
bool is_transpose_required = axis < input_rank - 1;
LOGS_DEFAULT(VERBOSE) <<"axis_: " << axis_ << " axis: " << axis << "\n";
LOGS_DEFAULT(VERBOSE) << "Transpose required: " << (is_transpose_required ? "true" : "false") << "\n";
LOGS_DEFAULT(VERBOSE) << "Input shape: " << input_shape.ToString() << "\n";
LOGS_DEFAULT(VERBOSE) << "Output shape: " << output_tensor->Shape().ToString() << "\n";
LOGS_DEFAULT(VERBOSE) << "Input rank: " << input_rank << "\n";
TensorShape transposed_input_shape = input_shape;
TensorShape transposed_input_shape;
Tensor transposed_input_tensor;
Tensor intermediate_output;
InlinedVector<size_t> perm;
InlinedVector<size_t> perm(input_rank);
if (is_transpose_required) {
AllocatorPtr alloc;
perm.resize(input_rank);
for (size_t i = 0; i < perm.size(); ++i) {
perm[i] = i;
}
std::iota(std::begin(perm), std::end(perm), 0);
perm[axis] = input_rank - 1;
perm[input_rank - 1] = axis;
LOGS_DEFAULT(VERBOSE) << "Allocating temporary tensors for transpose\n";
std::vector<int64_t> transposed_input_dims;
for (auto e : perm) {
transposed_input_dims.push_back(input_shape[e]);
}
// allocate a temporary tensor to hold transposed input
Tensor temp_input(input_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
LOGS_DEFAULT(VERBOSE) << "Performing transpose\n";
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(perm, *input_tensor, temp_input));
LOGS_DEFAULT(VERBOSE) << "Transpose done\n";
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
transposed_input_tensor = std::move(temp_input);
transposed_input_shape = transposed_input_tensor.Shape();
LOGS_DEFAULT(VERBOSE) << "Transposed input shape: " << transposed_input_shape.ToString() << "\n";
// Allocate memory for the intermediate output
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
Tensor temp_output(output_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
intermediate_output = std::move(temp_output);
transposed_input_shape = TensorShape(transposed_input_dims);
transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape);
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor));
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape);
}
const size_t cols = transposed_input_shape[input_rank - 1];
const size_t rows = input_shape.Size() / cols;
const size_t components = GetMaxComponents(cols);
const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1];
const int64_t rows = input_shape.Size() / cols;
const int64_t components = GetMaxComponents(cols);
const auto packedCols = cols / components;
LOGS_DEFAULT(VERBOSE) << "Cols: " << cols << " Rows: " << rows << " Components: " << components << " PackedCols: " << packedCols << "\n";
size_t WG = rows == 1 ? 256: 64;
uint32_t WG = rows == 1 ? 256 : 64;
SoftmaxProgram program{WG};
if (is_transpose_required) {
if (is_transpose_required) {
program
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
} else {
program
@ -251,22 +212,17 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
}
program
.CacheHint(std::to_string(components), std::to_string(WG))
.SetWorkgroupSize(WG)
.SetDispatchGroupSize(rows)
.AddUniformVariables({
{static_cast<int32_t>(packedCols)}
});
.AddUniformVariables({{static_cast<int32_t>(packedCols)}});
ORT_RETURN_IF_ERROR(context.RunProgram(program));
// If transpose was required, transpose the result back
if (is_transpose_required) {
Tensor transposed_output_tensor;
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(perm, intermediate_output, *output_tensor));
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor));
}
return Status::OK();

View file

@ -37,15 +37,15 @@ class Softmax final : public WebGpuKernel {
class SoftmaxProgram final : public Program<SoftmaxProgram> {
public:
SoftmaxProgram(size_t wg) : Program{"Softmax"}, WG{wg} {
}
SoftmaxProgram(uint32_t wg) : Program{"Softmax"}, WG{wg} {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32});
private:
size_t WG;
uint32_t WG;
};
} // namespace webgpu

View file

@ -97,24 +97,27 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}
Status Transpose::DoTranspose(const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) {
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) {
const auto& input_shape = input.Shape();
const auto& input_dims = input_shape.GetDims();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());
TensorShapeVector output_dims(rank);
InlinedVector<size_t> default_perm(rank);
const InlinedVector<size_t>* p_perm = nullptr;
ORT_RETURN_IF_ERROR(ComputeOutputShape(input, output_dims, default_perm, p_perm));
for (int32_t i = 0; i < rank; i++) {
output_dims[i] = input_dims[permutations[i]];
}
TensorShape output_shape(output_dims);
InlinedVector<int64_t> new_shape{};
InlinedVector<int64_t> new_perm{};
SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm);
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
auto new_input_shape = input_shape;
TensorShape new_output_shape(output_dims);
if (use_shared) {
new_input_shape = channels_last
@ -125,16 +128,16 @@ Status Transpose::DoTranspose(const gsl::span<const size_t>& permutations, const
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
}
uint32_t output_size = gsl::narrow_cast<int32_t>(input.Shape().Size());
TransposeProgram program{*p_perm, use_shared};
uint32_t output_size = gsl::narrow_cast<int32_t>(input_shape.Size());
TransposeProgram program{permutations, use_shared};
if (use_shared) {
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
}
program
.CacheHint(absl::StrJoin(*p_perm, "-"))
.AddInputs({{*input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{*output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.CacheHint(absl::StrJoin(permutations, "-"))
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
.AddUniformVariables({

View file

@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase {
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
}
Status ComputeInternal(ComputeContext& context) const override;
static Status DoTranspose(const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output);
static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output);
constexpr static uint32_t TILE_SIZE = 16;
};

View file

@ -170,11 +170,11 @@ TEST(SoftmaxOperator, ThreeAndFourDimsAxis0) {
RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 0,
// axis=0 is not supported by TensorRT
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 0,
// axis=0 is not supported by TensorRT
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
}
TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) {
@ -201,10 +201,10 @@ TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) {
0.040478885f, 0.033857856f, 0.080346674f, 0.06199841f, 0.040481992f};
RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 1,
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 2,
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
}
TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis_opset13) {
@ -376,8 +376,9 @@ TEST(SoftmaxOperator, DimWithZero) {
RunTest(x_vals, expected_vals, dimensions, /*opset*/ -1, /*axis*/ 0,
{kTensorrtExecutionProvider,
kNnapiExecutionProvider, // NNAPI softmax does not support empty input
kQnnExecutionProvider} // QNN doesn't support dim 0
kNnapiExecutionProvider, // NNAPI softmax does not support empty input
kWebGpuExecutionProvider, // WebGPU does not dim 0
kQnnExecutionProvider} // QNN doesn't support dim 0
);
}