From 053bada30fe6d027063a9ca10efbec08edf69776 Mon Sep 17 00:00:00 2001 From: Pranav Prakash Date: Tue, 4 May 2021 20:32:07 -0700 Subject: [PATCH] Add support for setting shape inference function on fused nodes (#7007) * Add support for setting shape inference function on fused nodes * Add test for fused node shape inference --- .../core/graph/indexed_sub_graph.h | 4 + onnxruntime/core/graph/function.cc | 5 ++ .../test/framework/inference_session_test.cc | 74 ++++++++++++++++++- .../util/include/inference_session_wrapper.h | 4 + 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index b0485c4225..47544c183a 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -36,6 +36,10 @@ struct IndexedSubGraph { NodeAttributes attributes; ///< Attributes of customized SubGraph/FunctionProto. std::string doc_string; ///< Doc string of customized SubGraph/FunctionProto. +#if !defined(ORT_MINIMAL_BUILD) + /** Type and shape inference function that can optionally be defined for the fused node */ + std::function type_and_shape_inference_function; +#endif }; /** Nodes covered by this subgraph. The NodeIndex values are from the parent Graph.*/ diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 94a6ef691c..a726d01f5f 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -152,6 +152,11 @@ static std::unique_ptr CreateSchema(const Graph& graph op_schema->SetDomain(meta_def->domain); op_schema->SetDoc(meta_def->doc_string); op_schema->SinceVersion(meta_def->since_version); + + if (meta_def->type_and_shape_inference_function) { + op_schema->TypeAndShapeInferenceFunction(meta_def->type_and_shape_inference_function); + } + int i = 0; for (auto& input : meta_def->inputs) { diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index b40b24675e..936ddc34c2 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -126,6 +126,18 @@ class FuseExecutionProvider : public IExecutionProvider { meta_def->outputs = {"M"}; meta_def->since_version = 1; meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + meta_def->type_and_shape_inference_function = [](::onnx::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + ::onnx::TensorShapeProto intermediary_shape; + bidirectionalBroadcastShapeInference( + ctx.getInputType(0)->tensor_type().shape(), + ctx.getInputType(1)->tensor_type().shape(), + intermediary_shape); + bidirectionalBroadcastShapeInference( + ctx.getInputType(1)->tensor_type().shape(), + intermediary_shape, + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()); + }; sub_graph->SetMetaDef(std::move(meta_def)); result.push_back(std::make_unique(std::move(sub_graph))); return result; @@ -1212,7 +1224,6 @@ TEST(ExecutionProviderTest, FunctionTest) { VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); InferenceSession session_object_2{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object_2.RegisterExecutionProvider(std::move(testCPUExecutionProvider))); ASSERT_STATUS_OK( session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); status = session_object_2.Load(model_file_name); @@ -1224,6 +1235,67 @@ TEST(ExecutionProviderTest, FunctionTest) { VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } +TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) { + onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + std::vector inputs; + std::vector outputs; + + // FLOAT tensor. + ONNX_NAMESPACE::TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); + auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); + inputs.push_back(&input_arg_1); + inputs.push_back(&input_arg_2); + auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); + outputs.push_back(&output_arg); + graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); + + auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); + inputs.clear(); + inputs.push_back(&output_arg); + inputs.push_back(&input_arg_3); + auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_2); + graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()); + std::string model_file_name = "fused_node_shape_inference_test_graph.onnx"; + status = onnxruntime::Model::Save(model, model_file_name); + + SessionOptions so; + so.session_logid = "ExecutionProviderTest.ShapeInferenceForFusedFunctionTest"; + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK( + session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); + status = session.Load(model_file_name); + ASSERT_TRUE(status.IsOK()); + status = session.Initialize(); + ASSERT_TRUE(status.IsOK()); + + Graph& fused_graph = session.GetMutableGraph(); + ASSERT_TRUE(fused_graph.NumberOfNodes() == 1); + auto &fused_node = *fused_graph.Nodes().begin(); + ASSERT_TRUE(fused_node.NodeType() == Node::Type::Fused); + ASSERT_TRUE(fused_node.Op()->has_type_and_shape_inference_function()); + + // Clear shape inference data from output node to verify that assigned inference function is called + auto &fused_node_output = *fused_node.MutableOutputDefs()[0]; + fused_node_output.ClearShape(); + fused_graph.SetGraphResolveNeeded(); + fused_graph.Resolve(); + + ASSERT_TRUE(fused_node_output.Shape() != nullptr); + ASSERT_TRUE(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape()) + == utils::GetTensorShapeFromTensorShapeProto(float_tensor.tensor_type().shape())); +} + TEST(InferenceSessionTests, Test3LayerNestedSubgraph) { // The main graph contains a 'If' node: 'graph_0__if_0' // Inside the then-branch of 'graph_0__if_0', there is a nested 'If' node: 'graph_0__if_0__else__if_0' diff --git a/onnxruntime/test/util/include/inference_session_wrapper.h b/onnxruntime/test/util/include/inference_session_wrapper.h index 6810d34d38..c7cf979380 100644 --- a/onnxruntime/test/util/include/inference_session_wrapper.h +++ b/onnxruntime/test/util/include/inference_session_wrapper.h @@ -20,6 +20,10 @@ class InferenceSessionWrapper : public InferenceSession { return model_->MainGraph(); } + Graph& GetMutableGraph() const { + return model_->MainGraph(); + } + const SessionState& GetSessionState() const { return InferenceSession::GetSessionState(); }