symbolic_shape_infer: Improve error message on mismatched types (#9809)

The previous assertion failure was basically impossible to debug.
This commit is contained in:
Gary Miguel 2021-11-19 09:39:26 -08:00 committed by GitHub
parent afd60a274c
commit 9d3c63263b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 3 deletions

View file

@ -642,14 +642,18 @@ class SymbolicShapeInference:
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
'''
'''
update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
'''
dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence(
dst_type) else dst_type.tensor_type
src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence(
src_type) else src_type.tensor_type
assert dst_tensor_type.elem_type == src_tensor_type.elem_type
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
node_id = node.name if node.name else node.op_type
raise ValueError(f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}")
if dst_tensor_type.HasField('shape'):
for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
if ds[0] != ds[1]:

View file

@ -36,6 +36,52 @@ class TestSymbolicShapeInference(unittest.TestCase):
int_max=100000,
guess_output_rank=True)
def test_mismatched_types(self):
graph = helper.make_graph(
[helper.make_node(
"If",
["x"],
["out"],
name="if_node",
then_branch=helper.make_graph(
[helper.make_node(
"Constant",
[],
["one_float"],
value=helper.make_tensor(
"one_float_value",
TensorProto.FLOAT,
[],
[1]),
)],
"then",
[],
[helper.make_tensor_value_info("one_float", TensorProto.FLOAT, [])],
),
else_branch=helper.make_graph(
[helper.make_node(
"Constant",
[],
["one_double"],
value=helper.make_tensor(
"one_double",
TensorProto.DOUBLE,
[],
[1]),
)],
"else",
[],
[helper.make_tensor_value_info("one_double", TensorProto.DOUBLE, [])],
))],
"graph",
[helper.make_tensor_value_info("x", TensorProto.BOOL, [])],
[helper.make_tensor_value_info("out", TensorProto.FLOAT, [])],
)
model = helper.make_model(graph, producer_name="test_mismatched_types")
with self.assertRaisesRegex(ValueError, r"if_node.*FLOAT.*DOUBLE"):
SymbolicShapeInference.infer_shapes(model, auto_merge=True)
class TestSymbolicShapeInferenceForOperators(unittest.TestCase):
def _check_shapes(self, graph, inferred_graph, vis): # type: (GraphProto, GraphProto, List[ValueInfoProto]) -> None
@ -238,7 +284,7 @@ class TestSymbolicShapeInferenceForOperators(unittest.TestCase):
def test_einsum_transpose(self):
self._test_einsum_one_input_impl(['a', 'b'], ['b', 'a'], "ij -> ji")
class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):