subgraph type override handling and unit test (#3560)

* unit test for subgraph type override

* unit test - re-wire input properly to subgraph

* update args

Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
ytaous 2020-04-17 19:33:34 -07:00 committed by GitHub
parent 2cb8cb816f
commit ca1bbff5d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 184 additions and 35 deletions

View file

@ -987,7 +987,8 @@ class Graph {
// perform type and shape inferencing on the subgraph and Resolve to validate
static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types);
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
const Graph::ResolveOptions& options);
// Apply type-inference and type-checking to all inputs and initializers:
common::Status TypeCheckInputsAndInitializers();

View file

@ -67,15 +67,17 @@ class NodeArg {
/** Validate and merge type [and shape] info from input_type.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger);
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger);
/** Validate and merge type [and shape] info from node_arg.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger);
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger);
/** Gets this NodeArg as a ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }

View file

@ -208,7 +208,8 @@ void NodeArg::ClearShape() {
}
}
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger) {
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types,
const logging::Logger& logger) {
if (!utils::HasType(node_arg_info_)) {
*node_arg_info_.mutable_type() = input_type;
type_ = DataTypeUtils::ToType(node_arg_info_.type());
@ -229,10 +230,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
const auto& input_tensor_elem_type = input_tensor_type.elem_type();
const auto& current_tensor_elem_type = current_type.tensor_type().elem_type();
if (input_tensor_elem_type != current_tensor_elem_type)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
if (input_tensor_elem_type != current_tensor_elem_type) {
if (override_types) {
DataType inferred_type = DataTypeUtils::ToType(input_type);
// The "SetType" call will override the shape information to empty.
// If the original tensor has shape information, need to set it back.
if (Shape()) {
auto old_shape = *Shape();
SetType(inferred_type);
SetShape(old_shape);
} else {
SetType(inferred_type);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
}
}
if (utils::HasShape(input_tensor_type)) {
auto& current_tensor_type = *current_type.mutable_tensor_type();
@ -249,11 +264,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
const auto& input_tensor_type = input_type.sparse_tensor_type();
const auto input_tensor_elem_type = input_tensor_type.elem_type();
const auto current_tensor_elem_type = current_type.sparse_tensor_type().elem_type();
if (input_tensor_elem_type != current_tensor_elem_type) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
if (override_types) {
DataType inferred_type = DataTypeUtils::ToType(input_type);
if (Shape()) {
auto old_shape = *Shape();
SetType(inferred_type);
SetShape(old_shape);
} else {
SetType(inferred_type);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
}
}
if (utils::HasShape(input_tensor_type)) {
auto& current_tensor_type = *current_type.mutable_sparse_tensor_type();
if (utils::HasShape(current_tensor_type)) {
@ -275,11 +303,12 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
return Status::OK();
}
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger) {
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types,
const logging::Logger& logger) {
auto status = Status::OK();
if (utils::HasType(node_arg.node_arg_info_))
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, logger);
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, override_types, logger);
return status;
}
@ -771,7 +800,7 @@ Graph::Graph(const Model& owning_model,
// so we prefer the shape from the initializer
name_to_type_map[tensor.name()] = t;
if (matching_graph_input != nullptr) {
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, logger));
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, false, logger));
}
} else {
// v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these.
@ -1398,12 +1427,12 @@ bool FullyDefinedType(const TypeProto& type_proto) {
// parameters are the Graph instance for the subgraph, the input types from the control flow node that contains
// the subgraph, and the vector to write the output from the inferencing.
using SubgraphInferencingFunc =
std::function<Status(const Node&, Graph&, const std::vector<const TypeProto*>&, std::vector<const TypeProto*>&)>;
std::function<Status(const Node&, Graph&, const std::vector<const TypeProto*>&, std::vector<const TypeProto*>&, const Graph::ResolveOptions&)>;
class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
public:
GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func)
: node_(node), graph_(graph), inferencing_func_(inferencing_func) {
GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, const Graph::ResolveOptions& options)
: node_(node), graph_(graph), inferencing_func_(inferencing_func), options_(options) {
}
// Perform inferencing on the graph contained in GraphInferencer.
@ -1413,7 +1442,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
const std::vector<const TensorProto*>& /*input_data*/) override {
std::vector<const TypeProto*> output_types;
auto status = inferencing_func_(node_, graph_, input_types, output_types);
auto status = inferencing_func_(node_, graph_, input_types, output_types, options_);
if (status != Status::OK()) {
fail_type_inference("Graph attribute inferencing failed: ", status.ErrorMessage());
@ -1426,6 +1455,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
const Node& node_;
Graph& graph_;
SubgraphInferencingFunc& inferencing_func_;
const Graph::ResolveOptions& options_;
};
// An implementation of the InferenceContext interface required by operator-specific
@ -1436,10 +1466,12 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
public:
InferenceContextImpl(Node& node,
SubgraphInferencingFunc subgraph_inferencing_func,
const Graph& graph) noexcept
const Graph& graph,
const Graph::ResolveOptions& options) noexcept
: node_(node),
subgraph_inferencing_func_(subgraph_inferencing_func),
graph_(graph) {
graph_(graph),
options_(options) {
node_output_types_.resize(node.OutputDefs().size());
}
@ -1500,7 +1532,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
auto* subgraph = node_.GetMutableGraphAttribute(attribute_name);
if (subgraph) {
auto inferencer = onnxruntime::make_unique<GraphInferencerImpl>(node_, *subgraph, subgraph_inferencing_func_);
auto inferencer = onnxruntime::make_unique<GraphInferencerImpl>(node_, *subgraph, subgraph_inferencing_func_, options_);
graph_inferencer = inferencer.get();
graph_inferencers_.push_back(std::move(inferencer));
} else {
@ -1518,11 +1550,13 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
SubgraphInferencingFunc subgraph_inferencing_func_;
std::vector<std::unique_ptr<GraphInferencerImpl>> graph_inferencers_;
const Graph& graph_;
const Graph::ResolveOptions& options_;
};
Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
const std::vector<const TypeProto*>& input_types,
std::vector<const TypeProto*>& output_types) {
std::vector<const TypeProto*>& output_types,
const Graph::ResolveOptions& options) {
auto status = Status::OK();
output_types.clear();
@ -1555,7 +1589,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
const auto& subgraph_input = *subgraph_inputs->at(i);
NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, subgraph.logger_);
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, options.override_types, subgraph.logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
}
@ -1576,7 +1610,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
if (!subgraph_nodearg)
continue;
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, subgraph.logger_);
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, options.override_types, subgraph.logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
}
@ -1588,8 +1622,6 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
// now that we have handled the input types, do the type/shape inferencing for the subgraph
// to flow the type/shape info through it
// TODO: Handle override-type option correctly for subgraphs.
Graph::ResolveOptions options;
status = subgraph.PerformTypeAndShapeInferencing(options);
ORT_RETURN_IF_ERROR(status);
@ -1695,7 +1727,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
// Once that completes, the outputs from the node containing the subgraph will be updated, and the final values
// returned here.
SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes);
InferenceContextImpl context(node, func, *this);
InferenceContextImpl context(node, func, *this, options);
try {
context.RunInferencing();

View file

@ -574,6 +574,116 @@ TEST(Loop, InfiniteLoopTermination) {
terminator_thread.join();
}
// Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as
// type/shape inferencing for subgraph to flow the type/shape info through
// subgraph.PerformTypeAndShapeInferencing(options).
// In this test, main graph has original input/expected output defined as "double" where the subgraph as "float".
// Expectation is types should get propagated properly in subgraph and yield correct output
//
// TODO - when the input/output type in main graph is float16, extra Cast nodes will be added and type input type
// will be changed by InsertCastTransformer for graph execution thus causes type mismatch failure.
// Need to investigate how InsertCastTransformer works in future.
TEST(Loop, SubgraphTypeOverride) {
auto create_subgraph = [](const RunOptions&) {
Model model("Loop subgraph", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
/*
Inputs: iter_num, cond_in, fake_in, loop carried state variables.
iter_num_in cond_in fake_in [outer_scope_0]
(unused) | | |
[Identity] [Identity] [Identity]
| | |
cond_out fake_out loop_var_0_out
*/
// graph inputs types.
TypeProto int64_scalar;
int64_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
int64_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
TypeProto bool_scalar;
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim();
// graph inputs
auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar);
auto& cond_in = graph.GetOrCreateNodeArg("cond_in", &bool_scalar);
auto& fake_in = graph.GetOrCreateNodeArg("fake_in", &float_tensor);
// outer scope value. need type but not shape.
auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor);
// add so that we don't end up with it being considered a graph input
graph.AddOuterScopeNodeArg("outer_scope_0");
// graph outputs
auto& cond_out = graph.GetOrCreateNodeArg("cond_out", &bool_scalar);
auto& fake_out = graph.GetOrCreateNodeArg("fake_out", &float_tensor);
auto& loop_var_0_out = graph.GetOrCreateNodeArg("loop_var_0_out", &float_tensor);
// cond_in -> cond_out
{
inputs = {&cond_in};
outputs = {&cond_out};
graph.AddNode("cond_in_identity", "Identity", "Forward cond_in to cond_out", inputs, outputs);
}
// fake_in -> fake_out
{
inputs = {&fake_in};
outputs = {&fake_out};
graph.AddNode("fake_in_identity", "Identity", "Forward fake_in to fake_out", inputs, outputs);
}
// outer_scope_0 -> loop_var_0_out
{
inputs = {&outer_scope_0};
outputs = {&loop_var_0_out};
graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs);
}
graph.SetInputs({&iter_num_in, &cond_in, &fake_in});
graph.SetOutputs({&cond_out, &fake_out, &loop_var_0_out});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
return graph.ToGraphProto();
};
LoopOpTester test{{}, create_subgraph};
test.AddInput<int64_t>("M", {1}, {1});
test.AddInput<bool>("cond", {1}, {true});
test.AddInput<double>("fake", {1}, {0.f});
test.AddInput<double>("outer_scope_0", {1}, {kOuterNodeAddValue});
test.AddOutput<double>("loop_fake_final", {1}, {0.f});
test.AddOutput<double>("loop_var_0_final", {1, 1}, {kOuterNodeAddValue});
test.AddOutput<int64_t>("outer_scope_0_out", {1}, {int64_t(kOuterNodeAddValue)});
OrtRunOptions session_run_options;
session_run_options.run_tag = "Loop.SubgraphTypeOverride";
Graph::ResolveOptions options;
options.override_types = true;
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider}, &session_run_options, nullptr,
ExecutionMode::ORT_SEQUENTIAL, {}, options);
}
// Regression test that a subgraph input overrides an outer scope value of the same name.
// Replicate issue from https://github.com/onnx/onnx/issues/2082
TEST(Loop, SubgraphInputShadowsOuterScopeValue) {

View file

@ -626,14 +626,15 @@ void OpTester::Run(
const RunOptions* run_options,
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers,
ExecutionMode execution_mode,
const CustomOutputVerifierFn& custom_output_verifier) {
const CustomOutputVerifierFn& custom_output_verifier,
const Graph::ResolveOptions& options) {
SessionOptions so;
so.session_logid = op_;
so.session_log_verbosity_level = 1;
so.execution_mode = execution_mode;
so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off
Run(so, expect_result, expected_failure_string, excluded_provider_types,
run_options, execution_providers, custom_output_verifier);
run_options, execution_providers, custom_output_verifier, options);
}
void OpTester::Run(
@ -643,7 +644,8 @@ void OpTester::Run(
const std::unordered_set<std::string>& excluded_provider_types,
const RunOptions* run_options,
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers,
const CustomOutputVerifierFn& custom_output_verifier) {
const CustomOutputVerifierFn& custom_output_verifier,
const Graph::ResolveOptions& options) {
std::string cur_provider = "not set";
try {
#ifndef NDEBUG
@ -660,12 +662,12 @@ void OpTester::Run(
expect_result == ExpectResult::kExpectFailure) {
// capture possible exceptions from shape inference for invalid testcase
try {
status = graph.Resolve();
status = graph.Resolve(options);
} catch (const std::exception& ex) {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
}
} else {
status = graph.Resolve();
status = graph.Resolve(options);
}
if (!status.IsOK()) {

View file

@ -413,7 +413,8 @@ class OpTester {
const RunOptions* run_options = nullptr,
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers = nullptr,
ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL,
const CustomOutputVerifierFn& custom_output_verifier = {});
const CustomOutputVerifierFn& custom_output_verifier = {},
const Graph::ResolveOptions& resolve_options = {});
void Run(SessionOptions session_options,
ExpectResult expect_result = ExpectResult::kExpectSuccess,
@ -421,7 +422,8 @@ class OpTester {
const std::unordered_set<std::string>& excluded_provider_types = {},
const RunOptions* run_options = nullptr,
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers = nullptr,
const CustomOutputVerifierFn& custom_output_verifier = {});
const CustomOutputVerifierFn& custom_output_verifier = {},
const Graph::ResolveOptions& resolve_options = {});
std::vector<MLValue> GetFetches() { return fetches_; }