diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 0bc5523779..728b0813be 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1797,7 +1797,7 @@ class SymbolicShapeInference: if self.auto_merge_: if node.op_type in [ 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', 'MatMulInteger16', 'Concat', - 'Where', 'Sum' + 'Where', 'Sum', 'Equal', 'Less', 'Greater', 'LessOrEqual', 'GreaterOrEqual' ]: shapes = [self._get_shape(node, i) for i in range(len(node.input))] if node.op_type in ['MatMul', 'MatMulInteger', 'MatMulInteger16']: