mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix Reshape Fusion with graph inputs (#3729)
Use NodeArg to check root input; Add a check on constant initializer
This commit is contained in:
parent
75c24a5fac
commit
f487cc0b28
6 changed files with 62 additions and 11 deletions
|
|
@ -46,7 +46,7 @@ each of which is a constant initializer or a Shape->Gather->Unsqueeze chain with
|
|||
index corresponding to the index of the argument.)
|
||||
|
||||
Before fusion:
|
||||
[Sub-graph Root Node ]
|
||||
[Sub-graph Root]
|
||||
| / \
|
||||
| Shape Shape
|
||||
| | |
|
||||
|
|
@ -61,13 +61,14 @@ Before fusion:
|
|||
Reshape
|
||||
|
||||
After fusion:
|
||||
[Sub-graph Root Node] (Constant Initializer)
|
||||
[Sub-graph Root] (Constant Initializer)
|
||||
\ [0, a, 0, b]
|
||||
\ /
|
||||
Reshape
|
||||
*/
|
||||
bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::Logger& logger) {
|
||||
const Node* p_root = graph_utils::GetInputNode(reshape, 0);
|
||||
// The root could be either a graph input or a node so use node arg to compare.
|
||||
const NodeArg& root_input = *(reshape.InputDefs()[0]);
|
||||
|
||||
const Node* p_concat = graph_utils::GetInputNode(reshape, 1);
|
||||
if (nullptr == p_concat) {
|
||||
|
|
@ -90,11 +91,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
|
|||
enum class NodeType { Unsqueeze, Gather, Shape };
|
||||
std::set<std::pair<NodeType, NodeIndex>> candidates_for_removal;
|
||||
for (int i = 0; i < concat_input_count; ++i) {
|
||||
// First check if the i-th argument is an initializer.
|
||||
// We do not check whether the initializer is constant.
|
||||
// Some model uses constant initializer and some does not.
|
||||
// Here we assume that no one will override the initializer using graph input.
|
||||
if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value)) {
|
||||
// First check if the i-th argument is a constant initializer.
|
||||
if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value, true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -113,7 +111,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
|
|||
const Node& gather = edges[1]->GetNode();
|
||||
const Node& shape = edges[2]->GetNode();
|
||||
|
||||
if (graph_utils::GetInputNode(shape, 0) != p_root) {
|
||||
const NodeArg& shape_input = *(shape.InputDefs()[0]);
|
||||
if (shape_input.Name() != root_input.Name()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -141,7 +141,11 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data) {
|
||||
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data, bool require_constant) {
|
||||
if (require_constant && !graph_utils::IsConstantInitializer(graph, input_arg.Name(), true)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
|
||||
if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam
|
|||
/** Get values of an integer tensor from initializer, and append them to a vector.
|
||||
@remarks only support int32 and int64 tensor. This function does not clear vector before appending.
|
||||
*/
|
||||
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data);
|
||||
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data, bool require_constant = true);
|
||||
|
||||
/** Check Shape of node input or output.
|
||||
@remarks when expected dim value > 0, the dim is expected to known and match the dim value.
|
||||
|
|
|
|||
|
|
@ -1116,6 +1116,27 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionGraphInputsTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_graph_inputs.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Shape"], 1);
|
||||
ASSERT_EQ(op_to_count["Gather"], 1);
|
||||
ASSERT_EQ(op_to_count["Unsqueeze"], 1);
|
||||
ASSERT_EQ(op_to_count["Concat"], 1);
|
||||
ASSERT_EQ(op_to_count["Reshape"], 1);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(GraphTransformationTests, ExpandElimination) {
|
||||
auto model_uri = MODEL_FOLDER "expand_elimination.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
|
|||
|
|
@ -70,3 +70,30 @@ graph = helper.make_graph(
|
|||
|
||||
save_model(graph, 'reshape_fusion_internal_node_is_graph_output.onnx')
|
||||
|
||||
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
helper.make_node("Shape", ["query"], ["shape0_out"], "shape0"),
|
||||
helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0),
|
||||
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Concat", ["a", "unsqueeze0_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node("Reshape", ["doc_word_mask", "concat_out"], ["Result"], "reshape"),
|
||||
],
|
||||
"Reshape_Fusion", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('query', TensorProto.FLOAT, [1, 50]),
|
||||
helper.make_tensor_value_info('doc_word_mask', TensorProto.FLOAT, [1, 200, 50]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('a', TensorProto.INT64, [1], [-1]),
|
||||
helper.make_tensor('indices0', TensorProto.INT64, [], [1]),
|
||||
]
|
||||
)
|
||||
|
||||
save_model(graph, 'reshape_fusion_with_graph_inputs.onnx')
|
||||
|
||||
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue