From 67f5be0da200f4ac0fa065d5682c99517b7e4636 Mon Sep 17 00:00:00 2001 From: Bin Miao Date: Wed, 13 Nov 2024 03:53:52 +0800 Subject: [PATCH] [WebNN EP] Support LRN operator (#22775) WebNN doesn't provide dedicate op for LRN, use a couple of WebNN ops to emulate it in WebNN EP: pow -> transpose -> pad -> averagePool -> transpose -> mul -> add -> pow -> div @Honry @fdwr PTAL, thanks! --- js/web/docs/webnn-operators.md | 1 + .../core/providers/webnn/builders/helper.cc | 62 ++++++++ .../core/providers/webnn/builders/helper.h | 5 + .../webnn/builders/impl/lrn_op_builder.cc | 150 ++++++++++++++++++ .../providers/webnn/builders/model_builder.h | 91 +++++++++++ .../webnn/builders/op_builder_factory.cc | 4 + .../webnn/builders/op_builder_factory.h | 1 + 7 files changed, 314 insertions(+) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 0c86845f79..6b13658874 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -57,6 +57,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | | Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | | | LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 | +| LRN | ai.onnx(7-12, 13+) | pad, averagePool2d, transpose, add, mul, pow, div | ✓ | ✓ | | | LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | | | Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 321f3b3d7e..537e552af2 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -261,5 +261,67 @@ bool IsMLTensorSupported() { return is_supported; } +// Convert int8 to uint4/int4 (stored as uint8) +uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type) { + uint8_t result = 0; + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + if (value < 0 || value > 15) { + ORT_THROW("Value cannot be safely converted to uint4."); + } + result |= (static_cast(value) << 4); + } else { + if (value < -8 || value > 7) { + ORT_THROW("Value cannot be safely converted to int4."); + } + result |= (value << 4); + } + + return result; +} + +// Convert float32 to float16 (stored as uint16) +uint16_t PackFloat32ToUint16AsFloat16(float value) { + uint32_t float32_bits; + + // Safely copy the float bits into an integer + std::memcpy(&float32_bits, &value, sizeof(float)); + + // Extract the sign, exponent, and mantissa from the float32 bits + uint32_t sign = (float32_bits >> 31) & 0x1; + uint32_t exponent = (float32_bits >> 23) & 0xFF; + uint32_t mantissa = float32_bits & 0x7FFFFF; + + // Shift the sign for float16 + uint16_t sign_float16 = sign << 15; + + // Handle special cases: Infinity and NaN + if (exponent == 255) { + return sign_float16 | (0x1F << 10) | (mantissa ? 0x200 : 0); + } + // Handle zero and subnormal numbers in float32 + if (exponent == 0) { + return sign_float16 | (mantissa >> 13); + } + + // Adjust the exponent for float16 (subtract bias difference: 127 - 15 = 112) + int exponent_float16 = exponent - 112; + + // Handle exponent overflow (larger than float16 can represent) + if (exponent_float16 >= 0x1F) { + return sign_float16 | (0x1F << 10); + } + // Handle exponent underflow (smaller than float16 can represent) + if (exponent_float16 <= 0) { + mantissa = (mantissa | 0x800000) >> (1 - exponent_float16); + return sign_float16 | (mantissa >> 13); + } + + // Adjust the mantissa by shifting it to fit float16 format (round to nearest even) + uint16_t mantissa_float16 = (mantissa + 0x1000) >> 13; + + // Combine sign, exponent, and mantissa into the final float16 representation + return sign_float16 | (exponent_float16 << 10) | mantissa_float16; +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 7b4a773365..c27e82ee7f 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include "core/common/inlined_containers.h" #include @@ -238,6 +239,7 @@ static const InlinedHashMap op_map = { {"Log", "log"}, {"LpPool", "l2Pool2d"}, {"LSTM", "lstm"}, + {"LRN", "averagePool2d"}, {"MatMul", "matmul"}, {"MatMulInteger", "matmulInteger"}, {"Max", "max"}, @@ -345,5 +347,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); bool IsMLTensorSupported(); +uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type); +uint16_t PackFloat32ToUint16AsFloat16(float value); + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc new file mode 100644 index 0000000000..bdd1283c72 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class LRNOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto input_data_type = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + const auto node_name = node.Name(); + emscripten::val wnn_builder = model_builder.GetBuilder(); + + NodeAttrHelper helper(node); + const float alpha = helper.Get("alpha", 0.0001f); + const float beta = helper.Get("beta", 0.75f); + const float bias = helper.Get("bias", 1.0f); + const uint32_t size = helper.Get("size", 1); + + // Prepare WebNN constants for alpha, beta, bias attributes. + // Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'. + emscripten::val alpha_constant = model_builder.CreateOrGetScalarConstant(input_data_type, alpha); + emscripten::val beta_constant = model_builder.CreateOrGetScalarConstant(input_data_type, beta); + emscripten::val bias_constant = model_builder.CreateOrGetScalarConstant(input_data_type, bias); + emscripten::val pow1_constant = model_builder.CreateOrGetScalarConstant(input_data_type, 2); + + /** + WebNN doesn't support LRN. So decompose it into a series of ops: + X --> Pow --> (Transpose)--> Pad --> AveragePool--> (Transpose) --> Mul --> Add --> Pow --> Div + ^ ^ ^ ^ ^ ^ ^ ^ + | | | | | | | | + Y:2 (0,2,3,1) Kernel:(1,size) (0,3,1,2) B:alpha B:bias B:beta A:input + */ + // + // pow(input, 2) + emscripten::val label_options = emscripten::val::object(); + label_options.set("label", node_name + "_pow1"); + emscripten::val pow1_output = wnn_builder.call("pow", input, pow1_constant, label_options); + + // transpose(pow1_output, permutation=[0, 2, 3, 1]) + // LRN is one of NHWC layout sensitive ops. When preferred layout is NCHW, move dimension 1 to dimension 3 (rightmost). + if (model_builder.GetPreferredLayout() == DataLayout::NCHW) { + std::vector perm{0, 2, 3, 1}; + emscripten::val transpose_options = emscripten::val::object(); + transpose_options.set("label", node_name + "_transpose_rightmost"); + transpose_options.set("permutation", emscripten::val::array(perm)); + pow1_output = + wnn_builder.call("transpose", pow1_output, transpose_options); + } + + // pad(pow1_output, beginning_padding = {0, 0, 0, leading_padding}, ending_padding = {0, 0, 0, trailing_padding}) + // Adding a Pad before averagePool2d and calling AveragePool with pads as 0's. + const uint32_t leading_padding = floor((size - 1) / 2); + const uint32_t trailing_padding = ceil((size - 1) / 2); + std::vector beginning_padding{0, 0, 0, leading_padding}; + std::vector ending_padding{0, 0, 0, trailing_padding}; + emscripten::val pad_options = emscripten::val::object(); + pad_options.set("label", node_name + "_pad"); + emscripten::val pad_output = + wnn_builder.call("pad", pow1_output, emscripten::val::array(beginning_padding), + emscripten::val::array(ending_padding), pad_options); + + // averagePool2d(pad_output, pool_options) + const std::vector kernel_shape = {1, size}; + emscripten::val pool_options = emscripten::val::object(); + pool_options.set("label", node_name + "_averagePool2d"); + pool_options.set("windowDimensions", emscripten::val::array(kernel_shape)); + emscripten::val pool_output = wnn_builder.call("averagePool2d", pad_output, pool_options); + + // transpose(pool_output, permutation=[0, 3, 1, 2]) + // Move dimension 3 back to dimension 1. + if (model_builder.GetPreferredLayout() == DataLayout::NCHW) { + std::vector perm{0, 3, 1, 2}; + emscripten::val transpose_options = emscripten::val::object(); + transpose_options.set("label", node_name + "_transpose_inverse"); + transpose_options.set("permutation", emscripten::val::array(perm)); + pool_output = + wnn_builder.call("transpose", pool_output, transpose_options); + } + + // mul(pool_output, alpha_constant) + label_options.set("label", node_name + "_mul"); + emscripten::val mul_output = + wnn_builder.call("mul", pool_output, alpha_constant, label_options); + + // add(mul_output, bias_constant) + label_options.set("label", node_name + "_add"); + emscripten::val add_output = wnn_builder.call("add", mul_output, bias_constant, label_options); + + // pow(add_output, beta_constant) + label_options.set("label", node_name + "_pow2"); + emscripten::val pow2_output = wnn_builder.call("pow", add_output, beta_constant, label_options); + + // div(input, pow2_output) + label_options.set("label", node_name + "_div"); + emscripten::val div_output = wnn_builder.call("div", input, pow2_output, label_options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(div_output)); + return Status::OK(); +} + +// Operator support related. +bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << "LRN only supports 4D input shape, input is " + << input_size << "D shape"; + return false; + } + + return true; +} + +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index aeb128adc9..c482e9d05b 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -40,6 +40,10 @@ class ModelBuilder { void AddOperand(const std::string& name, const emscripten::val& operand); const emscripten::val& GetZeroConstant( const int32_t& data_type, const std::vector& shape = {}); + + template + const emscripten::val& CreateOrGetScalarConstant(const int32_t& data_type, T value); + // Use the buffers to persist WebNN allocated data like transposed weight. // It ensures the validity during inference session. std::vector> mem_persist_buffers_; @@ -99,5 +103,92 @@ class ModelBuilder { static const IOpBuilder* GetOpBuilder(const Node& node); }; +// Create a scalar constant MLOperand of the specified value and data type. +// Workaround for builer.constant(type, value) method since it has not been implemented now. +// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value +// BTW, the spec is discussing if the builder.constant(type, value) should be dropped at +// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision. +// +// This function enforces a mapping between the data_type and the value types: +// - TensorProto_DataType_INT4 <-> int8_t +// - TensorProto_DataType_UINT4 <-> int8_t +// - TensorProto_DataType_BOOL <-> bool +// - TensorProto_DataType_UINT8 <-> uint8_t +// - TensorProto_DataType_INT8 <-> int8_t +// - TensorProto_DataType_FLOAT16 <-> float +// - TensorProto_DataType_FLOAT <-> float +// - TensorProto_DataType_INT32 <-> int32_t +// - TensorProto_DataType_INT64 <-> int64_t +// - TensorProto_DataType_UINT32 <-> uint32_t +// - TensorProto_DataType_UINT64 <-> uint64_t +template +const emscripten::val& ModelBuilder::CreateOrGetScalarConstant(const int32_t& data_type, T value) { + std::string name = "webnn_scalar_constant_" + std::to_string(data_type) + "_" + std::to_string(value); + emscripten::val desc = emscripten::val::object(); + desc.set("shape", emscripten::val::array()); + emscripten::val scalar_buffer = emscripten::val::undefined(); + uint16_t value_uint16 = 0; + uint8_t value_uint8 = 0; + if (!SetWebnnDataType(desc, data_type)) { + ORT_THROW("Unsupported data type: " + std::to_string(data_type)); + } + + // If the operand does not exist, create it. + if (wnn_operands_.find(name) == wnn_operands_.end()) { + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + scalar_buffer = emscripten::val::global("Uint8Array").new_(1); + value_uint8 = PackInt8ToUint8AsNibble(value, data_type); + scalar_buffer.call("fill", emscripten::val(value_uint8)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + scalar_buffer = emscripten::val::global("Uint8Array").new_(1); + scalar_buffer.call("fill", emscripten::val(value ? 1 : 0)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + scalar_buffer = emscripten::val::global("Uint8Array").new_(1); + scalar_buffer.call("fill", emscripten::val(value)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + scalar_buffer = emscripten::val::global("Int8Array").new_(1); + scalar_buffer.call("fill", emscripten::val(value)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + scalar_buffer = emscripten::val::global("Uint16Array").new_(1); + value_uint16 = PackFloat32ToUint16AsFloat16(value); + scalar_buffer.call("fill", emscripten::val(value_uint16)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + scalar_buffer = emscripten::val::global("Float32Array").new_(1); + scalar_buffer.call("fill", emscripten::val(value)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + scalar_buffer = emscripten::val::global("Int32Array").new_(1); + scalar_buffer.call("fill", emscripten::val(value)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + scalar_buffer = emscripten::val::global("Uint32Array").new_(1); + scalar_buffer.call("fill", emscripten::val(value)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + scalar_buffer = emscripten::val::global("BigInt64Array").new_(1); + scalar_buffer.call("fill", emscripten::val::global("BigInt")(value)); + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + scalar_buffer = emscripten::val::global("BigUint64Array").new_(1); + scalar_buffer.call("fill", emscripten::val::global("BigInt")(value)); + break; + default: + break; + } + + const emscripten::val scalar_constant = wnn_builder_.call("constant", desc, scalar_buffer); + wnn_operands_.insert(std::make_pair(name, scalar_constant)); + } + + return wnn_operands_.at(name); +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 1bce7c350a..ae095d24b8 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -137,6 +137,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateLogicalOpBuilder("Xor", op_registrations); } + { // LRN + CreateLRNOpBuilder("LRN", op_registrations); + } + { // LSTM CreateLstmOpBuilder("LSTM", op_registrations); } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 6ac5b8de18..a492d3e60f 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -37,6 +37,7 @@ void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);