Revert "[WebNN EP] Remove NHWC preferred layout" (#21905)

Reverts microsoft/onnxruntime#21570
This commit is contained in:
Wanming Lin 2024-08-30 09:01:56 +08:00 committed by GitHub
parent 0223e8647b
commit 7550fec4aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 296 additions and 39 deletions

View file

@ -19,9 +19,10 @@ common::Status ComputeConvPads(const std::vector<int64_t> input_shape,
const std::vector<int64_t>& onnx_strides,
const std::vector<int64_t>& onnx_dilations,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out) {
const int64_t input_size_y = input_shape[2];
const int64_t input_size_x = input_shape[3];
std::vector<int64_t>& pads_out,
bool use_nchw) {
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
const int64_t stride_y = onnx_strides[0];
const int64_t stride_x = onnx_strides[1];
const int64_t dilation_y = onnx_dilations[0];
@ -53,15 +54,16 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
const std::vector<int64_t>& onnx_strides,
const std::vector<int64_t>& onnx_dilations,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out) {
std::vector<int64_t>& pads_out,
bool use_nchw) {
if (AutoPadType::SAME_UPPER == auto_pad_type) {
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
onnx_pads, onnx_strides, onnx_dilations,
AutoPadType::SAME_UPPER, pads_out));
AutoPadType::SAME_UPPER, pads_out, use_nchw));
} else {
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
onnx_pads, onnx_strides, onnx_dilations,
AutoPadType::SAME_LOWER, pads_out));
AutoPadType::SAME_LOWER, pads_out, use_nchw));
}
return Status::OK();
}
@ -109,9 +111,10 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
const std::vector<int64_t>& onnx_output_padding,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out,
std::vector<int64_t>& output_shape_out) {
const int64_t input_size_y = input_shape[2];
const int64_t input_size_x = input_shape[3];
std::vector<int64_t>& output_shape_out,
bool use_nchw) {
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
const int64_t stride_y = onnx_strides[0];
const int64_t stride_x = onnx_strides[1];
const int64_t dilation_y = onnx_dilations[0];

View file

@ -21,7 +21,8 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
const std::vector<int64_t>& onnx_strides,
const std::vector<int64_t>& onnx_dilations,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out) ORT_MUST_USE_RESULT;
std::vector<int64_t>& pads_out,
bool use_nchw) ORT_MUST_USE_RESULT;
// Compute pads and output shape for ConvTranspose.
common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t> input_shape,
@ -33,7 +34,8 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
const std::vector<int64_t>& onnx_output_padding,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out,
std::vector<int64_t>& output_shape_out) ORT_MUST_USE_RESULT;
std::vector<int64_t>& output_shape_out,
bool use_nchw) ORT_MUST_USE_RESULT;
} // namespace webnn
} // namespace onnxruntime

View file

