[WebNN EP] Support LRN operator (#22775)

WebNN doesn't provide dedicate op for LRN, use a couple of WebNN ops to
emulate it in WebNN EP:
pow -> transpose -> pad -> averagePool -> transpose -> mul -> add -> pow
-> div
@Honry @fdwr PTAL, thanks!
This commit is contained in:
Bin Miao 2024-11-13 03:53:52 +08:00 committed by GitHub
parent fd5b1a18ee
commit 67f5be0da2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 314 additions and 0 deletions

View file

@ -57,6 +57,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | |
| Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | |
| LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 |
| LRN | ai.onnx(7-12, 13+) | pad, averagePool2d, transpose, add, mul, pow, div | ✓ | ✓ | |
| LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' |
| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | |
| Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | |

View file

@ -261,5 +261,67 @@ bool IsMLTensorSupported() {
return is_supported;
}
// Convert int8 to uint4/int4 (stored as uint8)
uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type) {
uint8_t result = 0;
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
if (value < 0 || value > 15) {
ORT_THROW("Value cannot be safely converted to uint4.");
}
result |= (static_cast<uint8_t>(value) << 4);
} else {
if (value < -8 || value > 7) {
ORT_THROW("Value cannot be safely converted to int4.");
}
result |= (value << 4);
}
return result;
}
// Convert float32 to float16 (stored as uint16)
uint16_t PackFloat32ToUint16AsFloat16(float value) {
uint32_t float32_bits;
// Safely copy the float bits into an integer
std::memcpy(&float32_bits, &value, sizeof(float));
// Extract the sign, exponent, and mantissa from the float32 bits
uint32_t sign = (float32_bits >> 31) & 0x1;
uint32_t exponent = (float32_bits >> 23) & 0xFF;
uint32_t mantissa = float32_bits & 0x7FFFFF;
// Shift the sign for float16
uint16_t sign_float16 = sign << 15;
// Handle special cases: Infinity and NaN
if (exponent == 255) {
return sign_float16 | (0x1F << 10) | (mantissa ? 0x200 : 0);
}
// Handle zero and subnormal numbers in float32
if (exponent == 0) {
return sign_float16 | (mantissa >> 13);
}
// Adjust the exponent for float16 (subtract bias difference: 127 - 15 = 112)
int exponent_float16 = exponent - 112;
// Handle exponent overflow (larger than float16 can represent)
if (exponent_float16 >= 0x1F) {
return sign_float16 | (0x1F << 10);
}
// Handle exponent underflow (smaller than float16 can represent)
if (exponent_float16 <= 0) {
mantissa = (mantissa | 0x800000) >> (1 - exponent_float16);
return sign_float16 | (mantissa >> 13);
}
// Adjust the mantissa by shifting it to fit float16 format (round to nearest even)
uint16_t mantissa_float16 = (mantissa + 0x1000) >> 13;
// Combine sign, exponent, and mantissa into the final float16 representation
return sign_float16 | (exponent_float16 << 10) | mantissa_float16;
}
} // namespace webnn
} // namespace onnxruntime

View file

@ -4,6 +4,7 @@
#pragma once
#include <cstring>
#include <core/common/status.h>
#include "core/common/inlined_containers.h"
#include <core/graph/basic_types.h>
@ -238,6 +239,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"LSTM", "lstm"},
{"LRN", "averagePool2d"},
{"MatMul", "matmul"},
{"MatMulInteger", "matmulInteger"},
{"Max", "max"},
@ -345,5 +347,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type);
bool IsMLTensorSupported();
uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type);
uint16_t PackFloat32ToUint16AsFloat16(float value);
} // namespace webnn
} // namespace onnxruntime

View file

@ -0,0 +1,150 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"
#include "base_op_builder.h"
namespace onnxruntime {
namespace webnn {
class LRNOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};
Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto input_data_type = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
const auto node_name = node.Name();
emscripten::val wnn_builder = model_builder.GetBuilder();
NodeAttrHelper helper(node);
const float alpha = helper.Get("alpha", 0.0001f);
const float beta = helper.Get("beta", 0.75f);
const float bias = helper.Get("bias", 1.0f);
const uint32_t size = helper.Get("size", 1);
// Prepare WebNN constants for alpha, beta, bias attributes.
// Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'.
emscripten::val alpha_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, alpha);
emscripten::val beta_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, beta);
emscripten::val bias_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, bias);
emscripten::val pow1_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, 2);
/**
WebNN doesn't support LRN. So decompose it into a series of ops:
X --> Pow --> (Transpose)--> Pad --> AveragePool--> (Transpose) --> Mul --> Add --> Pow --> Div
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
Y:2 (0,2,3,1) Kernel:(1,size) (0,3,1,2) B:alpha B:bias B:beta A:input
*/
//
// pow(input, 2)
emscripten::val label_options = emscripten::val::object();
label_options.set("label", node_name + "_pow1");
emscripten::val pow1_output = wnn_builder.call<emscripten::val>("pow", input, pow1_constant, label_options);
// transpose(pow1_output, permutation=[0, 2, 3, 1])
// LRN is one of NHWC layout sensitive ops. When preferred layout is NCHW, move dimension 1 to dimension 3 (rightmost).
if (model_builder.GetPreferredLayout() == DataLayout::NCHW) {
std::vector<uint32_t> perm{0, 2, 3, 1};
emscripten::val transpose_options = emscripten::val::object();
transpose_options.set("label", node_name + "_transpose_rightmost");
transpose_options.set("permutation", emscripten::val::array(perm));
pow1_output =
wnn_builder.call<emscripten::val>("transpose", pow1_output, transpose_options);
}
// pad(pow1_output, beginning_padding = {0, 0, 0, leading_padding}, ending_padding = {0, 0, 0, trailing_padding})
// Adding a Pad before averagePool2d and calling AveragePool with pads as 0's.
const uint32_t leading_padding = floor((size - 1) / 2);
const uint32_t trailing_padding = ceil((size - 1) / 2);
std::vector<uint32_t> beginning_padding{0, 0, 0, leading_padding};
std::vector<uint32_t> ending_padding{0, 0, 0, trailing_padding};
emscripten::val pad_options = emscripten::val::object();
pad_options.set("label", node_name + "_pad");
emscripten::val pad_output =
wnn_builder.call<emscripten::val>("pad", pow1_output, emscripten::val::array(beginning_padding),
emscripten::val::array(ending_padding), pad_options);
// averagePool2d(pad_output, pool_options)
const std::vector<uint32_t> kernel_shape = {1, size};
emscripten::val pool_options = emscripten::val::object();
pool_options.set("label", node_name + "_averagePool2d");
pool_options.set("windowDimensions", emscripten::val::array(kernel_shape));
emscripten::val pool_output = wnn_builder.call<emscripten::val>("averagePool2d", pad_output, pool_options);
// transpose(pool_output, permutation=[0, 3, 1, 2])
// Move dimension 3 back to dimension 1.
if (model_builder.GetPreferredLayout() == DataLayout::NCHW) {
std::vector<uint32_t> perm{0, 3, 1, 2};
emscripten::val transpose_options = emscripten::val::object();
transpose_options.set("label", node_name + "_transpose_inverse");
transpose_options.set("permutation", emscripten::val::array(perm));
pool_output =
wnn_builder.call<emscripten::val>("transpose", pool_output, transpose_options);
}
// mul(pool_output, alpha_constant)
label_options.set("label", node_name + "_mul");
emscripten::val mul_output =
wnn_builder.call<emscripten::val>("mul", pool_output, alpha_constant, label_options);
// add(mul_output, bias_constant)
label_options.set("label", node_name + "_add");
emscripten::val add_output = wnn_builder.call<emscripten::val>("add", mul_output, bias_constant, label_options);
// pow(add_output, beta_constant)
label_options.set("label", node_name + "_pow2");
emscripten::val pow2_output = wnn_builder.call<emscripten::val>("pow", add_output, beta_constant, label_options);
// div(input, pow2_output)
label_options.set("label", node_name + "_div");
emscripten::val div_output = wnn_builder.call<emscripten::val>("div", input, pow2_output, label_options);
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(div_output));
return Status::OK();
}
// Operator support related.
bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;
const auto input_size = input_shape.size();
if (input_size != 4) {
LOGS(logger, VERBOSE) << "LRN only supports 4D input shape, input is "
<< input_size << "D shape";
return false;
}
return true;
}
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<LRNOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
} // namespace webnn
} // namespace onnxruntime

View file

@ -40,6 +40,10 @@ class ModelBuilder {
void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(
const int32_t& data_type, const std::vector<uint32_t>& shape = {});
template <typename T>
const emscripten::val& CreateOrGetScalarConstant(const int32_t& data_type, T value);
// Use the buffers to persist WebNN allocated data like transposed weight.
// It ensures the validity during inference session.
std::vector<std::unique_ptr<uint8_t[]>> mem_persist_buffers_;
@ -99,5 +103,92 @@ class ModelBuilder {
static const IOpBuilder* GetOpBuilder(const Node& node);
};
// Create a scalar constant MLOperand of the specified value and data type.
// Workaround for builer.constant(type, value) method since it has not been implemented now.
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value
// BTW, the spec is discussing if the builder.constant(type, value) should be dropped at
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
//
// This function enforces a mapping between the data_type and the value types:
// - TensorProto_DataType_INT4 <-> int8_t
// - TensorProto_DataType_UINT4 <-> int8_t
// - TensorProto_DataType_BOOL <-> bool
// - TensorProto_DataType_UINT8 <-> uint8_t
// - TensorProto_DataType_INT8 <-> int8_t
// - TensorProto_DataType_FLOAT16 <-> float
// - TensorProto_DataType_FLOAT <-> float
// - TensorProto_DataType_INT32 <-> int32_t
// - TensorProto_DataType_INT64 <-> int64_t
// - TensorProto_DataType_UINT32 <-> uint32_t
// - TensorProto_DataType_UINT64 <-> uint64_t
template <typename T>
const emscripten::val& ModelBuilder::CreateOrGetScalarConstant(const int32_t& data_type, T value) {
std::string name = "webnn_scalar_constant_" + std::to_string(data_type) + "_" + std::to_string(value);
emscripten::val desc = emscripten::val::object();
desc.set("shape", emscripten::val::array());
emscripten::val scalar_buffer = emscripten::val::undefined();
uint16_t value_uint16 = 0;
uint8_t value_uint8 = 0;
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}
// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
value_uint8 = PackInt8ToUint8AsNibble(value, data_type);
scalar_buffer.call<void>("fill", emscripten::val(value_uint8));
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value ? 1 : 0));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
scalar_buffer = emscripten::val::global("Int8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
scalar_buffer = emscripten::val::global("Uint16Array").new_(1);
value_uint16 = PackFloat32ToUint16AsFloat16(value);
scalar_buffer.call<void>("fill", emscripten::val(value_uint16));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
scalar_buffer = emscripten::val::global("Float32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
scalar_buffer = emscripten::val::global("Int32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
scalar_buffer = emscripten::val::global("Uint32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
scalar_buffer = emscripten::val::global("BigInt64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
scalar_buffer = emscripten::val::global("BigUint64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
break;
default:
break;
}
const emscripten::val scalar_constant = wnn_builder_.call<emscripten::val>("constant", desc, scalar_buffer);
wnn_operands_.insert(std::make_pair(name, scalar_constant));
}
return wnn_operands_.at(name);
}
} // namespace webnn
} // namespace onnxruntime

View file

@ -137,6 +137,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateLogicalOpBuilder("Xor", op_registrations);
}
{ // LRN
CreateLRNOpBuilder("LRN", op_registrations);
}
{ // LSTM
CreateLstmOpBuilder("LSTM", op_registrations);
}

View file

@ -37,6 +37,7 @@ void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations&
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);