mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Fuse Pad even if Cast is present in-between (#21640)
### Description This change enhances the existing Pad Fusion to fuse Pad even if a Cast operator is present between Pad and Conv/MaxPool/AveragePool. It keeps the Cast as it is. <pre> /* * Before Fusion: * Pad * | * Cast (Optional) * | * Conv/MaxPool/AveragePool * * After Fusion: * Cast (Optional) * | * Conv/MaxPool/AveragePool */ </pre> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
e6e4047a77
commit
702b2e28e0
1 changed files with 62 additions and 31 deletions
|
|
@ -8,25 +8,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
/*
|
||||
* It matches following pattern:
|
||||
* Pad
|
||||
* |
|
||||
* Conv/MaxPool/AveragePool
|
||||
*/
|
||||
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
// if Pad has input axis, don't fuse it.
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) ||
|
||||
node.GetOutputEdgesCount() != 1 ||
|
||||
node.InputDefs().size() > 3) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (graph.NodeProducesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Node& child_node = *node.OutputNodesBegin();
|
||||
bool VerifyNotCastChild(const Node& child_node) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
|
||||
|
|
@ -54,6 +36,45 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
|
|||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
|
||||
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
|
||||
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());
|
||||
|
||||
for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) {
|
||||
child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]);
|
||||
uint32_t mirrored_child_index = child_index + (child_pads_size / 2);
|
||||
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
|
||||
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
|
||||
}
|
||||
}
|
||||
/*
|
||||
* Before:
|
||||
* Pad
|
||||
* |
|
||||
* Cast (Optional)
|
||||
* |
|
||||
* Conv/MaxPool/AveragePool
|
||||
*
|
||||
* After:
|
||||
* Cast (Optional)
|
||||
* |
|
||||
* Conv/MaxPool/AveragePool
|
||||
*/
|
||||
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
// if Pad has input axis, don't fuse it.
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) ||
|
||||
node.GetOutputEdgesCount() != 1 ||
|
||||
node.InputDefs().size() > 3) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (graph.NodeProducesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const NodeAttributes& pad_attributes = node.GetAttributes();
|
||||
if (pad_attributes.find("mode") != pad_attributes.end() &&
|
||||
pad_attributes.at("mode").s() != "constant") {
|
||||
|
|
@ -83,7 +104,19 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
|
|||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
const Node& child_node = *node.OutputNodesBegin();
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Cast", {1, 6, 9, 13})) {
|
||||
if (child_node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (graph.NodeProducesGraphOutput(child_node)) {
|
||||
return false;
|
||||
}
|
||||
return VerifyNotCastChild(*child_node.OutputNodesBegin());
|
||||
} else {
|
||||
return VerifyNotCastChild(child_node);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -100,8 +133,6 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
|
|||
pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end());
|
||||
}
|
||||
|
||||
assert(static_cast<uint32_t>(pads_values.size()) == (2 * static_cast<uint32_t>(pad_node.InputDefs()[0]->Shape()->dim_size())));
|
||||
|
||||
uint32_t pads_size = static_cast<uint32_t>(pads_values.size());
|
||||
// check if padding is applied only on feature dims
|
||||
if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 ||
|
||||
|
|
@ -115,18 +146,18 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
|
|||
}
|
||||
|
||||
Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index());
|
||||
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
|
||||
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());
|
||||
|
||||
for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) {
|
||||
child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]);
|
||||
uint32_t mirrored_child_index = child_index + (child_pads_size / 2);
|
||||
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
|
||||
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
|
||||
}
|
||||
// We don't need to cast the pad_constant_value because this fusion requires that constant_pad_value
|
||||
// to be zero. See PadFusion::SatisfyCondition for details.
|
||||
Node& target_padding_node = (child_node.OpType() == "Cast") ? *graph.GetNode(child_node.OutputNodesBegin()->Index()) : child_node;
|
||||
UpdatePaddingAttribute(target_padding_node, pads_values, pads_size);
|
||||
|
||||
graph_utils::RemoveNodeOutputEdges(graph, pad_node);
|
||||
graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]);
|
||||
// Un-pad the output shape of Cast node
|
||||
if (child_node.OpType() == "Cast") {
|
||||
auto* cast_output_node_arg = child_node.MutableOutputDefs()[0];
|
||||
cast_output_node_arg->SetShape(*pad_node.MutableInputDefs()[0]->Shape());
|
||||
}
|
||||
graph.RemoveNode(pad_node.Index());
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
return Status::OK();
|
||||
|
|
|
|||
Loading…
Reference in a new issue