mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
symbolic_shape_infer: Improve error message on mismatched types (#9809)
The previous assertion failure was basically impossible to debug.
This commit is contained in:
parent
afd60a274c
commit
9d3c63263b
2 changed files with 53 additions and 3 deletions
|
|
@ -642,14 +642,18 @@ class SymbolicShapeInference:
|
|||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
||||
|
||||
def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
|
||||
'''
|
||||
'''
|
||||
update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
|
||||
'''
|
||||
dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence(
|
||||
dst_type) else dst_type.tensor_type
|
||||
src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence(
|
||||
src_type) else src_type.tensor_type
|
||||
assert dst_tensor_type.elem_type == src_tensor_type.elem_type
|
||||
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
|
||||
node_id = node.name if node.name else node.op_type
|
||||
raise ValueError(f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
|
||||
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
|
||||
f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}")
|
||||
if dst_tensor_type.HasField('shape'):
|
||||
for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
|
||||
if ds[0] != ds[1]:
|
||||
|
|
|
|||
|
|
@ -36,6 +36,52 @@ class TestSymbolicShapeInference(unittest.TestCase):
|
|||
int_max=100000,
|
||||
guess_output_rank=True)
|
||||
|
||||
def test_mismatched_types(self):
|
||||
graph = helper.make_graph(
|
||||
[helper.make_node(
|
||||
"If",
|
||||
["x"],
|
||||
["out"],
|
||||
name="if_node",
|
||||
then_branch=helper.make_graph(
|
||||
[helper.make_node(
|
||||
"Constant",
|
||||
[],
|
||||
["one_float"],
|
||||
value=helper.make_tensor(
|
||||
"one_float_value",
|
||||
TensorProto.FLOAT,
|
||||
[],
|
||||
[1]),
|
||||
)],
|
||||
"then",
|
||||
[],
|
||||
[helper.make_tensor_value_info("one_float", TensorProto.FLOAT, [])],
|
||||
),
|
||||
else_branch=helper.make_graph(
|
||||
[helper.make_node(
|
||||
"Constant",
|
||||
[],
|
||||
["one_double"],
|
||||
value=helper.make_tensor(
|
||||
"one_double",
|
||||
TensorProto.DOUBLE,
|
||||
[],
|
||||
[1]),
|
||||
)],
|
||||
"else",
|
||||
[],
|
||||
[helper.make_tensor_value_info("one_double", TensorProto.DOUBLE, [])],
|
||||
))],
|
||||
"graph",
|
||||
[helper.make_tensor_value_info("x", TensorProto.BOOL, [])],
|
||||
[helper.make_tensor_value_info("out", TensorProto.FLOAT, [])],
|
||||
)
|
||||
model = helper.make_model(graph, producer_name="test_mismatched_types")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, r"if_node.*FLOAT.*DOUBLE"):
|
||||
SymbolicShapeInference.infer_shapes(model, auto_merge=True)
|
||||
|
||||
|
||||
class TestSymbolicShapeInferenceForOperators(unittest.TestCase):
|
||||
def _check_shapes(self, graph, inferred_graph, vis): # type: (GraphProto, GraphProto, List[ValueInfoProto]) -> None
|
||||
|
|
@ -238,7 +284,7 @@ class TestSymbolicShapeInferenceForOperators(unittest.TestCase):
|
|||
|
||||
def test_einsum_transpose(self):
|
||||
self._test_einsum_one_input_impl(['a', 'b'], ['b', 'a'], "ij -> ji")
|
||||
|
||||
|
||||
|
||||
class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
|
||||
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
|
||||
|
|
|
|||
Loading…
Reference in a new issue