diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 0d80efbc64..cd5d30638e 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -642,14 +642,18 @@ class SymbolicShapeInference: vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): - ''' + ''' update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches ''' dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence( dst_type) else dst_type.tensor_type src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence( src_type) else src_type.tensor_type - assert dst_tensor_type.elem_type == src_tensor_type.elem_type + if dst_tensor_type.elem_type != src_tensor_type.elem_type: + node_id = node.name if node.name else node.op_type + raise ValueError(f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " + f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " + f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}") if dst_tensor_type.HasField('shape'): for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): if ds[0] != ds[1]: diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index e3935c4e40..ad046cce89 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -36,6 +36,52 @@ class TestSymbolicShapeInference(unittest.TestCase): int_max=100000, guess_output_rank=True) + def test_mismatched_types(self): + graph = helper.make_graph( + [helper.make_node( + "If", + ["x"], + ["out"], + name="if_node", + then_branch=helper.make_graph( + [helper.make_node( + "Constant", + [], + ["one_float"], + value=helper.make_tensor( + "one_float_value", + TensorProto.FLOAT, + [], + [1]), + )], + "then", + [], + [helper.make_tensor_value_info("one_float", TensorProto.FLOAT, [])], + ), + else_branch=helper.make_graph( + [helper.make_node( + "Constant", + [], + ["one_double"], + value=helper.make_tensor( + "one_double", + TensorProto.DOUBLE, + [], + [1]), + )], + "else", + [], + [helper.make_tensor_value_info("one_double", TensorProto.DOUBLE, [])], + ))], + "graph", + [helper.make_tensor_value_info("x", TensorProto.BOOL, [])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [])], + ) + model = helper.make_model(graph, producer_name="test_mismatched_types") + + with self.assertRaisesRegex(ValueError, r"if_node.*FLOAT.*DOUBLE"): + SymbolicShapeInference.infer_shapes(model, auto_merge=True) + class TestSymbolicShapeInferenceForOperators(unittest.TestCase): def _check_shapes(self, graph, inferred_graph, vis): # type: (GraphProto, GraphProto, List[ValueInfoProto]) -> None @@ -238,7 +284,7 @@ class TestSymbolicShapeInferenceForOperators(unittest.TestCase): def test_einsum_transpose(self): self._test_einsum_one_input_impl(['a', 'b'], ['b', 'a'], "ij -> ji") - + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):