Modify generator for eager to use all inputs for determining promote type. (#12268)

* Sort supported types order so we get a consistently generated order of types.
* Fix promote type to include all the input types and not just the first one.
This commit is contained in:
Wil Brady 2022-07-21 17:21:10 -04:00 committed by GitHub
parent 30ac6e87fa
commit 45c0be8a25
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -283,8 +283,8 @@ class ORTGen:
if i == 0:
writer.push_indent()
cpp_param = cpp_func.get_parameter(op_input)
supported_types = ",".join([type for type in onnx_op.input_types[idx]])
writer.write("!IsSupportedType(%s, {%s})" % (cpp_param.identifier.value, supported_types))
supported_types = ",".join(sorted([type for type in onnx_op.input_types[idx]]))
writer.write(f"!IsSupportedType({cpp_param.identifier.value}, {{{supported_types}}})")
i += 1
writer.writeline(") {")
self._write_cpu_fall_back(writer, mapped_func)
@ -319,12 +319,12 @@ class ORTGen:
for op_input in onnx_op.inputs:
if isinstance(op_input, Outputs):
continue
cpp_param = cpp_func.get_parameter(op_input)
if cpp_param:
if cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Tensor":
types_from_tensor.append(f"{op_input}.scalar_type()")
elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Scalar":
types_from_scalar.append(f"{op_input}.type()")
cpp_param = cpp_func.get_parameter(op_input)
if cpp_param:
if cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Tensor":
types_from_tensor.append(f"{op_input}.scalar_type()")
elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Scalar":
types_from_scalar.append(f"{op_input}.type()")
if len(types_from_tensor) > 0 or len(types_from_scalar) > 0:
need_type_promotion = True
writer.writeline(