mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Support Split before opset13 (#19988)
### Description Support Split before opset13, where the `split` is an attribute. ### Motivation and Context Support more models which use the earlier opset.
This commit is contained in:
parent
dfa891a2d8
commit
ea3082edc6
1 changed files with 44 additions and 64 deletions
|
|
@ -28,8 +28,6 @@ class SplitOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
|
||||
int GetMinSupportedOpSet(const Node& node) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -57,53 +55,35 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));
|
||||
options.set("axis", axis);
|
||||
|
||||
if (!GetTensorName(input_defs, 1).empty()) {
|
||||
// Inputs contains optional 'split' input
|
||||
std::vector<int32_t> splits;
|
||||
uint32_t split_count = 0;
|
||||
std::vector<uint32_t> splits = helper.Get("split", std::vector<uint32_t>{});
|
||||
|
||||
// Read either the split count or explicit split lengths from the various attributes over opset versions.
|
||||
if (helper.HasAttr("num_outputs")) {
|
||||
split_count = helper.Get("num_outputs", 0);
|
||||
} else if (GetTensorName(input_defs, 1).size()) {
|
||||
const auto& initializers(model_builder.GetInitializerTensors());
|
||||
const auto& split_tensor = *initializers.at(input_defs[1]->Name());
|
||||
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get split.");
|
||||
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
|
||||
input,
|
||||
emscripten::val::array(splits),
|
||||
options);
|
||||
ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == static_cast<int32_t>(splits.size()),
|
||||
"The size of outputs must be equal to the size of 'split' input.");
|
||||
} else {
|
||||
if (helper.HasAttr("num_outputs")) {
|
||||
const int32_t num_outputs = helper.Get("num_outputs", 1);
|
||||
ORT_RETURN_IF_NOT(num_outputs > 0, "The 'num_outputs' must be a positive integer.");
|
||||
if (input_shape[axis] % num_outputs == 0) {
|
||||
// The 'num_outputs' evenly divide the dim value at 'axis' specified.
|
||||
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
|
||||
input,
|
||||
num_outputs,
|
||||
options);
|
||||
} else {
|
||||
std::vector<int64_t> mapping_split;
|
||||
mapping_split.insert(mapping_split.begin(), num_outputs - 1, input_shape[axis] / num_outputs);
|
||||
mapping_split.insert(mapping_split.end(), input_shape[axis] % num_outputs);
|
||||
std::vector<uint32_t> converted_splits = GetVecUint32FromVecInt64(mapping_split);
|
||||
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
|
||||
input,
|
||||
emscripten::val::array(converted_splits),
|
||||
options);
|
||||
}
|
||||
ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == num_outputs,
|
||||
"The size of outputs must be equal to 'num_outputs'.");
|
||||
} else {
|
||||
// w/o 'split' input for opset 13
|
||||
// Refer to https://github.com/microsoft/onnxruntime/blob/a7ad859e3ab60bddfcf2fefa96bfcb550f0fc04c/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L984-L989
|
||||
// split input stream equally across output streams.
|
||||
const auto& output_defs = node.OutputDefs();
|
||||
const size_t output_count = output_defs.size();
|
||||
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
|
||||
input, static_cast<int32_t>(output_count),
|
||||
options);
|
||||
ORT_RETURN_IF_NOT(output_array["length"].as<size_t>() == output_count,
|
||||
"The size of outputs must be equal to the count of output nodes.");
|
||||
}
|
||||
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get input for split.");
|
||||
} else if (!helper.HasAttr("split")) {
|
||||
split_count = node.OutputDefs().size();
|
||||
}
|
||||
|
||||
// Check that the splits evenly divide.
|
||||
if (split_count > 0 && splits.empty() && input_shape[axis] % split_count != 0) {
|
||||
// Divide inputs into variable size outputs:
|
||||
splits.insert(splits.end(), split_count - 1, gsl::narrow<uint32_t>(input_shape[axis]) / split_count);
|
||||
splits.insert(splits.end(), gsl::narrow<uint32_t>(input_shape[axis]) % split_count);
|
||||
}
|
||||
|
||||
if (splits.empty()) {
|
||||
output_array = model_builder.GetBuilder().call<emscripten::val>(
|
||||
"split", input, split_count, options);
|
||||
} else {
|
||||
output_array = model_builder.GetBuilder().call<emscripten::val>(
|
||||
"split", input, emscripten::val::array(splits), options);
|
||||
}
|
||||
|
||||
for (size_t i = 0, count = output_array["length"].as<size_t>(); i < count; i++) {
|
||||
model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i]));
|
||||
}
|
||||
|
|
@ -112,11 +92,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
|
||||
// Operator support related.
|
||||
|
||||
int SplitOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
|
||||
// Since opset 13, Split has optional 'split' input.
|
||||
return 13;
|
||||
}
|
||||
|
||||
bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
||||
const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
|
|
@ -132,6 +107,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
NodeAttrHelper helper(node);
|
||||
int32_t axis = helper.Get("axis", 0);
|
||||
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));
|
||||
std::vector<uint32_t> split = helper.Get("split", std::vector<uint32_t>{});
|
||||
|
||||
const std::string split_name = GetTensorName(input_defs, 1);
|
||||
// Inputs contain optional 'split' input.
|
||||
|
|
@ -141,7 +117,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return false;
|
||||
}
|
||||
// Values should be >= 0. Sum of the values must be equal to the dim value at 'axis' specified.
|
||||
std::vector<int64_t> split;
|
||||
const auto& split_tensor = *initializers.at(input_defs[1]->Name());
|
||||
if (split_tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
|
||||
LOGS(logger, VERBOSE) << "The type of tensor's element data must be INT64.";
|
||||
|
|
@ -151,18 +126,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
LOGS(logger, VERBOSE) << "Cannot get split.";
|
||||
return false;
|
||||
}
|
||||
int64_t sum = 0;
|
||||
for (size_t i = 0; i < split.size(); i++) {
|
||||
if (split[i] < 0) {
|
||||
LOGS(logger, VERBOSE) << "Value of split should be greater than or equal to 0.";
|
||||
return false;
|
||||
}
|
||||
sum += split[i];
|
||||
}
|
||||
if (sum != input_shape[axis]) {
|
||||
LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (helper.HasAttr("num_outputs")) {
|
||||
// Split has 'num_outputs' attribute when opset is 18.
|
||||
|
|
@ -179,6 +142,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!split.empty()) {
|
||||
int64_t sum = 0;
|
||||
// TODO: Allow 0 size dimensions.
|
||||
// https://github.com/webmachinelearning/webnn/issues/391
|
||||
for (uint32_t split_value : split) {
|
||||
if (split_value <= 0) {
|
||||
LOGS(logger, VERBOSE) << "Value of split should be greater than 0.";
|
||||
return false;
|
||||
}
|
||||
sum += split_value;
|
||||
}
|
||||
if (sum != input_shape[axis]) {
|
||||
LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue