From 45c0be8a25f2e4f5629c5ceee15cc1c4ca5b6bf1 Mon Sep 17 00:00:00 2001 From: Wil Brady <25513670+WilBrady@users.noreply.github.com> Date: Thu, 21 Jul 2022 17:21:10 -0400 Subject: [PATCH] Modify generator for eager to use all inputs for determining promote type. (#12268) * Sort supported types order so we get a consistently generated order of types. * Fix promote type to include all the input types and not just the first one. --- .../orttraining/eager/opgen/opgen/generator.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 6767158e75..9b1d37f309 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -283,8 +283,8 @@ class ORTGen: if i == 0: writer.push_indent() cpp_param = cpp_func.get_parameter(op_input) - supported_types = ",".join([type for type in onnx_op.input_types[idx]]) - writer.write("!IsSupportedType(%s, {%s})" % (cpp_param.identifier.value, supported_types)) + supported_types = ",".join(sorted([type for type in onnx_op.input_types[idx]])) + writer.write(f"!IsSupportedType({cpp_param.identifier.value}, {{{supported_types}}})") i += 1 writer.writeline(") {") self._write_cpu_fall_back(writer, mapped_func) @@ -319,12 +319,12 @@ class ORTGen: for op_input in onnx_op.inputs: if isinstance(op_input, Outputs): continue - cpp_param = cpp_func.get_parameter(op_input) - if cpp_param: - if cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Tensor": - types_from_tensor.append(f"{op_input}.scalar_type()") - elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Scalar": - types_from_scalar.append(f"{op_input}.type()") + cpp_param = cpp_func.get_parameter(op_input) + if cpp_param: + if cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Tensor": + types_from_tensor.append(f"{op_input}.scalar_type()") + elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Scalar": + types_from_scalar.append(f"{op_input}.type()") if len(types_from_tensor) > 0 or len(types_from_scalar) > 0: need_type_promotion = True writer.writeline(