[Android NNAPI EP] Add support for LRN/Grouped Conv ops, fix issues where NNAPI will fall back to CPU (#4582)

* add LRN/Grouped Conv Support, minor changes

* better pool ops sdk version requirement

* reduce string comparision for gemm/matmul ops

* fix nnapi fall back to cpu for softmax

* addressed review comments, correct a small error in the code
This commit is contained in:
gwang-msft 2020-07-23 00:05:39 -07:00 committed by GitHub
parent c5df918744
commit 03ebe33850
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 242 additions and 98 deletions

View file

@ -63,7 +63,7 @@ static Status UnpackInitializerTensor(const onnx::TensorProto& initializer,
CASE_UNPACK(UINT8, uint8_t, int32_data_size);
CASE_UNPACK(UINT16, uint16_t, int32_data_size);
CASE_UNPACK(UINT32, uint32_t, uint64_data_size);
CASE_UNPACK(UINT64, uint64_t, int64_data_size);
CASE_UNPACK(UINT64, uint64_t, uint64_data_size);
CASE_UNPACK(FLOAT16, onnxruntime::MLFloat16, int32_data_size);
CASE_UNPACK(BFLOAT16, onnxruntime::BFloat16, int32_data_size);
default:
@ -999,16 +999,28 @@ class PoolOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
int32_t GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& /* node */) const override {
return 28;
int32_t GetMinSupportedSdkVer(ModelBuilder& model_builder, const Node& /* node */) const override {
return model_builder.UseNCHW() ? 29 : 28;
}
void AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override;
};
bool PoolOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
const auto& op = node.OpType();
if (op == "AveragePool" || op == "MaxPool") {
const auto& op_type = node.OpType();
Shape input_shape;
if (!GetShape(*node.InputDefs()[0], input_shape))
return false;
const auto input_size = input_shape.size();
if (input_size != 4) {
LOGS_DEFAULT(VERBOSE)
<< op_type << " only supportes rank-4 tensor, input ["
<< node.InputDefs()[0]->Name() << "] has actual dim count " << input_size;
return false;
}
if (op_type == "AveragePool" || op_type == "MaxPool") {
NodeAttrHelper helper(node);
const auto count_include_pad = helper.Get("count_include_pad", 0);
@ -1043,18 +1055,9 @@ bool PoolOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const N
LOGS_DEFAULT(VERBOSE) << "Argmax in maxpooling is not supported";
return false;
}
} else if (op == "GlobalAveragePool" || op == "GlobalMaxPool") {
Shape input_shape;
if (!GetShape(*node.InputDefs()[0], input_shape))
return false;
const auto input_size = input_shape.size();
if (input_size != 4) {
LOGS_DEFAULT(VERBOSE)
<< "GlobalAveragePool/GlobalMaxPool Only rank-4 tensor is supported in "
<< node.InputDefs()[0]->Name() << ", actual dim count " << input_size;
return false;
}
} else if (op_type != "GlobalAveragePool" && op_type != "GlobalMaxPool") {
LOGS_DEFAULT(VERBOSE) << "PoolOpBuilder, unknown op: " << op_type;
return false;
}
return true;
@ -1085,19 +1088,20 @@ void PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Nod
}
const auto& output = node.OutputDefs()[0]->Name();
const auto& op = node.OpType();
const auto& op_type = node.OpType();
int32_t op_type;
if (op == "AveragePool" || op == "GlobalAveragePool")
op_type = ANEURALNETWORKS_AVERAGE_POOL_2D;
else // (op == "MaxPool" || op == "GlobalMaxPool")
op_type = ANEURALNETWORKS_MAX_POOL_2D;
int32_t op_code;
bool is_average_pool = op_type == "AveragePool";
if (is_average_pool || op_type == "GlobalAveragePool")
op_code = ANEURALNETWORKS_AVERAGE_POOL_2D;
else // (op_type == "MaxPool" || op_type == "GlobalMaxPool")
op_code = ANEURALNETWORKS_MAX_POOL_2D;
vector<int32_t> onnx_pads, onnx_strides, kernel_shape;
bool use_auto_pad = false;
int32_t nnapi_padding_code = ANEURALNETWORKS_PADDING_VALID;
const auto& input_shape = shaper[input];
if (op == "AveragePool" || op == "MaxPool") {
if (is_average_pool || op_type == "MaxPool") {
const auto auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET"));
kernel_shape = helper.Get("kernel_shape", vector<int32_t>{0, 0});
onnx_strides = helper.Get("strides", vector<int>{1, 1});
@ -1108,7 +1112,7 @@ void PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Nod
onnx_strides, {1, 1} /* onnx_dilations */,
auto_pad_type, use_nchw,
onnx_pads, nnapi_padding_code, use_auto_pad);
} else { // (op == "GlobalAveragePool" || op == "GlobalMaxPool")
} else { // (op_type == "GlobalAveragePool" || op_type == "GlobalMaxPool")
use_auto_pad = true;
nnapi_padding_code = ANEURALNETWORKS_PADDING_VALID;
onnx_strides = vector<int32_t>{1, 1};
@ -1141,15 +1145,16 @@ void PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Nod
input_indices.push_back(model_builder.AddOperandFromScalar(kernel_shape[0]));
input_indices.push_back(model_builder.AddOperandFromScalar(fuse_code));
// TODO support API 28
input_indices.push_back(model_builder.AddOperandFromScalar(use_nchw));
if (model_builder.GetAndroidSdkVer() > 28) { // nchw only supported on api 29+
input_indices.push_back(model_builder.AddOperandFromScalar(use_nchw));
}
shaper.Pool(input,
onnx_pads, onnx_strides, kernel_shape,
use_nchw,
output);
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
model_builder.AddOperation(op_type, input_indices, {output}, {output_operand_type}, {output_is_nhwc});
model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}, {output_is_nhwc});
}
#pragma endregion op_pool
@ -1233,6 +1238,11 @@ bool ConvOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& n
const auto onnx_dilations = helper.Get("dilations", vector<int>{1, 1});
if (onnx_dilations != vector<int>{1, 1}) {
if (group != 1 && tensor.dims()[1] != 1) {
LOGS_DEFAULT(VERBOSE) << "dilation is not supported on grouped conv";
return false;
}
const auto android_sdk_ver = model_builder.GetAndroidSdkVer();
if (android_sdk_ver < 29) {
LOGS_DEFAULT(VERBOSE) << op_type << " dilations is only supported on Android API levle 29+, "
@ -1240,11 +1250,6 @@ bool ConvOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& n
return false;
}
}
if (group != 1 && tensor.dims()[1] != 1) {
LOGS_DEFAULT(VERBOSE) << "group != 1 is not supported";
return false;
}
} else {
LOGS_DEFAULT(VERBOSE) << "The weight of convolution must be known";
return false;
@ -1344,8 +1349,22 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Nod
const auto& weight = input_defs[w_idx]->Name();
const auto& weight_tensor = initializers.at(weight);
bool conv2d = (group == 1);
bool depthwise_conv2d = (weight_tensor.dims()[1] == 1);
bool conv_2d = false,
depthwise_conv_2d = false,
grouped_conv_2d = false;
// For ONNX we only have 1 conv ops
// For NNAPI we have 3
// Input is (N, C, H, W)
// group == 1, --> regular conv
// group != 1 && weight is (M, 1, kH, kW), --> depthwise conv
// group != 1 && weight is (M, C/group, kH, kW), --> grouped conv
if (group == 1)
conv_2d = true;
else if ((weight_tensor.dims()[1] == 1))
depthwise_conv_2d = true;
else
grouped_conv_2d = true;
Shape onnx_weight_shape;
for (auto dim : weight_tensor.dims())
@ -1367,9 +1386,9 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Nod
OperandType onnx_weight_operand_type(onnx_weight_type, onnx_weight_shape, w_scale, w_zero_point);
// Pre-process weights
if (conv2d) {
if (conv_2d || grouped_conv_2d) {
AddInitializerInNewLayout(model_builder, weight, onnx_weight_operand_type, L_0231);
} else { // depthwise_conv2d
} else { // depthwise_conv_2d
AddInitializerInNewLayout(model_builder, weight, onnx_weight_operand_type, L_1230);
}
@ -1391,7 +1410,7 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Nod
if (!hasBias) {
const auto weight_dimen = shaper[weight];
Shape bias_dimen;
if (conv2d)
if (conv_2d || grouped_conv_2d)
bias_dimen = {weight_dimen[0]};
else
bias_dimen = {weight_dimen[3]};
@ -1450,31 +1469,42 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Nod
input_indices.push_back(model_builder.AddOperandFromScalar(onnx_strides[1]));
input_indices.push_back(model_builder.AddOperandFromScalar(onnx_strides[0]));
if (!conv2d && depthwise_conv2d) {
int32_t depthwiseMultiplier = shaper[weight][3] / group;
input_indices.push_back(model_builder.AddOperandFromScalar(depthwiseMultiplier));
if (!conv_2d) {
if (depthwise_conv_2d) {
int32_t depthwiseMultiplier = shaper[weight][3] / group;
input_indices.push_back(model_builder.AddOperandFromScalar(depthwiseMultiplier));
} else { // grouped_conv_2d
input_indices.push_back(model_builder.AddOperandFromScalar(group));
}
}
int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]);
input_indices.push_back(model_builder.AddOperandFromScalar(fuse_code));
// TODO support API 28
input_indices.push_back(model_builder.AddOperandFromScalar(use_nchw));
if (model_builder.GetAndroidSdkVer() > 28) {
input_indices.push_back(model_builder.AddOperandFromScalar(use_nchw));
if (onnx_dilations[1] != 1 || onnx_dilations[0] != 1) {
input_indices.push_back(model_builder.AddOperandFromScalar(onnx_dilations[1]));
input_indices.push_back(model_builder.AddOperandFromScalar(onnx_dilations[0]));
// 1. NNAPI Grouped Conv does not support dilations
// 2. There is a bug in NNAPI (not sure NNAPI itself or Qualcomm Hexagon driver),
// setting dilation (even it is the default (1,1)) will make the execution fall back to CPU
// so if dilations == (1,1) we simply ignore it
if (!grouped_conv_2d &&
(onnx_dilations[1] != 1 || onnx_dilations[0] != 1)) {
input_indices.push_back(model_builder.AddOperandFromScalar(onnx_dilations[1]));
input_indices.push_back(model_builder.AddOperandFromScalar(onnx_dilations[0]));
}
}
int32_t operationCode;
const auto& output = node.OutputDefs()[0]->Name();
if (conv2d) {
operationCode = ANEURALNETWORKS_CONV_2D;
if (conv_2d || grouped_conv_2d) {
operationCode = conv_2d ? ANEURALNETWORKS_CONV_2D
: ANEURALNETWORKS_GROUPED_CONV_2D;
shaper.Conv(input, weight,
onnx_pads, onnx_strides, onnx_dilations,
use_nchw,
output);
} else { // depthwise_conv2d
} else { // depthwise_conv_2d
operationCode = ANEURALNETWORKS_DEPTHWISE_CONV_2D;
shaper.DepthwiseConv(input, weight,
onnx_pads, onnx_strides, onnx_dilations,
@ -1503,10 +1533,10 @@ class CastOpBuilder : public BaseOpBuilder {
bool CastOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
NodeAttrHelper helper(node);
auto to = helper.Get("to", 0);
const auto to = helper.Get("to", 0);
if (to != ONNX_NAMESPACE::TensorProto::FLOAT &&
to != ONNX_NAMESPACE::TensorProto::INT32) {
LOGS_DEFAULT(VERBOSE) << "[Cast] Only support cast to int32 or float";
LOGS_DEFAULT(VERBOSE) << "[Cast] Only support cast to int32 or float, actual to type, " << to;
return false;
}
@ -1553,13 +1583,13 @@ class SoftMaxOpBuilder : public BaseOpBuilder {
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
int32_t GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& /* node */) const override {
return 29;
return 28;
}
void AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override;
};
bool SoftMaxOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
bool SoftMaxOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) {
Shape input_shape;
if (!GetShape(*node.InputDefs()[0], input_shape))
return false;
@ -1570,6 +1600,19 @@ bool SoftMaxOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, cons
<< input_size << "d shape";
return false;
}
const auto android_skd_ver = model_builder.GetAndroidSdkVer();
if (android_skd_ver < 29) {
NodeAttrHelper helper(node);
int32_t axis = helper.Get("axis", 1);
if (axis != 1) {
LOGS_DEFAULT(VERBOSE)
<< "SoftMax only support axis 1 on Android API level: " << android_skd_ver
<< " input axis: " << axis;
return false;
}
}
return true;
}
@ -1577,30 +1620,45 @@ void SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
auto& shaper(model_builder.GetShaper());
const auto& operand_indices(model_builder.GetOperandIndices());
const auto& operand_types(model_builder.GetOperandTypes());
const auto android_skd_ver = model_builder.GetAndroidSdkVer();
NodeAttrHelper helper(node);
auto input = node.InputDefs()[0]->Name();
if (model_builder.IsOperandNHWC(input)) {
// We want to transpose nhwc operand back to nchw before softmax
const auto& nhwc_input = node.InputDefs()[0]->Name();
if (!model_builder.GetNCHWOperand(nhwc_input, input)) {
input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw");
TransposeNHWCToNCHW(model_builder, nhwc_input, input);
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
bool output_is_nhwc = input_is_nhwc;
if (android_skd_ver < 29) {
if (model_builder.IsOperandNHWC(input)) {
output_is_nhwc = false;
// We want to transpose nhwc operand back to nchw before softmax
const auto& nhwc_input = node.InputDefs()[0]->Name();
if (!model_builder.GetNCHWOperand(nhwc_input, input)) {
input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw");
TransposeNHWCToNCHW(model_builder, nhwc_input, input);
}
}
}
int32_t axis = helper.Get("axis", 1);
if (output_is_nhwc) {
const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2};
axis = axis_nchw_to_nhwc[axis];
}
const auto& output = node.OutputDefs()[0]->Name();
float beta = 1.f;
int32_t axis = helper.Get("axis", 1);
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input));
input_indices.push_back(model_builder.AddOperandFromScalar(beta));
input_indices.push_back(model_builder.AddOperandFromScalar(axis));
if (android_skd_ver > 28) {
// you can only specify axis for android api level 29+
input_indices.push_back(model_builder.AddOperandFromScalar(axis));
}
shaper.Identity(input, output);
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
model_builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices, {output},
{output_operand_type}, {false});
{output_operand_type}, {output_is_nhwc});
}
#pragma endregion
@ -1668,11 +1726,12 @@ bool GemmOpBuilder::HasSupportedInputs(const Node& node) {
}
bool GemmOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) {
const auto& op = node.OpType();
const auto& op_type = node.OpType();
const auto input_defs(node.InputDefs());
const auto& initializers(model_builder.GetInitializerTensors());
size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C
if (op == "QLinearMatMul") {
bool is_qlinear_matmul = op_type == "QLinearMatMul";
if (is_qlinear_matmul) {
a_idx = 0;
b_idx = 3;
}
@ -1699,40 +1758,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& n
}
}
if (op == "MatMul") {
// Only support A*B B is an initializer
if (!Contains(initializers, input_defs[b_idx]->Name())) {
LOGS_DEFAULT(VERBOSE) << "B of MatMul must be known";
return false;
}
} else if (op == "QLinearMatMul") {
// For QLinearMatMul, we only support uint8 output now
int32_t output_type;
if (!GetType(*node.OutputDefs()[0], output_type))
return false;
if (output_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
LOGS_DEFAULT(VERBOSE) << "[" << op
<< "] output type: [" << output_type
<< "] is not supported for now";
return false;
}
// Only support A*B B is an initializer
// And all scale/zero points are initializer scalars
if (!Contains(initializers, input_defs[b_idx]->Name())) {
LOGS_DEFAULT(VERBOSE) << "B of MatMul must be known";
return false;
}
// a/b/y_scale
if (!IsQuantizationScaleSupported(model_builder, node, {1, 4, 6}))
return false;
// a/b/y_zero_point
if (!IsQuantizationZeroPointSupported(model_builder, node, {2, 5, 7}))
return false;
} else if (op == "Gemm") {
if (op_type == "Gemm") {
// Only support
// 1. A*B'+C
// 2. A*B+C and B is an initializer
@ -1767,6 +1793,37 @@ bool GemmOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& n
return false;
}
}
} else if (op_type == "MatMul" || is_qlinear_matmul) {
// Only support A*B B is an initializer
if (!Contains(initializers, input_defs[b_idx]->Name())) {
LOGS_DEFAULT(VERBOSE) << "B of MatMul must be known";
return false;
}
if (is_qlinear_matmul) {
// For QLinearMatMul, we only support uint8 output now
int32_t output_type;
if (!GetType(*node.OutputDefs()[0], output_type))
return false;
if (output_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
LOGS_DEFAULT(VERBOSE) << "[" << op_type
<< "] output type: [" << output_type
<< "] is not supported for now";
return false;
}
// All scale/zero points are initializer scalars
// a/b/y_scale
if (!IsQuantizationScaleSupported(model_builder, node, {1, 4, 6}))
return false;
// a/b/y_zero_point
if (!IsQuantizationZeroPointSupported(model_builder, node, {2, 5, 7}))
return false;
}
} else {
LOGS_DEFAULT(VERBOSE) << "GemmOpBuilder, unknown op: " << op_type;
}
return true;
@ -2304,6 +2361,88 @@ void DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
#pragma endregion
#pragma region op_LRN
class LRNOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
int32_t GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& /* node */) const override {
return 28;
}
void AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override;
};
bool LRNOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
Shape input_shape;
if (!GetShape(*node.InputDefs()[0], input_shape))
return false;
const auto input_size = input_shape.size();
if (input_size != 4) {
LOGS_DEFAULT(VERBOSE) << "LRN only support 4d shape, input is "
<< input_size << "d shape";
return false;
}
return true;
}
void LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) {
auto& shaper(model_builder.GetShaper());
const auto& operand_indices(model_builder.GetOperandIndices());
const auto& operand_types(model_builder.GetOperandTypes());
NodeAttrHelper helper(node);
const auto android_skd_ver = model_builder.GetAndroidSdkVer();
auto input = node.InputDefs()[0]->Name();
const auto& output = node.OutputDefs()[0]->Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
if (android_skd_ver < 29) {
// on android api level 28, we need to transpose the nchw input to nhwc
output_is_nhwc = true;
if (!model_builder.IsOperandNHWC(input)) {
const auto& nchw_input = node.InputDefs()[0]->Name();
if (!model_builder.GetNHWCOperand(nchw_input, input)) {
input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc");
TransposeNCHWToNHWC(model_builder, nchw_input, input);
}
}
}
auto alpha = helper.Get("alpha", 0.0001f);
const auto beta = helper.Get("beta", 0.75f);
const auto bias = helper.Get("bias", 1.0f);
const auto size = helper.Get("size", 1);
const auto radius = (size - 1) / 2;
alpha /= size; // NNAPI's alpha is different than ONNX's alpha
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input));
input_indices.push_back(model_builder.AddOperandFromScalar(radius));
input_indices.push_back(model_builder.AddOperandFromScalar(bias));
input_indices.push_back(model_builder.AddOperandFromScalar(alpha));
input_indices.push_back(model_builder.AddOperandFromScalar(beta));
// specify axis is only available on api level >= 29
if (android_skd_ver > 28) {
// ONNX LRN is always performed on C dimension
int32_t axis = output_is_nhwc
? 3 // nhwc
: 1; // nchw
input_indices.push_back(model_builder.AddOperandFromScalar(axis));
}
shaper.Identity(input, output);
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
model_builder.AddOperation(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION,
input_indices, {output}, {output_operand_type}, {output_is_nhwc});
}
#pragma endregion
#pragma region CreateOpBuilders
std::unordered_map<std::string, std::shared_ptr<IOpBuilder>>
@ -2364,6 +2503,7 @@ CreateOpBuilders() {
op_map.emplace("Squeeze", std::make_shared<SqueezeOpBuilder>());
op_map.emplace("QuantizeLinear", std::make_shared<QuantizeLinearOpBuilder>());
op_map.emplace("DequantizeLinear", std::make_shared<DequantizeLinearOpBuilder>());
op_map.emplace("LRN", std::make_shared<LRNOpBuilder>());
return op_map;
}

View file

@ -106,6 +106,7 @@ enum {
ANEURALNETWORKS_GATHER = 51,
ANEURALNETWORKS_GREATER = 53,
ANEURALNETWORKS_GREATER_EQUAL = 54,
ANEURALNETWORKS_GROUPED_CONV_2D = 55,
ANEURALNETWORKS_LESS = 58,
ANEURALNETWORKS_LESS_EQUAL = 59,
ANEURALNETWORKS_LOG = 60,

View file

@ -626,6 +626,9 @@ TEST(ConvTest, Conv2D_group) {
auto expected_vals = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f, 32.0f, 34.0f};
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
// NNAPI EP requires weight to be an initializer
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
}
TEST(ConvTest, ConvDimWithZero) {