Add support for setting shape inference function on fused nodes (#7007)

* Add support for setting shape inference function on fused nodes
* Add test for fused node shape inference
This commit is contained in:
Pranav Prakash 2021-05-04 20:32:07 -07:00 committed by GitHub
parent d8cf960412
commit 053bada30f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 1 deletions

View file

@ -36,6 +36,10 @@ struct IndexedSubGraph {
NodeAttributes attributes; ///< Attributes of customized SubGraph/FunctionProto.
std::string doc_string; ///< Doc string of customized SubGraph/FunctionProto.
#if !defined(ORT_MINIMAL_BUILD)
/** Type and shape inference function that can optionally be defined for the fused node */
std::function<void (ONNX_NAMESPACE::InferenceContext&)> type_and_shape_inference_function;
#endif
};
/** Nodes covered by this subgraph. The NodeIndex values are from the parent Graph.*/

View file

@ -152,6 +152,11 @@ static std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const Graph& graph
op_schema->SetDomain(meta_def->domain);
op_schema->SetDoc(meta_def->doc_string);
op_schema->SinceVersion(meta_def->since_version);
if (meta_def->type_and_shape_inference_function) {
op_schema->TypeAndShapeInferenceFunction(meta_def->type_and_shape_inference_function);
}
int i = 0;
for (auto& input : meta_def->inputs) {

View file

@ -126,6 +126,18 @@ class FuseExecutionProvider : public IExecutionProvider {
meta_def->outputs = {"M"};
meta_def->since_version = 1;
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
meta_def->type_and_shape_inference_function = [](::onnx::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
::onnx::TensorShapeProto intermediary_shape;
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
intermediary_shape);
bidirectionalBroadcastShapeInference(
ctx.getInputType(1)->tensor_type().shape(),
intermediary_shape,
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
};
sub_graph->SetMetaDef(std::move(meta_def));
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph)));
return result;
@ -1212,7 +1224,6 @@ TEST(ExecutionProviderTest, FunctionTest) {
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
InferenceSession session_object_2{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object_2.RegisterExecutionProvider(std::move(testCPUExecutionProvider)));
ASSERT_STATUS_OK(
session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
status = session_object_2.Load(model_file_name);
@ -1224,6 +1235,67 @@ TEST(ExecutionProviderTest, FunctionTest) {
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
}
TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) {
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
// FLOAT tensor.
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
inputs.push_back(&input_arg_1);
inputs.push_back(&input_arg_2);
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
outputs.push_back(&output_arg);
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
inputs.clear();
inputs.push_back(&output_arg);
inputs.push_back(&input_arg_3);
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
outputs.clear();
outputs.push_back(&output_arg_2);
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK());
std::string model_file_name = "fused_node_shape_inference_test_graph.onnx";
status = onnxruntime::Model::Save(model, model_file_name);
SessionOptions so;
so.session_logid = "ExecutionProviderTest.ShapeInferenceForFusedFunctionTest";
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(
session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
status = session.Load(model_file_name);
ASSERT_TRUE(status.IsOK());
status = session.Initialize();
ASSERT_TRUE(status.IsOK());
Graph& fused_graph = session.GetMutableGraph();
ASSERT_TRUE(fused_graph.NumberOfNodes() == 1);
auto &fused_node = *fused_graph.Nodes().begin();
ASSERT_TRUE(fused_node.NodeType() == Node::Type::Fused);
ASSERT_TRUE(fused_node.Op()->has_type_and_shape_inference_function());
// Clear shape inference data from output node to verify that assigned inference function is called
auto &fused_node_output = *fused_node.MutableOutputDefs()[0];
fused_node_output.ClearShape();
fused_graph.SetGraphResolveNeeded();
fused_graph.Resolve();
ASSERT_TRUE(fused_node_output.Shape() != nullptr);
ASSERT_TRUE(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape())
== utils::GetTensorShapeFromTensorShapeProto(float_tensor.tensor_type().shape()));
}
TEST(InferenceSessionTests, Test3LayerNestedSubgraph) {
// The main graph contains a 'If' node: 'graph_0__if_0'
// Inside the then-branch of 'graph_0__if_0', there is a nested 'If' node: 'graph_0__if_0__else__if_0'

View file

@ -20,6 +20,10 @@ class InferenceSessionWrapper : public InferenceSession {
return model_->MainGraph();
}
Graph& GetMutableGraph() const {
return model_->MainGraph();
}
const SessionState& GetSessionState() const {
return InferenceSession::GetSessionState();
}