mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Fix bug for PRelu on CPU backend. (#17543)
### Description WebNN CPU backend expects slope of PRelu to be a static value. For now, we will not support it. ### Motivation and Context Fallback this case to pass the CI.
This commit is contained in:
parent
4d931edd78
commit
a5302fec93
1 changed files with 22 additions and 0 deletions
|
|
@ -18,6 +18,10 @@ class BinaryOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -50,6 +54,24 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
||||
const Node& node,
|
||||
const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
||||
// XNNPACK prelu operator expects slope to be a static value.
|
||||
// https://github.com/google/XNNPACK/issues/4692
|
||||
// TODO: Remove this check after it is solved.
|
||||
if (op_type == "PRelu" && !Contains(initializers, input_defs[1]->Name()) && device_type == WebnnDeviceType::CPU) {
|
||||
LOGS(logger, VERBOSE) << "The second input (slope) for PRelu must be a constant initializer for WebNN CPU backend.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
Loading…
Reference in a new issue