[Symbolic Shape Infer] add more ops for auto merge (#8824)

As Less/Equal/Greater/LessOrEqual/GreaterOrEqual ops can broadcast
This commit is contained in:
KeDengMS 2021-08-24 16:33:23 -07:00 committed by GitHub
parent 7f1e880649
commit ddd4586a2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1797,7 +1797,7 @@ class SymbolicShapeInference:
if self.auto_merge_:
if node.op_type in [
'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', 'MatMulInteger16', 'Concat',
'Where', 'Sum'
'Where', 'Sum', 'Equal', 'Less', 'Greater', 'LessOrEqual', 'GreaterOrEqual'
]:
shapes = [self._get_shape(node, i) for i in range(len(node.input))]
if node.op_type in ['MatMul', 'MatMulInteger', 'MatMulInteger16']: