From c49d5f1d982d3545c9a4be491322a12798530fb2 Mon Sep 17 00:00:00 2001 From: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Date: Tue, 24 Nov 2020 00:01:23 -0800 Subject: [PATCH] Reenable skip flatten/reshape if it's Gemm's input (#5904) --- .../nnapi_builtin/builders/op_builder.cc | 101 +++++++++--------- 1 file changed, 49 insertions(+), 52 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index 427f028acf..d5e0118dba 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -756,75 +756,72 @@ class ReshapeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT; - static bool CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank); + static bool CanSkipReshape(const ModelBuilder& model_builder, const Node& node, size_t input_rank, size_t output_rank); }; void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); } -// We can skip the Reshape if all the output edges satisfies, -// 1. The output of the reshape/flatten is the input 0 of the GEMM/Matmul, +// We can skip the Reshape if all the output edges satisfies both the following conditions +// 1. The output the reshape/flatten is not an output of the graph +// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators, +// and not any other types of operator, // and the input rank >= 2 and output_rank == 2 // This is because Gemm/Matmul will map to ANEURALNETWORKS_FULLY_CONNECTED in NNAPI, // ANEURALNETWORKS_FULLY_CONNECTED will flatten the 2+ dim input 0 to 2d -// 2. Or the output the reshape/flatten is the output of the graph -// (no op in the graph is using the output except can be used by Gemm/Matmul satisfying condition 1 above) // The reason we want to skip Reshape is that Reshape is not running on Hardware (NPU,...) in NNAPI for // some CPU (e.g. Qualcomm SD for now), skipping unnecessary Reshape will prevent context switching // between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution // If we are going to skip the reshape, we will still add correct shape and operand type for the output in // onnxruntime::nnapi::Model. -// If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape -// information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape -/* static */ bool ReshapeOpBuilder::CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank) { - // - // TEMPORARILY DISABLED. Needs refinement. - // - // const auto& output = node.OutputDefs()[0]->Name(); - // // We will go through all the output edges - // for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { - // const auto& op_type = it->GetNode().OpType(); - // // TODO add quantized matmul when reshape support quantized input - // if (op_type != "Gemm" && op_type != "MatMul") { - // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul" - // << " or no op is using the output (output is graph output)" - // << ", output name, " << output - // << " is used by " << op_type; - // return false; - // } +/* static */ bool ReshapeOpBuilder::CanSkipReshape(const ModelBuilder& model_builder, const Node& node, + size_t input_rank, size_t output_rank) { + const auto& output = node.OutputDefs()[0]->Name(); + // We will go through all the output edges + for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + const auto& op_type = it->GetNode().OpType(); + // TODO add quantized matmul when reshape support quantized input + if (op_type != "Gemm" && op_type != "MatMul") { + LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul" + << " or no op is using the output (output is graph output)" + << ", output name, " << output + << " is used by " << op_type; + return false; + } - // // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0 - // if (it->GetDstArgIndex() != 0) { - // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul" - // << ", output name, " << output; - // return false; - // } + // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0 + if (it->GetDstArgIndex() != 0) { + LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul" + << ", output name, " << output; + return false; + } - // // We only support 2d matmul/gemm here - // // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2 - // if (input_rank < 2 || output_rank != 2) { - // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2" - // << ", output name, " << output - // << ", the actual input_rank, " << input_rank - // << ", the actual output_rank, " << output_rank; - // return false; - // } - // } + // We only support 2d matmul/gemm here + // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2 + if (input_rank < 2 || output_rank != 2) { + LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2" + << ", output name, " << output + << ", the actual input_rank, " << input_rank + << ", the actual output_rank, " << output_rank; + return false; + } + } - // // If we reach here, we have either, - // // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here] - // // or - // // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care] - // // we can skip this Reshape now - // LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node [" - // << node.Name() << "] with output, " << output; - // return true; + // If we reach here, we have all the Reshape outputs are used by gemm/matmul, or Reshape has no output edge + // Check if the Reshape output is a graph output, if so we cannot skip the Reshape + // We do not care the case where the Reshape output is a dead end + for (const auto* node_arg : model_builder.GetGraphViewer().GetOutputs()) { + if (node_arg->Name() == output) { + LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can not be skipped when the output is a graph output" + << ", output name, " << output; + return false; + } + } - ORT_UNUSED_PARAMETER(node); - ORT_UNUSED_PARAMETER(input_rank); - ORT_UNUSED_PARAMETER(output_rank); - return false; + LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node [" + << node.Name() << "] with output, " << output; + return true; } /* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder, @@ -842,7 +839,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now) // We will try to see if we the skip the Reshape to prevent context switching between // NNAPI CPU impl and NNAPI hardware accelerator impl - if (CanSkipReshape(node, input_rank, output_rank)) { + if (CanSkipReshape(model_builder, node, input_rank, output_rank)) { // Since reshape can be skipped, only register the dimension and type, with same index and new name const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);