Reenable skip flatten/reshape if it's Gemm's input (#5904)

This commit is contained in:
Guoyu Wang 2020-11-24 00:01:23 -08:00 committed by GitHub
parent 782303324e
commit c49d5f1d98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -756,75 +756,72 @@ class ReshapeOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT;
static bool CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank);
static bool CanSkipReshape(const ModelBuilder& model_builder, const Node& node, size_t input_rank, size_t output_rank);
};
void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
}
// We can skip the Reshape if all the output edges satisfies,
// 1. The output of the reshape/flatten is the input 0 of the GEMM/Matmul,
// We can skip the Reshape if all the output edges satisfies both the following conditions
// 1. The output the reshape/flatten is not an output of the graph
// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
// and not any other types of operator,
// and the input rank >= 2 and output_rank == 2
// This is because Gemm/Matmul will map to ANEURALNETWORKS_FULLY_CONNECTED in NNAPI,
// ANEURALNETWORKS_FULLY_CONNECTED will flatten the 2+ dim input 0 to 2d
// 2. Or the output the reshape/flatten is the output of the graph
// (no op in the graph is using the output except can be used by Gemm/Matmul satisfying condition 1 above)
// The reason we want to skip Reshape is that Reshape is not running on Hardware (NPU,...) in NNAPI for
// some CPU (e.g. Qualcomm SD for now), skipping unnecessary Reshape will prevent context switching
// between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution
// If we are going to skip the reshape, we will still add correct shape and operand type for the output in
// onnxruntime::nnapi::Model.
// If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape
// information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank) {
//
// TEMPORARILY DISABLED. Needs refinement.
//
// const auto& output = node.OutputDefs()[0]->Name();
// // We will go through all the output edges
// for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
// const auto& op_type = it->GetNode().OpType();
// // TODO add quantized matmul when reshape support quantized input
// if (op_type != "Gemm" && op_type != "MatMul") {
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
// << " or no op is using the output (output is graph output)"
// << ", output name, " << output
// << " is used by " << op_type;
// return false;
// }
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const ModelBuilder& model_builder, const Node& node,
size_t input_rank, size_t output_rank) {
const auto& output = node.OutputDefs()[0]->Name();
// We will go through all the output edges
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
const auto& op_type = it->GetNode().OpType();
// TODO add quantized matmul when reshape support quantized input
if (op_type != "Gemm" && op_type != "MatMul") {
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
<< " or no op is using the output (output is graph output)"
<< ", output name, " << output
<< " is used by " << op_type;
return false;
}
// // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
// if (it->GetDstArgIndex() != 0) {
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
// << ", output name, " << output;
// return false;
// }
// NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
if (it->GetDstArgIndex() != 0) {
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
<< ", output name, " << output;
return false;
}
// // We only support 2d matmul/gemm here
// // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
// if (input_rank < 2 || output_rank != 2) {
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
// << ", output name, " << output
// << ", the actual input_rank, " << input_rank
// << ", the actual output_rank, " << output_rank;
// return false;
// }
// }
// We only support 2d matmul/gemm here
// And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
if (input_rank < 2 || output_rank != 2) {
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
<< ", output name, " << output
<< ", the actual input_rank, " << input_rank
<< ", the actual output_rank, " << output_rank;
return false;
}
}
// // If we reach here, we have either,
// // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
// // or
// // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
// // we can skip this Reshape now
// LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
// << node.Name() << "] with output, " << output;
// return true;
// If we reach here, we have all the Reshape outputs are used by gemm/matmul, or Reshape has no output edge
// Check if the Reshape output is a graph output, if so we cannot skip the Reshape
// We do not care the case where the Reshape output is a dead end
for (const auto* node_arg : model_builder.GetGraphViewer().GetOutputs()) {
if (node_arg->Name() == output) {
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can not be skipped when the output is a graph output"
<< ", output name, " << output;
return false;
}
}
ORT_UNUSED_PARAMETER(node);
ORT_UNUSED_PARAMETER(input_rank);
ORT_UNUSED_PARAMETER(output_rank);
return false;
LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
<< node.Name() << "] with output, " << output;
return true;
}
/* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder,
@ -842,7 +839,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
// Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now)
// We will try to see if we the skip the Reshape to prevent context switching between
// NNAPI CPU impl and NNAPI hardware accelerator impl
if (CanSkipReshape(node, input_rank, output_rank)) {
if (CanSkipReshape(model_builder, node, input_rank, output_rank)) {
// Since reshape can be skipped, only register the dimension and type, with same index and new name
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);