mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Reenable skip flatten/reshape if it's Gemm's input (#5904)
This commit is contained in:
parent
782303324e
commit
c49d5f1d98
1 changed files with 49 additions and 52 deletions
|
|
@ -756,75 +756,72 @@ class ReshapeOpBuilder : public BaseOpBuilder {
|
|||
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT;
|
||||
static bool CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank);
|
||||
static bool CanSkipReshape(const ModelBuilder& model_builder, const Node& node, size_t input_rank, size_t output_rank);
|
||||
};
|
||||
|
||||
void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
|
||||
}
|
||||
|
||||
// We can skip the Reshape if all the output edges satisfies,
|
||||
// 1. The output of the reshape/flatten is the input 0 of the GEMM/Matmul,
|
||||
// We can skip the Reshape if all the output edges satisfies both the following conditions
|
||||
// 1. The output the reshape/flatten is not an output of the graph
|
||||
// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
|
||||
// and not any other types of operator,
|
||||
// and the input rank >= 2 and output_rank == 2
|
||||
// This is because Gemm/Matmul will map to ANEURALNETWORKS_FULLY_CONNECTED in NNAPI,
|
||||
// ANEURALNETWORKS_FULLY_CONNECTED will flatten the 2+ dim input 0 to 2d
|
||||
// 2. Or the output the reshape/flatten is the output of the graph
|
||||
// (no op in the graph is using the output except can be used by Gemm/Matmul satisfying condition 1 above)
|
||||
// The reason we want to skip Reshape is that Reshape is not running on Hardware (NPU,...) in NNAPI for
|
||||
// some CPU (e.g. Qualcomm SD for now), skipping unnecessary Reshape will prevent context switching
|
||||
// between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution
|
||||
// If we are going to skip the reshape, we will still add correct shape and operand type for the output in
|
||||
// onnxruntime::nnapi::Model.
|
||||
// If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape
|
||||
// information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape
|
||||
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank) {
|
||||
//
|
||||
// TEMPORARILY DISABLED. Needs refinement.
|
||||
//
|
||||
// const auto& output = node.OutputDefs()[0]->Name();
|
||||
// // We will go through all the output edges
|
||||
// for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
|
||||
// const auto& op_type = it->GetNode().OpType();
|
||||
// // TODO add quantized matmul when reshape support quantized input
|
||||
// if (op_type != "Gemm" && op_type != "MatMul") {
|
||||
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
|
||||
// << " or no op is using the output (output is graph output)"
|
||||
// << ", output name, " << output
|
||||
// << " is used by " << op_type;
|
||||
// return false;
|
||||
// }
|
||||
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const ModelBuilder& model_builder, const Node& node,
|
||||
size_t input_rank, size_t output_rank) {
|
||||
const auto& output = node.OutputDefs()[0]->Name();
|
||||
// We will go through all the output edges
|
||||
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
|
||||
const auto& op_type = it->GetNode().OpType();
|
||||
// TODO add quantized matmul when reshape support quantized input
|
||||
if (op_type != "Gemm" && op_type != "MatMul") {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
|
||||
<< " or no op is using the output (output is graph output)"
|
||||
<< ", output name, " << output
|
||||
<< " is used by " << op_type;
|
||||
return false;
|
||||
}
|
||||
|
||||
// // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
|
||||
// if (it->GetDstArgIndex() != 0) {
|
||||
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
|
||||
// << ", output name, " << output;
|
||||
// return false;
|
||||
// }
|
||||
// NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
|
||||
if (it->GetDstArgIndex() != 0) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
|
||||
<< ", output name, " << output;
|
||||
return false;
|
||||
}
|
||||
|
||||
// // We only support 2d matmul/gemm here
|
||||
// // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
|
||||
// if (input_rank < 2 || output_rank != 2) {
|
||||
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
|
||||
// << ", output name, " << output
|
||||
// << ", the actual input_rank, " << input_rank
|
||||
// << ", the actual output_rank, " << output_rank;
|
||||
// return false;
|
||||
// }
|
||||
// }
|
||||
// We only support 2d matmul/gemm here
|
||||
// And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
|
||||
if (input_rank < 2 || output_rank != 2) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
|
||||
<< ", output name, " << output
|
||||
<< ", the actual input_rank, " << input_rank
|
||||
<< ", the actual output_rank, " << output_rank;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// // If we reach here, we have either,
|
||||
// // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
|
||||
// // or
|
||||
// // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
|
||||
// // we can skip this Reshape now
|
||||
// LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
|
||||
// << node.Name() << "] with output, " << output;
|
||||
// return true;
|
||||
// If we reach here, we have all the Reshape outputs are used by gemm/matmul, or Reshape has no output edge
|
||||
// Check if the Reshape output is a graph output, if so we cannot skip the Reshape
|
||||
// We do not care the case where the Reshape output is a dead end
|
||||
for (const auto* node_arg : model_builder.GetGraphViewer().GetOutputs()) {
|
||||
if (node_arg->Name() == output) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can not be skipped when the output is a graph output"
|
||||
<< ", output name, " << output;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ORT_UNUSED_PARAMETER(node);
|
||||
ORT_UNUSED_PARAMETER(input_rank);
|
||||
ORT_UNUSED_PARAMETER(output_rank);
|
||||
return false;
|
||||
LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
|
||||
<< node.Name() << "] with output, " << output;
|
||||
return true;
|
||||
}
|
||||
|
||||
/* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder,
|
||||
|
|
@ -842,7 +839,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
|
|||
// Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now)
|
||||
// We will try to see if we the skip the Reshape to prevent context switching between
|
||||
// NNAPI CPU impl and NNAPI hardware accelerator impl
|
||||
if (CanSkipReshape(node, input_rank, output_rank)) {
|
||||
if (CanSkipReshape(model_builder, node, input_rank, output_rank)) {
|
||||
// Since reshape can be skipped, only register the dimension and type, with same index and new name
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);
|
||||
|
|
|
|||
Loading…
Reference in a new issue