mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Refactor Softmax and remove debug logs
This commit is contained in:
parent
f9b61dbc6f
commit
87de60730a
5 changed files with 50 additions and 90 deletions
|
|
@ -8,9 +8,6 @@
|
|||
#include "core/providers/webgpu/shader_variable.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webgpu {
|
||||
|
||||
|
|
@ -84,16 +81,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
int components = input.NumComponents();
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Input StorageType: " << input.StorageType() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input ElementType: " << input.ElementType() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input ValueType: " << input.ValueType() << "\n";
|
||||
|
||||
|
||||
|
||||
std::string threadMaxDecl = input.ElementType() == "f32" ?
|
||||
"var threadMax = x_value_t(-3.402823e+38f);\n" :
|
||||
"var threadMax = x_value_t(-65504.0h);\n";
|
||||
|
||||
std::string threadMaxDecl = input.ElementType() == "f32" ? "var threadMax = x_value_t(-3.402823e+38f);\n" : "var threadMax = x_value_t(-65504.0h);\n";
|
||||
|
||||
// Define shared memory for row max and row sum
|
||||
shader.AdditionalImplementation()
|
||||
|
|
@ -142,7 +130,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
<< " workgroupBarrier();\n"
|
||||
<< " }\n"
|
||||
<< " if (lindex == 0) {\n"
|
||||
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
|
||||
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
|
||||
<< " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
|
||||
|
|
@ -163,7 +151,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
<< " workgroupBarrier();\n"
|
||||
<< " }\n"
|
||||
<< " if (lindex == 0) {\n"
|
||||
<< " rowSumShared = x_value_t(" << SumVector("threadShared[0]", components) << ");\n"
|
||||
<< " rowSumShared = x_value_t(" << SumVector("threadShared[0]", components) << ");\n"
|
||||
<< " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
|
||||
|
|
@ -179,71 +167,44 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
Status Softmax::ComputeInternal(ComputeContext& context) const {
|
||||
const auto* input_tensor = context.Input(0);
|
||||
const TensorShape& input_shape = input_tensor->Shape();
|
||||
size_t input_rank = input_shape.NumDimensions();
|
||||
|
||||
int64_t input_rank = input_shape.NumDimensions();
|
||||
auto* output_tensor = context.Output(0, input_shape);
|
||||
|
||||
// normalize axis
|
||||
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
|
||||
|
||||
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
|
||||
bool is_transpose_required = axis < input_rank - 1;
|
||||
LOGS_DEFAULT(VERBOSE) <<"axis_: " << axis_ << " axis: " << axis << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Transpose required: " << (is_transpose_required ? "true" : "false") << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input shape: " << input_shape.ToString() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Output shape: " << output_tensor->Shape().ToString() << "\n";
|
||||
LOGS_DEFAULT(VERBOSE) << "Input rank: " << input_rank << "\n";
|
||||
|
||||
TensorShape transposed_input_shape = input_shape;
|
||||
TensorShape transposed_input_shape;
|
||||
Tensor transposed_input_tensor;
|
||||
Tensor intermediate_output;
|
||||
InlinedVector<size_t> perm;
|
||||
InlinedVector<size_t> perm(input_rank);
|
||||
|
||||
if (is_transpose_required) {
|
||||
AllocatorPtr alloc;
|
||||
perm.resize(input_rank);
|
||||
for (size_t i = 0; i < perm.size(); ++i) {
|
||||
perm[i] = i;
|
||||
}
|
||||
std::iota(std::begin(perm), std::end(perm), 0);
|
||||
perm[axis] = input_rank - 1;
|
||||
perm[input_rank - 1] = axis;
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Allocating temporary tensors for transpose\n";
|
||||
std::vector<int64_t> transposed_input_dims;
|
||||
for (auto e : perm) {
|
||||
transposed_input_dims.push_back(input_shape[e]);
|
||||
}
|
||||
|
||||
// allocate a temporary tensor to hold transposed input
|
||||
Tensor temp_input(input_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Performing transpose\n";
|
||||
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(perm, *input_tensor, temp_input));
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Transpose done\n";
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
|
||||
transposed_input_tensor = std::move(temp_input);
|
||||
transposed_input_shape = transposed_input_tensor.Shape();
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Transposed input shape: " << transposed_input_shape.ToString() << "\n";
|
||||
|
||||
// Allocate memory for the intermediate output
|
||||
LOGS_DEFAULT(VERBOSE) << "Allocating memory for intermediate output\n";
|
||||
Tensor temp_output(output_tensor->DataType(), TensorShape(transposed_input_shape), alloc);
|
||||
intermediate_output = std::move(temp_output);
|
||||
transposed_input_shape = TensorShape(transposed_input_dims);
|
||||
transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape);
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor));
|
||||
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape);
|
||||
}
|
||||
|
||||
|
||||
const size_t cols = transposed_input_shape[input_rank - 1];
|
||||
const size_t rows = input_shape.Size() / cols;
|
||||
const size_t components = GetMaxComponents(cols);
|
||||
const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1];
|
||||
const int64_t rows = input_shape.Size() / cols;
|
||||
const int64_t components = GetMaxComponents(cols);
|
||||
const auto packedCols = cols / components;
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Cols: " << cols << " Rows: " << rows << " Components: " << components << " PackedCols: " << packedCols << "\n";
|
||||
|
||||
size_t WG = rows == 1 ? 256: 64;
|
||||
uint32_t WG = rows == 1 ? 256 : 64;
|
||||
|
||||
SoftmaxProgram program{WG};
|
||||
if (is_transpose_required) {
|
||||
if (is_transpose_required) {
|
||||
program
|
||||
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
|
||||
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
|
||||
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
|
||||
} else {
|
||||
program
|
||||
|
|
@ -251,22 +212,17 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
|
|||
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
|
||||
}
|
||||
|
||||
|
||||
program
|
||||
.CacheHint(std::to_string(components), std::to_string(WG))
|
||||
.SetWorkgroupSize(WG)
|
||||
.SetDispatchGroupSize(rows)
|
||||
.AddUniformVariables({
|
||||
{static_cast<int32_t>(packedCols)}
|
||||
});
|
||||
|
||||
.AddUniformVariables({{static_cast<int32_t>(packedCols)}});
|
||||
|
||||
ORT_RETURN_IF_ERROR(context.RunProgram(program));
|
||||
|
||||
// If transpose was required, transpose the result back
|
||||
if (is_transpose_required) {
|
||||
Tensor transposed_output_tensor;
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(perm, intermediate_output, *output_tensor));
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -37,15 +37,15 @@ class Softmax final : public WebGpuKernel {
|
|||
|
||||
class SoftmaxProgram final : public Program<SoftmaxProgram> {
|
||||
public:
|
||||
SoftmaxProgram(size_t wg) : Program{"Softmax"}, WG{wg} {
|
||||
}
|
||||
SoftmaxProgram(uint32_t wg) : Program{"Softmax"}, WG{wg} {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32});
|
||||
|
||||
private:
|
||||
size_t WG;
|
||||
uint32_t WG;
|
||||
};
|
||||
|
||||
} // namespace webgpu
|
||||
|
|
|
|||
|
|
@ -97,24 +97,27 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Transpose::DoTranspose(const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) {
|
||||
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) {
|
||||
const auto& input_shape = input.Shape();
|
||||
const auto& input_dims = input_shape.GetDims();
|
||||
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());
|
||||
|
||||
|
||||
TensorShapeVector output_dims(rank);
|
||||
InlinedVector<size_t> default_perm(rank);
|
||||
const InlinedVector<size_t>* p_perm = nullptr;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(input, output_dims, default_perm, p_perm));
|
||||
|
||||
for (int32_t i = 0; i < rank; i++) {
|
||||
output_dims[i] = input_dims[permutations[i]];
|
||||
}
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
InlinedVector<int64_t> new_shape{};
|
||||
InlinedVector<int64_t> new_perm{};
|
||||
SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm);
|
||||
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
|
||||
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
|
||||
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});
|
||||
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
|
||||
auto new_input_shape = input_shape;
|
||||
TensorShape new_output_shape(output_dims);
|
||||
|
||||
if (use_shared) {
|
||||
new_input_shape = channels_last
|
||||
|
|
@ -125,16 +128,16 @@ Status Transpose::DoTranspose(const gsl::span<const size_t>& permutations, const
|
|||
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
|
||||
}
|
||||
|
||||
uint32_t output_size = gsl::narrow_cast<int32_t>(input.Shape().Size());
|
||||
TransposeProgram program{*p_perm, use_shared};
|
||||
uint32_t output_size = gsl::narrow_cast<int32_t>(input_shape.Size());
|
||||
TransposeProgram program{permutations, use_shared};
|
||||
|
||||
if (use_shared) {
|
||||
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
|
||||
}
|
||||
|
||||
program
|
||||
.CacheHint(absl::StrJoin(*p_perm, "-"))
|
||||
.AddInputs({{*input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
|
||||
.AddOutputs({{*output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
|
||||
.CacheHint(absl::StrJoin(permutations, "-"))
|
||||
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
|
||||
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
|
||||
.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
|
||||
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
|
||||
.AddUniformVariables({
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase {
|
|||
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
|
||||
}
|
||||
Status ComputeInternal(ComputeContext& context) const override;
|
||||
static Status DoTranspose(const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output);
|
||||
static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output);
|
||||
|
||||
constexpr static uint32_t TILE_SIZE = 16;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -170,11 +170,11 @@ TEST(SoftmaxOperator, ThreeAndFourDimsAxis0) {
|
|||
|
||||
RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 0,
|
||||
// axis=0 is not supported by TensorRT
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
|
||||
|
||||
RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 0,
|
||||
// axis=0 is not supported by TensorRT
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) {
|
||||
|
|
@ -201,10 +201,10 @@ TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) {
|
|||
0.040478885f, 0.033857856f, 0.080346674f, 0.06199841f, 0.040481992f};
|
||||
|
||||
RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 1,
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
|
||||
|
||||
RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 2,
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis_opset13) {
|
||||
|
|
@ -376,8 +376,9 @@ TEST(SoftmaxOperator, DimWithZero) {
|
|||
|
||||
RunTest(x_vals, expected_vals, dimensions, /*opset*/ -1, /*axis*/ 0,
|
||||
{kTensorrtExecutionProvider,
|
||||
kNnapiExecutionProvider, // NNAPI softmax does not support empty input
|
||||
kQnnExecutionProvider} // QNN doesn't support dim 0
|
||||
kNnapiExecutionProvider, // NNAPI softmax does not support empty input
|
||||
kWebGpuExecutionProvider, // WebGPU does not dim 0
|
||||
kQnnExecutionProvider} // QNN doesn't support dim 0
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue