mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
[WebNN EP] Support LRN operator (#22775)
WebNN doesn't provide dedicate op for LRN, use a couple of WebNN ops to emulate it in WebNN EP: pow -> transpose -> pad -> averagePool -> transpose -> mul -> add -> pow -> div @Honry @fdwr PTAL, thanks!
This commit is contained in:
parent
fd5b1a18ee
commit
67f5be0da2
7 changed files with 314 additions and 0 deletions
|
|
@ -57,6 +57,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|
|||
| LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | |
|
||||
| Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | |
|
||||
| LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 |
|
||||
| LRN | ai.onnx(7-12, 13+) | pad, averagePool2d, transpose, add, mul, pow, div | ✓ | ✓ | |
|
||||
| LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' |
|
||||
| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | |
|
||||
| Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | |
|
||||
|
|
|
|||
|
|
@ -261,5 +261,67 @@ bool IsMLTensorSupported() {
|
|||
return is_supported;
|
||||
}
|
||||
|
||||
// Convert int8 to uint4/int4 (stored as uint8)
|
||||
uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type) {
|
||||
uint8_t result = 0;
|
||||
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
|
||||
if (value < 0 || value > 15) {
|
||||
ORT_THROW("Value cannot be safely converted to uint4.");
|
||||
}
|
||||
result |= (static_cast<uint8_t>(value) << 4);
|
||||
} else {
|
||||
if (value < -8 || value > 7) {
|
||||
ORT_THROW("Value cannot be safely converted to int4.");
|
||||
}
|
||||
result |= (value << 4);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Convert float32 to float16 (stored as uint16)
|
||||
uint16_t PackFloat32ToUint16AsFloat16(float value) {
|
||||
uint32_t float32_bits;
|
||||
|
||||
// Safely copy the float bits into an integer
|
||||
std::memcpy(&float32_bits, &value, sizeof(float));
|
||||
|
||||
// Extract the sign, exponent, and mantissa from the float32 bits
|
||||
uint32_t sign = (float32_bits >> 31) & 0x1;
|
||||
uint32_t exponent = (float32_bits >> 23) & 0xFF;
|
||||
uint32_t mantissa = float32_bits & 0x7FFFFF;
|
||||
|
||||
// Shift the sign for float16
|
||||
uint16_t sign_float16 = sign << 15;
|
||||
|
||||
// Handle special cases: Infinity and NaN
|
||||
if (exponent == 255) {
|
||||
return sign_float16 | (0x1F << 10) | (mantissa ? 0x200 : 0);
|
||||
}
|
||||
// Handle zero and subnormal numbers in float32
|
||||
if (exponent == 0) {
|
||||
return sign_float16 | (mantissa >> 13);
|
||||
}
|
||||
|
||||
// Adjust the exponent for float16 (subtract bias difference: 127 - 15 = 112)
|
||||
int exponent_float16 = exponent - 112;
|
||||
|
||||
// Handle exponent overflow (larger than float16 can represent)
|
||||
if (exponent_float16 >= 0x1F) {
|
||||
return sign_float16 | (0x1F << 10);
|
||||
}
|
||||
// Handle exponent underflow (smaller than float16 can represent)
|
||||
if (exponent_float16 <= 0) {
|
||||
mantissa = (mantissa | 0x800000) >> (1 - exponent_float16);
|
||||
return sign_float16 | (mantissa >> 13);
|
||||
}
|
||||
|
||||
// Adjust the mantissa by shifting it to fit float16 format (round to nearest even)
|
||||
uint16_t mantissa_float16 = (mantissa + 0x1000) >> 13;
|
||||
|
||||
// Combine sign, exponent, and mantissa into the final float16 representation
|
||||
return sign_float16 | (exponent_float16 << 10) | mantissa_float16;
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <core/common/status.h>
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include <core/graph/basic_types.h>
|
||||
|
|
@ -238,6 +239,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
|
|||
{"Log", "log"},
|
||||
{"LpPool", "l2Pool2d"},
|
||||
{"LSTM", "lstm"},
|
||||
{"LRN", "averagePool2d"},
|
||||
{"MatMul", "matmul"},
|
||||
{"MatMulInteger", "matmulInteger"},
|
||||
{"Max", "max"},
|
||||
|
|
@ -345,5 +347,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type);
|
|||
|
||||
bool IsMLTensorSupported();
|
||||
|
||||
uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type);
|
||||
uint16_t PackFloat32ToUint16AsFloat16(float value);
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
150
onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc
Normal file
150
onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class LRNOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto input_data_type = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
|
||||
const auto node_name = node.Name();
|
||||
emscripten::val wnn_builder = model_builder.GetBuilder();
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
const float alpha = helper.Get("alpha", 0.0001f);
|
||||
const float beta = helper.Get("beta", 0.75f);
|
||||
const float bias = helper.Get("bias", 1.0f);
|
||||
const uint32_t size = helper.Get("size", 1);
|
||||
|
||||
// Prepare WebNN constants for alpha, beta, bias attributes.
|
||||
// Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'.
|
||||
emscripten::val alpha_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, alpha);
|
||||
emscripten::val beta_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, beta);
|
||||
emscripten::val bias_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, bias);
|
||||
emscripten::val pow1_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, 2);
|
||||
|
||||
/**
|
||||
WebNN doesn't support LRN. So decompose it into a series of ops:
|
||||
X --> Pow --> (Transpose)--> Pad --> AveragePool--> (Transpose) --> Mul --> Add --> Pow --> Div
|
||||
^ ^ ^ ^ ^ ^ ^ ^
|
||||
| | | | | | | |
|
||||
Y:2 (0,2,3,1) Kernel:(1,size) (0,3,1,2) B:alpha B:bias B:beta A:input
|
||||
*/
|
||||
//
|
||||
// pow(input, 2)
|
||||
emscripten::val label_options = emscripten::val::object();
|
||||
label_options.set("label", node_name + "_pow1");
|
||||
emscripten::val pow1_output = wnn_builder.call<emscripten::val>("pow", input, pow1_constant, label_options);
|
||||
|
||||
// transpose(pow1_output, permutation=[0, 2, 3, 1])
|
||||
// LRN is one of NHWC layout sensitive ops. When preferred layout is NCHW, move dimension 1 to dimension 3 (rightmost).
|
||||
if (model_builder.GetPreferredLayout() == DataLayout::NCHW) {
|
||||
std::vector<uint32_t> perm{0, 2, 3, 1};
|
||||
emscripten::val transpose_options = emscripten::val::object();
|
||||
transpose_options.set("label", node_name + "_transpose_rightmost");
|
||||
transpose_options.set("permutation", emscripten::val::array(perm));
|
||||
pow1_output =
|
||||
wnn_builder.call<emscripten::val>("transpose", pow1_output, transpose_options);
|
||||
}
|
||||
|
||||
// pad(pow1_output, beginning_padding = {0, 0, 0, leading_padding}, ending_padding = {0, 0, 0, trailing_padding})
|
||||
// Adding a Pad before averagePool2d and calling AveragePool with pads as 0's.
|
||||
const uint32_t leading_padding = floor((size - 1) / 2);
|
||||
const uint32_t trailing_padding = ceil((size - 1) / 2);
|
||||
std::vector<uint32_t> beginning_padding{0, 0, 0, leading_padding};
|
||||
std::vector<uint32_t> ending_padding{0, 0, 0, trailing_padding};
|
||||
emscripten::val pad_options = emscripten::val::object();
|
||||
pad_options.set("label", node_name + "_pad");
|
||||
emscripten::val pad_output =
|
||||
wnn_builder.call<emscripten::val>("pad", pow1_output, emscripten::val::array(beginning_padding),
|
||||
emscripten::val::array(ending_padding), pad_options);
|
||||
|
||||
// averagePool2d(pad_output, pool_options)
|
||||
const std::vector<uint32_t> kernel_shape = {1, size};
|
||||
emscripten::val pool_options = emscripten::val::object();
|
||||
pool_options.set("label", node_name + "_averagePool2d");
|
||||
pool_options.set("windowDimensions", emscripten::val::array(kernel_shape));
|
||||
emscripten::val pool_output = wnn_builder.call<emscripten::val>("averagePool2d", pad_output, pool_options);
|
||||
|
||||
// transpose(pool_output, permutation=[0, 3, 1, 2])
|
||||
// Move dimension 3 back to dimension 1.
|
||||
if (model_builder.GetPreferredLayout() == DataLayout::NCHW) {
|
||||
std::vector<uint32_t> perm{0, 3, 1, 2};
|
||||
emscripten::val transpose_options = emscripten::val::object();
|
||||
transpose_options.set("label", node_name + "_transpose_inverse");
|
||||
transpose_options.set("permutation", emscripten::val::array(perm));
|
||||
pool_output =
|
||||
wnn_builder.call<emscripten::val>("transpose", pool_output, transpose_options);
|
||||
}
|
||||
|
||||
// mul(pool_output, alpha_constant)
|
||||
label_options.set("label", node_name + "_mul");
|
||||
emscripten::val mul_output =
|
||||
wnn_builder.call<emscripten::val>("mul", pool_output, alpha_constant, label_options);
|
||||
|
||||
// add(mul_output, bias_constant)
|
||||
label_options.set("label", node_name + "_add");
|
||||
emscripten::val add_output = wnn_builder.call<emscripten::val>("add", mul_output, bias_constant, label_options);
|
||||
|
||||
// pow(add_output, beta_constant)
|
||||
label_options.set("label", node_name + "_pow2");
|
||||
emscripten::val pow2_output = wnn_builder.call<emscripten::val>("pow", add_output, beta_constant, label_options);
|
||||
|
||||
// div(input, pow2_output)
|
||||
label_options.set("label", node_name + "_div");
|
||||
emscripten::val div_output = wnn_builder.call<emscripten::val>("div", input, pow2_output, label_options);
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(div_output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related.
|
||||
bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
||||
const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
const auto input_size = input_shape.size();
|
||||
if (input_size != 4) {
|
||||
LOGS(logger, VERBOSE) << "LRN only supports 4D input shape, input is "
|
||||
<< input_size << "D shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<LRNOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -40,6 +40,10 @@ class ModelBuilder {
|
|||
void AddOperand(const std::string& name, const emscripten::val& operand);
|
||||
const emscripten::val& GetZeroConstant(
|
||||
const int32_t& data_type, const std::vector<uint32_t>& shape = {});
|
||||
|
||||
template <typename T>
|
||||
const emscripten::val& CreateOrGetScalarConstant(const int32_t& data_type, T value);
|
||||
|
||||
// Use the buffers to persist WebNN allocated data like transposed weight.
|
||||
// It ensures the validity during inference session.
|
||||
std::vector<std::unique_ptr<uint8_t[]>> mem_persist_buffers_;
|
||||
|
|
@ -99,5 +103,92 @@ class ModelBuilder {
|
|||
static const IOpBuilder* GetOpBuilder(const Node& node);
|
||||
};
|
||||
|
||||
// Create a scalar constant MLOperand of the specified value and data type.
|
||||
// Workaround for builer.constant(type, value) method since it has not been implemented now.
|
||||
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value
|
||||
// BTW, the spec is discussing if the builder.constant(type, value) should be dropped at
|
||||
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
|
||||
//
|
||||
// This function enforces a mapping between the data_type and the value types:
|
||||
// - TensorProto_DataType_INT4 <-> int8_t
|
||||
// - TensorProto_DataType_UINT4 <-> int8_t
|
||||
// - TensorProto_DataType_BOOL <-> bool
|
||||
// - TensorProto_DataType_UINT8 <-> uint8_t
|
||||
// - TensorProto_DataType_INT8 <-> int8_t
|
||||
// - TensorProto_DataType_FLOAT16 <-> float
|
||||
// - TensorProto_DataType_FLOAT <-> float
|
||||
// - TensorProto_DataType_INT32 <-> int32_t
|
||||
// - TensorProto_DataType_INT64 <-> int64_t
|
||||
// - TensorProto_DataType_UINT32 <-> uint32_t
|
||||
// - TensorProto_DataType_UINT64 <-> uint64_t
|
||||
template <typename T>
|
||||
const emscripten::val& ModelBuilder::CreateOrGetScalarConstant(const int32_t& data_type, T value) {
|
||||
std::string name = "webnn_scalar_constant_" + std::to_string(data_type) + "_" + std::to_string(value);
|
||||
emscripten::val desc = emscripten::val::object();
|
||||
desc.set("shape", emscripten::val::array());
|
||||
emscripten::val scalar_buffer = emscripten::val::undefined();
|
||||
uint16_t value_uint16 = 0;
|
||||
uint8_t value_uint8 = 0;
|
||||
if (!SetWebnnDataType(desc, data_type)) {
|
||||
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
|
||||
}
|
||||
|
||||
// If the operand does not exist, create it.
|
||||
if (wnn_operands_.find(name) == wnn_operands_.end()) {
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
|
||||
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
|
||||
value_uint8 = PackInt8ToUint8AsNibble(value, data_type);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value_uint8));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value ? 1 : 0));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
scalar_buffer = emscripten::val::global("Int8Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
scalar_buffer = emscripten::val::global("Uint16Array").new_(1);
|
||||
value_uint16 = PackFloat32ToUint16AsFloat16(value);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value_uint16));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
scalar_buffer = emscripten::val::global("Float32Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
scalar_buffer = emscripten::val::global("Int32Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
|
||||
scalar_buffer = emscripten::val::global("Uint32Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val(value));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
scalar_buffer = emscripten::val::global("BigInt64Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
scalar_buffer = emscripten::val::global("BigUint64Array").new_(1);
|
||||
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
const emscripten::val scalar_constant = wnn_builder_.call<emscripten::val>("constant", desc, scalar_buffer);
|
||||
wnn_operands_.insert(std::make_pair(name, scalar_constant));
|
||||
}
|
||||
|
||||
return wnn_operands_.at(name);
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -137,6 +137,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateLogicalOpBuilder("Xor", op_registrations);
|
||||
}
|
||||
|
||||
{ // LRN
|
||||
CreateLRNOpBuilder("LRN", op_registrations);
|
||||
}
|
||||
|
||||
{ // LSTM
|
||||
CreateLstmOpBuilder("LSTM", op_registrations);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations&
|
|||
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
|
|||
Loading…
Reference in a new issue