mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[NNAPI/CoreML EP] Add Onnx opset 14 support (#7211)
* Add opset 14 support for nnapi/coreml ep * Address CR comments
This commit is contained in:
parent
a98c2ebb8c
commit
afbbeaa30a
3 changed files with 41 additions and 12 deletions
|
|
@ -40,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const;
|
||||
|
||||
virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; }
|
||||
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 13; }
|
||||
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 14; }
|
||||
|
||||
private:
|
||||
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
|
||||
|
|
|
|||
|
|
@ -79,7 +79,9 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer
|
|||
return false;
|
||||
}
|
||||
|
||||
const auto& perm_dims = initializers.at(perm_name)->dims();
|
||||
const auto& perm_tensor = *initializers.at(perm_name);
|
||||
const int64_t* raw_perm = GetTensorInt64Data(perm_tensor);
|
||||
const auto& perm_dims = perm_tensor.dims();
|
||||
if (perm_dims.empty() || perm_dims[0] == 0) {
|
||||
LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty";
|
||||
return false;
|
||||
|
|
@ -94,6 +96,18 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer
|
|||
return false;
|
||||
}
|
||||
|
||||
// CoreML reshape does not support 0 as dimension
|
||||
NodeAttrHelper helper(node);
|
||||
const bool allow_zero = helper.Get("allowzero ", 0) == 1;
|
||||
if (allow_zero) {
|
||||
for (int64_t i = 0; i < perm_dims[0]; i++) {
|
||||
if (raw_perm[i] == 0) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class BaseOpSupportChecker : public IOpSupportChecker {
|
|||
virtual bool HasSupportedInputsImpl(const Node& node) const;
|
||||
|
||||
virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; }
|
||||
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 13; }
|
||||
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 14; }
|
||||
|
||||
private:
|
||||
bool HasSupportedOpSet(const Node& node) const;
|
||||
|
|
@ -389,15 +389,24 @@ bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& init
|
|||
return false;
|
||||
}
|
||||
|
||||
const auto& shape_tensor = *initializers.at(perm_name);
|
||||
const int64_t* raw_shape = GetTensorInt64Data(shape_tensor);
|
||||
const auto size = SafeInt<uint32_t>(shape_tensor.dims()[0]);
|
||||
const auto& perm_tensor = *initializers.at(perm_name);
|
||||
const int64_t* raw_perm = GetTensorInt64Data(perm_tensor);
|
||||
const auto perm_size = SafeInt<uint32_t>(perm_tensor.dims()[0]);
|
||||
|
||||
for (uint32_t i = 0; i < size; i++) {
|
||||
NodeAttrHelper helper(node);
|
||||
const bool allow_zero = helper.Get("allowzero ", 0) == 1;
|
||||
for (uint32_t i = 0; i < perm_size; i++) {
|
||||
// NNAPI reshape does not support 0 as dimension
|
||||
if (raw_shape[i] == 0 && i < input_shape.size() && input_shape[i] == 0) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape doesn't suppport 0 reshape dimension on a dynamic dimension";
|
||||
return false;
|
||||
if (raw_perm[i] == 0) {
|
||||
if (i < input_shape.size() && input_shape[i] == 0) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension on a dynamic dimension";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (allow_zero) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -689,10 +698,16 @@ bool ConvOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
|||
bool ConvOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params) const {
|
||||
const auto& op_type = node.OpType();
|
||||
const bool is_qlinear_conv = (op_type == "QLinearConv");
|
||||
|
||||
// We don't support nhwc com.microsoft.QLinearConv for now
|
||||
if (is_qlinear_conv && node.Domain() == kMSDomain) {
|
||||
LOGS_DEFAULT(VERBOSE) << "com.microsoft.QLinearConv is not supported";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_defs = node.InputDefs();
|
||||
NodeAttrHelper helper(node);
|
||||
|
||||
bool is_qlinear_conv = (op_type == "QLinearConv");
|
||||
size_t w_idx = is_qlinear_conv ? 3 : 1;
|
||||
const auto group = helper.Get("group", 1);
|
||||
const auto weight_name = input_defs[w_idx]->Name();
|
||||
|
|
|
|||
Loading…
Reference in a new issue