Fix Reshape Fusion with graph inputs (#3729)

Use NodeArg to check root input; Add a check on constant initializer
This commit is contained in:
Tianlei Wu 2020-04-28 00:03:16 -07:00 committed by GitHub
parent 75c24a5fac
commit f487cc0b28
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 11 deletions

View file

@ -46,7 +46,7 @@ each of which is a constant initializer or a Shape->Gather->Unsqueeze chain with
index corresponding to the index of the argument.)
Before fusion:
[Sub-graph Root Node ]
[Sub-graph Root]
| / \
| Shape Shape
| | |
@ -61,13 +61,14 @@ Before fusion:
Reshape
After fusion:
[Sub-graph Root Node] (Constant Initializer)
[Sub-graph Root] (Constant Initializer)
\ [0, a, 0, b]
\ /
Reshape
*/
bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::Logger& logger) {
const Node* p_root = graph_utils::GetInputNode(reshape, 0);
// The root could be either a graph input or a node so use node arg to compare.
const NodeArg& root_input = *(reshape.InputDefs()[0]);
const Node* p_concat = graph_utils::GetInputNode(reshape, 1);
if (nullptr == p_concat) {
@ -90,11 +91,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
enum class NodeType { Unsqueeze, Gather, Shape };
std::set<std::pair<NodeType, NodeIndex>> candidates_for_removal;
for (int i = 0; i < concat_input_count; ++i) {
// First check if the i-th argument is an initializer.
// We do not check whether the initializer is constant.
// Some model uses constant initializer and some does not.
// Here we assume that no one will override the initializer using graph input.
if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value)) {
// First check if the i-th argument is a constant initializer.
if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value, true)) {
continue;
}
@ -113,7 +111,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
const Node& gather = edges[1]->GetNode();
const Node& shape = edges[2]->GetNode();
if (graph_utils::GetInputNode(shape, 0) != p_root) {
const NodeArg& shape_input = *(shape.InputDefs()[0]);
if (shape_input.Name() != root_input.Name()) {
return false;
}

View file

@ -141,7 +141,11 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam
return true;
}
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data) {
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data, bool require_constant) {
if (require_constant && !graph_utils::IsConstantInitializer(graph, input_arg.Name(), true)) {
return false;
}
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
return false;

View file

@ -42,7 +42,7 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam
/** Get values of an integer tensor from initializer, and append them to a vector.
@remarks only support int32 and int64 tensor. This function does not clear vector before appending.
*/
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data);
bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector<int64_t>& data, bool require_constant = true);
/** Check Shape of node input or output.
@remarks when expected dim value > 0, the dim is expected to known and match the dim value.

View file

@ -1116,6 +1116,27 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) {
}
}
TEST_F(GraphTransformationTests, ReshapeFusionGraphInputsTest) {
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_graph_inputs.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Shape"], 1);
ASSERT_EQ(op_to_count["Gather"], 1);
ASSERT_EQ(op_to_count["Unsqueeze"], 1);
ASSERT_EQ(op_to_count["Concat"], 1);
ASSERT_EQ(op_to_count["Reshape"], 1);
}
TEST_F(GraphTransformationTests, ExpandElimination) {
auto model_uri = MODEL_FOLDER "expand_elimination.onnx";
std::shared_ptr<Model> model;

View file

@ -70,3 +70,30 @@ graph = helper.make_graph(
save_model(graph, 'reshape_fusion_internal_node_is_graph_output.onnx')
graph = helper.make_graph(
[ # nodes
helper.make_node("Shape", ["query"], ["shape0_out"], "shape0"),
helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0),
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Concat", ["a", "unsqueeze0_out"], ["concat_out"], "concat", axis=0),
helper.make_node("Reshape", ["doc_word_mask", "concat_out"], ["Result"], "reshape"),
],
"Reshape_Fusion", #name
[ # inputs
helper.make_tensor_value_info('query', TensorProto.FLOAT, [1, 50]),
helper.make_tensor_value_info('doc_word_mask', TensorProto.FLOAT, [1, 200, 50]),
],
[ # outputs
helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']),
],
[ # initializers
helper.make_tensor('a', TensorProto.INT64, [1], [-1]),
helper.make_tensor('indices0', TensorProto.INT64, [], [1]),
]
)
save_model(graph, 'reshape_fusion_with_graph_inputs.onnx')