@ -18,6 +18,9 @@ namespace webnn {
class ConvOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
@ -30,6 +33,13 @@ class ConvOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override;
};
void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
// skip the weight for conv as we need to transpose for preferred layout NHWC.
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W
}
}
// Helper functions
common::Status SetConvBaseOptions(ModelBuilder& model_builder,
const Node& node, emscripten::val& options,
@ -38,6 +48,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
const std::vector<int64_t>& strides,
const std::vector<int64_t>& dilations,
std::vector<int64_t>& pads,
const bool is_nhwc,
const bool is_conv1d,
const logging::Logger& logger) {
NodeAttrHelper helper(node);
@ -50,7 +61,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
// Calculate explicit padding for autoPad.
if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3],
pads, strides, dilations, auto_pad_type, pads_out));
pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc));
pads = pads_out;
}
} else if (node.OpType() == "ConvTranspose") {
@ -71,7 +82,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
// Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER.
ORT_RETURN_IF_ERROR(ComputeConvTransposePadsAndOutputShape(input_shape, weight_shape[2], weight_shape[3],
pads, strides, dilations, output_padding,
auto_pad_type, pads_out, output_shape));
auto_pad_type, pads_out, output_shape, !is_nhwc));
if (output_shape[0] != -1 && output_shape[1] != -1) {
options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape)));
@ -100,6 +111,89 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
return Status::OK();
}
// Both depthwise Conv and ConvTranspose share the same logic to add the layout.
Status AddInitializerInNewLayout(ModelBuilder& model_builder,
const std::string& name,
bool is_conv,
bool is_conv1d) {
const auto& tensor = *model_builder.GetInitializerTensors().at(name);
auto data_type = tensor.data_type();
const auto& shape = tensor.dims();
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(std::vector<int64_t>(std::begin(shape), std::end(shape)));
if (is_conv1d) {
// Support conv1d by prepending a 1 size dimension.
dims.push_back(1);
}
const uint8_t* src = nullptr;
Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath());
src = unpacked_tensor.DataAsByteSpan().data();
const auto out_t = dims[0], in_t = dims[1],
h_t = dims[2], w_t = dims[3];
std::vector<uint32_t> dest_shape;
if (is_conv == 1)
dest_shape = {out_t, h_t, w_t, in_t}; // L_0231
else
dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight
SafeInt<size_t> num_elements = SafeInt<size_t>(Product(dest_shape));
size_t element_size{0};
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
element_size = sizeof(uint8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
element_size = sizeof(int8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
element_size = sizeof(uint16_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
element_size = sizeof(float);
break;
default:
break;
}
std::unique_ptr<uint8_t[]> buffer_holder(new uint8_t[element_size * num_elements]);
uint8_t* buffer = buffer_holder.get();
for (uint32_t out = 0; out < out_t; out++) {
for (uint32_t in = 0; in < in_t; in++) {
for (uint32_t h = 0; h < h_t; h++) {
for (uint32_t w = 0; w < w_t; w++) {
auto onnx_idx = out * in_t * h_t * w_t +
in * h_t * w_t +
h * w_t +
w;
uint32_t nnapi_idx;
if (is_conv == 1) { // L_0231
nnapi_idx = out * h_t * w_t * in_t +
h * w_t * in_t +
w * in_t +
in;
} else { // L_1230 for depthwise conv weight
nnapi_idx = in * h_t * w_t * out_t +
h * w_t * out_t +
w * out_t +
out;
}
for (size_t i = 0; i < element_size; i++) {
buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i];
}
}
}
}
}
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size,
dest_shape, data_type));
return Status::OK();
}
// Add operator related.
Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
@ -109,6 +203,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
const auto& op_type = node.OpType();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val output = emscripten::val::object();
const auto& initializers(model_builder.GetInitializerTensors());
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
@ -121,11 +216,19 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
auto pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});
const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC;
const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3;
const bool is_constant_weight = Contains(initializers, weight_name);
// Support conv1d by prepending a 1 or 2 size dimensions.
if (is_conv1d) {
// Reshape input.
input_shape.push_back(1);
if (is_nhwc) {
// For NHWC preferred layout, the input has been transposed.
// For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2.
input_shape.insert(input_shape.begin() + 2, 1);
} else {
input_shape.push_back(1);
}
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(input_shape);
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
@ -141,19 +244,63 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
ORT_RETURN_IF_ERROR(SetConvBaseOptions(
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_conv1d, logger));
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger));
bool depthwise = false;
if (op_type == "Conv" || op_type == "ConvInteger") {
int groups = options["groups"].as<int>();
if (is_nhwc) {
depthwise = (groups == input_shape[3] && groups != 1);
options.set("inputLayout", emscripten::val("nhwc"));
if (is_constant_weight) {
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d));
}
if (!depthwise) {
options.set("filterLayout", emscripten::val("ohwi"));
} else {
options.set("filterLayout", emscripten::val("ihwo"));
}
}
} else { // ConvTranspose
if (is_nhwc) {
options.set("inputLayout", emscripten::val("nhwc"));
options.set("filterLayout", emscripten::val("ohwi"));
if (is_constant_weight) {
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d));
}
}
}
emscripten::val filter = model_builder.GetOperand(weight_name);
if (is_conv1d) {
// Reshape weight to 4D for conv1d.
// The weight_shape has been appended 1's, reshape weight operand.
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(weight_shape);
emscripten::val reshape_options = emscripten::val::object();
reshape_options.set("label", node.Name() + "_reshape_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
filter,
emscripten::val::array(new_shape),
reshape_options);
if (!is_nhwc || !is_constant_weight) {
// The weight_shape has been appended 1's, reshape weight operand.
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(weight_shape);
emscripten::val reshape_options = emscripten::val::object();
reshape_options.set("label", node.Name() + "_reshape_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
filter,
emscripten::val::array(new_shape),
reshape_options);
}
}
emscripten::val transpose_options = emscripten::val::object();
if (is_nhwc && !is_constant_weight) {
// For NHWC preferred layout, if the weight is input:
// - Transpose it from iohw -> ohwi for convTranspose.
// - Transpose it from oihw -> ihwo for depthwise conv.
// - Transpose it from oihw -> ohwi for conv.
std::vector<uint32_t> perm(4);
if (op_type == "ConvTranspose" || depthwise) {
perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight
} else {
perm = {0, 2, 3, 1}; // L_0231
}
transpose_options.set("permutation", emscripten::val::array(perm));
transpose_options.set("label", node.Name() + "_transpose_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("transpose", filter, transpose_options);
}
if (op_type == "Conv") {

View file

@ -79,6 +79,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name());
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
@ -101,8 +104,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3;
ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank;
auto insertion_point = new_shape.begin() + 3;
auto insertion_point = new_shape.begin() + insertion_offset;
if (input_shape.size() < webnn_shape_rank) {
// Pad the shape with extra 1's to satisfy WebNN v1's rank requirements.
new_shape.insert(insertion_point, -excess_rank, 1);
@ -121,6 +125,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
reshape_input_options);
}
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
}
output = model_builder.GetBuilder().call<emscripten::val>("instanceNormalization", input, options);
// Reshape back to the original output shape for 3D input.
if (input_shape.size() != 4) {

View file

@ -70,7 +70,11 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
options.set("strides", emscripten::val::array(strides));
const auto dilations = helper.Get("dilations", std::vector<int32_t>{1, 1});
options.set("dilations", emscripten::val::array(dilations));
options.set("layout", emscripten::val("nchw"));
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
} else {
options.set("layout", emscripten::val("nchw"));
}
// Add Padding.
// Usually using autopadding is more efficient than using explicit padding.
@ -89,7 +93,8 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
helper.Get("strides", std::vector<int64_t>{1, 1}),
helper.Get("dilations", std::vector<int64_t>{1, 1}),
auto_pad_type,
pads_out));
pads_out,
model_builder.GetPreferredLayout() == DataLayout::NCHW));
pads = GetVecUint32FromVecInt64(pads_out);
}
// Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width],

View file

@ -120,10 +120,18 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::vector<float> scales;
std::vector<int32_t> sizes;
std::vector<float> scales_hw;
std::vector<int32_t> sizes_hw;
std::vector<int32_t> axes;
std::string scales_name = GetTensorName(input_defs, 2);
const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC;
if (!scales_name.empty()) { // Use scales.
ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales");
std::vector<float> scales_hw = {scales[2], scales[3]};
if (is_nhwc) {
scales_hw = {scales[1], scales[2]};
} else {
scales_hw = {scales[2], scales[3]};
}
options.set("scales", emscripten::val::array(scales_hw));
} else { // Use sizes, we already checked inputs in IsOpSupportedImpl.
std::vector<int64_t> output_sizes;
@ -132,11 +140,19 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::transform(output_sizes.cbegin(), output_sizes.cend(),
std::back_inserter(sizes),
[](int64_t dim) -> int32_t { return SafeInt<int32_t>(dim); });
std::vector<int32_t> sizes_hw = {sizes[2], sizes[3]};
if (is_nhwc) {
sizes_hw = {sizes[1], sizes[2]};
} else {
sizes_hw = {sizes[2], sizes[3]};
}
options.set("sizes", emscripten::val::array(sizes_hw));
}
std::vector<int32_t> axes = {2, 3};
if (is_nhwc) {
axes = {1, 2};
} else {
axes = {2, 3};
}
options.set("axes", emscripten::val::array(axes));
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
@ -205,6 +221,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return false;
}
const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain;
// We want to check if the scales or sizes are not trying to resize on N/C channels here.
if (has_scales) { // We are using scales.
std::vector<float> scales;
@ -212,7 +229,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return false;
float scale_n = scales[0];
float scale_c = scales[1];
float scale_c = is_nhwc ? scales[3] : scales[1];
if (scale_n != 1.0f || scale_c != 1.0f) {
LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1"
<< "Resize of N/C channels are not supported"
@ -222,8 +239,8 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
// For now we only support upscale, so the scale_h and scale_w should be an integer >= 1.
// TODO support ResizeBilinear.
float scale_h = scales[2];
float scale_w = scales[3];
float scale_h = is_nhwc ? scales[1] : scales[2];
float scale_w = is_nhwc ? scales[2] : scales[3];
// Onnx spec requires scale to be a positive float, so we are not checking that here.
if (roundf(scale_h) != scale_h) {
@ -244,11 +261,12 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return false;
auto output_size_n = output_sizes[0];
if (output_size_n != input_shape[0] || output_sizes[1] != input_shape[1]) {
const int c_idx = is_nhwc ? 3 : 1;
if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) {
LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, "
<< "Resize of N/C channels are not supported"
<< ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n
<< ". input_size_c, " << input_shape[1] << ", output_size_c, " << output_sizes[1];
<< ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx];
return false;
}
}

View file

@ -20,10 +20,12 @@ namespace onnxruntime {
namespace webnn {
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const WebnnDeviceType wnn_device_type)
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type)
: graph_viewer_(graph_viewer),
logger_(logger),
wnn_context_(context),
preferred_layout_(preferred_layout),
wnn_device_type_(wnn_device_type) {
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
// is only allowed to be called once.
@ -252,6 +254,64 @@ Status ModelBuilder::AddOperations() {
return Status::OK();
}
Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
const std::string& name, const void* buffer, const size_t size,
const std::vector<uint32_t> shape, const int32_t data_type) {
auto persist_buffer = std::make_unique<uint8_t[]>(size);
uint8_t* dest = persist_buffer.get();
memcpy(dest, buffer, size);
emscripten::val view = emscripten::val::undefined();
emscripten::val desc = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t),
reinterpret_cast<const uint8_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t),
reinterpret_cast<const int8_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t),
reinterpret_cast<const uint16_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(float),
reinterpret_cast<const float*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int32_t),
reinterpret_cast<const int32_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int64_t),
reinterpret_cast<const int64_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint32_t),
reinterpret_cast<const uint32_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint64_t),
reinterpret_cast<const uint64_t*>(dest))};
break;
default:
break;
}
desc.set("dimensions", emscripten::val::array(shape));
emscripten::val operand = emscripten::val::object();
// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
// buffers in JS side. Simply create a copy to fix it.
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
AddOperand(name, operand);
mem_persist_buffers_.push_back(std::move(persist_buffer));
return Status::OK();
}
Status ModelBuilder::RegisterModelOutputs() {
for (const auto* node_arg : graph_viewer_.GetOutputs()) {
ORT_RETURN_IF_ERROR(RegisterModelInputOutput(*node_arg, false /* is_input */));

View file

@ -22,7 +22,8 @@ class IOpBuilder;
class ModelBuilder {
public:
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const WebnnDeviceType wnn_device_type);
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type);
~ModelBuilder() = default;
Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
@ -36,6 +37,15 @@ class ModelBuilder {
const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); }
void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(const std::string& data_type);
// Use the buffers to persist WebNN allocated data like transposed weight.
// It ensures the validity during inference session.
std::vector<std::unique_ptr<uint8_t[]>> mem_persist_buffers_;
// Add a constant operand (allocate persist buffer and move the ownership to mem_persist_buffers_).
Status AddOperandFromPersistMemoryBuffer(
const std::string& name, const void* buffer,
const size_t size, const std::vector<uint32_t> shape, const int32_t data_type);
DataLayout GetPreferredLayout() const { return preferred_layout_; }
WebnnDeviceType GetWebnnDeviceType() const { return wnn_device_type_; }
@ -54,6 +64,7 @@ class ModelBuilder {
emscripten::val wnn_context_ = emscripten::val::undefined();
emscripten::val wnn_builder_ = emscripten::val::undefined();
DataLayout preferred_layout_;
WebnnDeviceType wnn_device_type_;
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
std::vector<std::string> input_names_;

View file

@ -19,9 +19,12 @@ namespace onnxruntime {
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags)
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider} {
// WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend.
if (webnn_device_flags.compare("cpu") == 0) {
preferred_layout_ = DataLayout::NHWC;
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
} else {
preferred_layout_ = DataLayout::NCHW;
if (webnn_device_flags.compare("gpu") == 0) {
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
} else if (webnn_device_flags.compare("npu") == 0) {
@ -209,7 +212,8 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
Node& fused_node = fused_node_and_graph.fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_, wnn_device_type_);
webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
preferred_layout_, wnn_device_type_);
std::unique_ptr<webnn::Model> model;
ORT_RETURN_IF_ERROR(builder.Compile(model));

View file

@ -26,8 +26,7 @@ class WebNNExecutionProvider : public IExecutionProvider {
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_registries*/) const override;
// WebNN EP uses default NCHW layout for all backends.
DataLayout GetPreferredLayout() const override { return DataLayout::NCHW; }
DataLayout GetPreferredLayout() const override { return preferred_layout_; }
// We implement the Compile that takes FusedNodeAndGraph instances.
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
@ -45,6 +44,7 @@ class WebNNExecutionProvider : public IExecutionProvider {
private:
emscripten::val wnn_context_ = emscripten::val::undefined();
DataLayout preferred_layout_;
webnn::WebnnDeviceType wnn_device_type_;
InlinedHashMap<std::string, std::unique_ptr<onnxruntime::webnn::Model>> models_;
ModelMetadefIdGenerator metadef_id_generator_;