mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Add support for setting shape inference function on fused nodes (#7007)
* Add support for setting shape inference function on fused nodes * Add test for fused node shape inference
This commit is contained in:
parent
d8cf960412
commit
053bada30f
4 changed files with 86 additions and 1 deletions
|
|
@ -36,6 +36,10 @@ struct IndexedSubGraph {
|
|||
NodeAttributes attributes; ///< Attributes of customized SubGraph/FunctionProto.
|
||||
|
||||
std::string doc_string; ///< Doc string of customized SubGraph/FunctionProto.
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/** Type and shape inference function that can optionally be defined for the fused node */
|
||||
std::function<void (ONNX_NAMESPACE::InferenceContext&)> type_and_shape_inference_function;
|
||||
#endif
|
||||
};
|
||||
|
||||
/** Nodes covered by this subgraph. The NodeIndex values are from the parent Graph.*/
|
||||
|
|
|
|||
|
|
@ -152,6 +152,11 @@ static std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const Graph& graph
|
|||
op_schema->SetDomain(meta_def->domain);
|
||||
op_schema->SetDoc(meta_def->doc_string);
|
||||
op_schema->SinceVersion(meta_def->since_version);
|
||||
|
||||
if (meta_def->type_and_shape_inference_function) {
|
||||
op_schema->TypeAndShapeInferenceFunction(meta_def->type_and_shape_inference_function);
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
|
||||
for (auto& input : meta_def->inputs) {
|
||||
|
|
|
|||
|
|
@ -126,6 +126,18 @@ class FuseExecutionProvider : public IExecutionProvider {
|
|||
meta_def->outputs = {"M"};
|
||||
meta_def->since_version = 1;
|
||||
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
|
||||
meta_def->type_and_shape_inference_function = [](::onnx::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
::onnx::TensorShapeProto intermediary_shape;
|
||||
bidirectionalBroadcastShapeInference(
|
||||
ctx.getInputType(0)->tensor_type().shape(),
|
||||
ctx.getInputType(1)->tensor_type().shape(),
|
||||
intermediary_shape);
|
||||
bidirectionalBroadcastShapeInference(
|
||||
ctx.getInputType(1)->tensor_type().shape(),
|
||||
intermediary_shape,
|
||||
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
|
||||
};
|
||||
sub_graph->SetMetaDef(std::move(meta_def));
|
||||
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph)));
|
||||
return result;
|
||||
|
|
@ -1212,7 +1224,6 @@ TEST(ExecutionProviderTest, FunctionTest) {
|
|||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
|
||||
InferenceSession session_object_2{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object_2.RegisterExecutionProvider(std::move(testCPUExecutionProvider)));
|
||||
ASSERT_STATUS_OK(
|
||||
session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
status = session_object_2.Load(model_file_name);
|
||||
|
|
@ -1224,6 +1235,67 @@ TEST(ExecutionProviderTest, FunctionTest) {
|
|||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
}
|
||||
|
||||
TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) {
|
||||
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
||||
// FLOAT tensor.
|
||||
ONNX_NAMESPACE::TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
|
||||
|
||||
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
|
||||
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
|
||||
inputs.push_back(&input_arg_1);
|
||||
inputs.push_back(&input_arg_2);
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
|
||||
outputs.push_back(&output_arg);
|
||||
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
|
||||
|
||||
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
|
||||
inputs.clear();
|
||||
inputs.push_back(&output_arg);
|
||||
inputs.push_back(&input_arg_3);
|
||||
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
|
||||
outputs.clear();
|
||||
outputs.push_back(&output_arg_2);
|
||||
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
std::string model_file_name = "fused_node_shape_inference_test_graph.onnx";
|
||||
status = onnxruntime::Model::Save(model, model_file_name);
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "ExecutionProviderTest.ShapeInferenceForFusedFunctionTest";
|
||||
InferenceSessionWrapper session{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(
|
||||
session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
status = session.Load(model_file_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
Graph& fused_graph = session.GetMutableGraph();
|
||||
ASSERT_TRUE(fused_graph.NumberOfNodes() == 1);
|
||||
auto &fused_node = *fused_graph.Nodes().begin();
|
||||
ASSERT_TRUE(fused_node.NodeType() == Node::Type::Fused);
|
||||
ASSERT_TRUE(fused_node.Op()->has_type_and_shape_inference_function());
|
||||
|
||||
// Clear shape inference data from output node to verify that assigned inference function is called
|
||||
auto &fused_node_output = *fused_node.MutableOutputDefs()[0];
|
||||
fused_node_output.ClearShape();
|
||||
fused_graph.SetGraphResolveNeeded();
|
||||
fused_graph.Resolve();
|
||||
|
||||
ASSERT_TRUE(fused_node_output.Shape() != nullptr);
|
||||
ASSERT_TRUE(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape())
|
||||
== utils::GetTensorShapeFromTensorShapeProto(float_tensor.tensor_type().shape()));
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, Test3LayerNestedSubgraph) {
|
||||
// The main graph contains a 'If' node: 'graph_0__if_0'
|
||||
// Inside the then-branch of 'graph_0__if_0', there is a nested 'If' node: 'graph_0__if_0__else__if_0'
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ class InferenceSessionWrapper : public InferenceSession {
|
|||
return model_->MainGraph();
|
||||
}
|
||||
|
||||
Graph& GetMutableGraph() const {
|
||||
return model_->MainGraph();
|
||||
}
|
||||
|
||||
const SessionState& GetSessionState() const {
|
||||
return InferenceSession::GetSessionState();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue