mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebGPU EP] Batch Norm Implementation (#23525)
Increases operator coverage for webgpu ep. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
1fce51b3b2
commit
6b4f9c481d
4 changed files with 206 additions and 11 deletions
138
onnxruntime/core/providers/webgpu/nn/batch_norm.cc
Normal file
138
onnxruntime/core/providers/webgpu/nn/batch_norm.cc
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/providers/webgpu/nn/batch_norm.h"
|
||||
#include "core/providers/cpu/nn/batch_norm_helper.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webgpu {
|
||||
|
||||
#define WEBGPU_BATCH_NORM_VERSIONED_KERNEL(start, end, domain, is_nhwc) \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
domain, \
|
||||
start, \
|
||||
end, \
|
||||
kWebGpuExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
|
||||
BatchNormalization<is_nhwc>);
|
||||
|
||||
#define WEBGPU_BATCH_NORM_KERNEL(version, domain, is_nhwc) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
domain, \
|
||||
version, \
|
||||
kWebGpuExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
|
||||
BatchNormalization<is_nhwc>);
|
||||
|
||||
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kOnnxDomain, false)
|
||||
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kOnnxDomain, false)
|
||||
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kOnnxDomain, false)
|
||||
WEBGPU_BATCH_NORM_KERNEL(15, kOnnxDomain, false)
|
||||
|
||||
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kMSInternalNHWCDomain, true)
|
||||
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kMSInternalNHWCDomain, true)
|
||||
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kMSInternalNHWCDomain, true)
|
||||
WEBGPU_BATCH_NORM_KERNEL(15, kMSInternalNHWCDomain, true)
|
||||
|
||||
Status BatchNormalizationProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
const ShaderVariableHelper& input_tensor = shader.AddInput("input_tensor", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
const ShaderVariableHelper& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
const ShaderVariableHelper& input_mean = shader.AddInput("input_mean", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
const ShaderVariableHelper& input_var = shader.AddInput("input_var", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
|
||||
|
||||
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
|
||||
<< " let idx = global_idx * " << components_ << ";\n"
|
||||
<< " var outputIndices = " << output.OffsetToIndices("idx") << ";\n";
|
||||
if (spatial_) {
|
||||
if (input_tensor.Rank() == 1) {
|
||||
shader.MainFunctionBody() << " let cOffset = 0u;\n";
|
||||
} else {
|
||||
if (format_ == DataLayout::NHWC) {
|
||||
shader.MainFunctionBody() << " let cOffset = outputIndices[" << input_tensor.Rank() - 1 << "] / " << components_ << ";\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " let cOffset = outputIndices[1];\n";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (format_ == DataLayout::NCHW) {
|
||||
shader.MainFunctionBody() << " " << output.IndicesSet("outputIndices", "0", "0") << "\n"
|
||||
<< " let cOffset = " << output.IndicesToOffset("outputIndices") << ";\n";
|
||||
} else {
|
||||
// update C channel
|
||||
shader.MainFunctionBody() << " var cIndices = scale_indices_t(0);\n"
|
||||
<< " cIndices[0] = outputIndices[" << input_tensor.Rank() - 1 << "];\n";
|
||||
// update D1 x ... x Dn channels
|
||||
for (int i = 1; i < scale.Rank(); i++) {
|
||||
shader.MainFunctionBody() << " cIndices[" << i << "] = outputIndices[" << i << "];\n";
|
||||
}
|
||||
shader.MainFunctionBody() << " let cOffset = " << scale.IndicesToOffset("cIndices") << ";\n";
|
||||
}
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << " let scale = " << scale.GetByOffset("cOffset") << ";\n"
|
||||
<< " let B = " << B.GetByOffset("cOffset") << ";\n"
|
||||
<< " let input_mean = " << input_mean.GetByOffset("cOffset") << ";\n"
|
||||
<< " let input_var = " << input_var.GetByOffset("cOffset") << ";\n"
|
||||
<< " let x = " << input_tensor.GetByOffset("global_idx") << ";\n"
|
||||
<< " let value = (x - input_mean) * inverseSqrt(input_var + " << epsilon_ << ") * scale + B;\n"
|
||||
<< " " << output.SetByOffset("global_idx", "value") << "\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <bool is_nhwc>
|
||||
Status BatchNormalization<is_nhwc>::ComputeInternal(ComputeContext& context) const {
|
||||
if (training_mode_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization trainingMode is not supported yet.");
|
||||
}
|
||||
|
||||
if (context.InputCount() != 5) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization requires 5 inputs.");
|
||||
}
|
||||
|
||||
const auto* input_tensor = context.Input(0);
|
||||
const TensorShape& input_shape = input_tensor->Shape();
|
||||
size_t input_rank = input_shape.NumDimensions();
|
||||
const int components = spatial_ ? ((input_shape[input_rank - 1] % 4 == 0) ? 4 : ((input_shape[input_rank - 1] % 2 == 0) ? 2 : 1)) : 1;
|
||||
|
||||
auto output_dims = input_shape.AsShapeVector();
|
||||
TensorShape output_shape(output_dims);
|
||||
auto* output_tensor = context.Output(0, output_shape);
|
||||
int64_t output_size = output_tensor->Shape().Size() / static_cast<int64_t>(components);
|
||||
|
||||
if (output_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto* scale = context.Input<Tensor>(1);
|
||||
const auto* B = context.Input<Tensor>(2);
|
||||
const auto* input_mean = context.Input<Tensor>(3);
|
||||
const auto* input_var = context.Input<Tensor>(4);
|
||||
|
||||
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(input_tensor, scale, B, input_mean, input_var, spatial_ == 1, format_ == DataLayout::NHWC));
|
||||
|
||||
BatchNormalizationProgram program{epsilon_, spatial_, format_, static_cast<int64_t>(components)};
|
||||
program
|
||||
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
|
||||
{scale, ProgramTensorMetadataDependency::TypeAndRank},
|
||||
{B, ProgramTensorMetadataDependency::TypeAndRank},
|
||||
{input_mean, ProgramTensorMetadataDependency::TypeAndRank},
|
||||
{input_var, ProgramTensorMetadataDependency::TypeAndRank}})
|
||||
.AddOutputs({output_tensor})
|
||||
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
|
||||
.AddUniformVariables({{static_cast<uint32_t>(output_size)}});
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
54
onnxruntime/core/providers/webgpu/nn/batch_norm.h
Normal file
54
onnxruntime/core/providers/webgpu/nn/batch_norm.h
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/webgpu/webgpu_kernel.h"
|
||||
#include "core/providers/webgpu/program.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webgpu {
|
||||
|
||||
class BatchNormalizationProgram final : public Program<BatchNormalizationProgram> {
|
||||
public:
|
||||
BatchNormalizationProgram(float epsilon, int64_t spatial, DataLayout format, int64_t components) : Program{"BatchNormalization"},
|
||||
epsilon_{epsilon},
|
||||
spatial_{spatial},
|
||||
format_{format},
|
||||
components_{components} {}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
private:
|
||||
float epsilon_;
|
||||
int64_t spatial_;
|
||||
DataLayout format_;
|
||||
int64_t components_;
|
||||
};
|
||||
|
||||
template <bool is_nhwc>
|
||||
class BatchNormalization final : public WebGpuKernel {
|
||||
public:
|
||||
BatchNormalization(const OpKernelInfo& info) : WebGpuKernel(info) {
|
||||
epsilon_ = info.GetAttrOrDefault<float>("epsilon", 1e-5f);
|
||||
momentum_ = info.GetAttrOrDefault<float>("momentum", 0.9f);
|
||||
spatial_ = info.GetAttrOrDefault<int64_t>("spatial", 1);
|
||||
training_mode_ = info.GetAttrOrDefault<int64_t>("training_mode", 0);
|
||||
// NCHW for ai.onnx domain, NHWC for com.ms.internal.nhwc domain
|
||||
format_ = is_nhwc ? DataLayout::NHWC : DataLayout::NCHW;
|
||||
}
|
||||
|
||||
Status ComputeInternal(ComputeContext& context) const override;
|
||||
|
||||
private:
|
||||
float epsilon_;
|
||||
float momentum_;
|
||||
int64_t spatial_;
|
||||
int64_t training_mode_;
|
||||
DataLayout format_;
|
||||
};
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -696,14 +696,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If)>,
|
||||
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
|
||||
|
|
|
|||
|
|
@ -924,7 +924,8 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
// TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
|
||||
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
|
||||
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
|
||||
kWebGpuExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(BatchNormTest, ForwardTrainingTestOpset14) {
|
||||
|
|
@ -953,7 +954,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
|
|||
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
|
||||
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
|
||||
kWebGpuExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(BatchNormTest, ForwardTrainingTestOpset15) {
|
||||
|
|
@ -982,7 +984,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) {
|
|||
// Same exclusions as the opset 14 test
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
|
||||
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
|
||||
kWebGpuExecutionProvider});
|
||||
}
|
||||
#endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue