From 2dcb69685e8b48a24d5a01e34e58bd0c0270dca9 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Thu, 20 Jan 2022 10:06:09 -0800 Subject: [PATCH] support type promotion in binary poerators in eager mode (#10285) Co-authored-by: Cheng Tang --- orttraining/orttraining/eager/opgen/opgen.py | 2 +- .../orttraining/eager/opgen/opgen/atenops.py | 7 ++ .../eager/opgen/opgen/custom_ops.py | 2 + .../eager/opgen/opgen/generator.py | 52 ++++++-- orttraining/orttraining/eager/ort_aten.cpp | 115 ++++++++++++++---- orttraining/orttraining/eager/ort_aten.h | 7 ++ orttraining/orttraining/eager/test/ort_ops.py | 10 ++ 7 files changed, 163 insertions(+), 32 deletions(-) diff --git a/orttraining/orttraining/eager/opgen/opgen.py b/orttraining/orttraining/eager/opgen/opgen.py index 71e0542894..7fe7f02370 100755 --- a/orttraining/orttraining/eager/opgen/opgen.py +++ b/orttraining/orttraining/eager/opgen/opgen.py @@ -22,7 +22,7 @@ parser.add_argument('--custom_ops', action='store_true', help='Whether we are ge args = parser.parse_args() ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module() -ortgen = ORTGen(ops_module.ops, custom_ops=args.custom_ops) +ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops) regdecs_path = args.header_file print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}") diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index e64f4a7d87..b9994b9b96 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -26,6 +26,7 @@ class GeluGrad(ONNXOp): self.domain = kMSDomain ops = {} +type_promotion_ops = [] for binary_op, onnx_op in { 'add': Add('self', Mul('alpha', 'other')), @@ -37,6 +38,7 @@ for binary_op, onnx_op in { name = f'aten::{binary_op}{variant}.{dtype}' if name not in ops: ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op) + type_promotion_ops.append(f'aten::{binary_op}{variant}.{dtype}') for unary_op in [ 'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos', @@ -92,3 +94,8 @@ hand_implemented = { } ops = {**ops, **hand_implemented} +# TODO: this is a temporary whitelist for ops need type promotion +# Need to enhance the support for onnx type constrains to automatically +# resolve whether the op need type promotion. +# Will remove this list in the future. +type_promotion_ops = (*type_promotion_ops, 'aten::gelu_backward') diff --git a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py index 4fe53bbbf9..a49d4be751 100644 --- a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py +++ b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py @@ -13,3 +13,5 @@ from opgen.onnxops import * ops = { 'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB') } + +type_promotion_ops = {} diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index d58eb5d813..7226c504bd 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -105,11 +105,13 @@ class ORTGen: def __init__( self, ops: Optional[Dict[str, ONNXOp]] = None, - custom_ops : bool = False): + custom_ops : bool = False, + type_promotion_ops : List = ()): self._mapped_ops = {} if ops: self.register_many(ops) - self._custom_ops = custom_ops + self._custom_ops = custom_ops + self.type_promotion_ops = type_promotion_ops def register(self, aten_name: str, onnx_op: ONNXOp): self._mapped_ops[aten_name] = onnx_op @@ -310,6 +312,28 @@ class ORTGen: # Perform kernel fission on the ATen op to yield a chain of ORT Invokes # e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y)) + + # whether need type promotion + need_type_promotion = False + if mapped_func.mapped_op_name in self.type_promotion_ops: + types_from_tensor = [] + types_from_scalar = [] + for onnx_op_index, onnx_op in enumerate(ctx.ops): + for op_input in onnx_op.inputs: + if isinstance(op_input, Outputs): + continue + cpp_param = cpp_func.get_parameter(op_input) + if cpp_param: + if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Tensor': + types_from_tensor.append(f'{op_input}.scalar_type()') + elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar': + types_from_scalar.append(f'{op_input}.type()') + if len(types_from_tensor) > 0 or len(types_from_scalar) > 0 : + need_type_promotion = True + writer.writeline('auto promoted_type = PromoteScalarTypesWithCategory({%s}, {%s});' + % (','.join(types_from_tensor), ','.join(types_from_scalar))) + writer.writeline() + for onnx_op_index, onnx_op in enumerate(ctx.ops): # Torch -> ORT inputs for op_input in onnx_op.inputs: @@ -324,6 +348,14 @@ class ORTGen: writer.write(f'auto ort_input_{op_input} = ') writer.writeline(f'create_ort_value(invoker, {op_input});') + if need_type_promotion: + type_func_str = 'type()' if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar' else 'scalar_type()' + writer.write(f'if ({op_input}.{type_func_str} != *promoted_type)') + writer.writeline('{') + writer.push_indent() + writer.writeline(f'ort_input_{op_input} = CastToType(invoker, ort_input_{op_input}, *promoted_type);') + writer.pop_indent() + writer.writeline('}') # Torch kwargs -> ORT attributes attrs = { k:v for k, v in onnx_op.attributes.items() if v and v.value } @@ -403,17 +435,23 @@ class ORTGen: # TODO: Assert return type if not return_alias_info: + # tensor options + writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}') + if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList': + writer.write('[0]') + writer.write('.options()') + if need_type_promotion: + writer.write('.dtype(*promoted_type)') + writer.writeline(';') + writer.writeline('return aten_tensor_from_ort(') writer.push_indent() if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector': writer.writeline(f'{return_outputs},') - writer.writeline(f'{first_param.identifier.value}.options());') + writer.writeline('tensor_options);') else: writer.writeline(f'std::move({return_outputs}[0]),') - writer.write(first_param.identifier.value) - if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList': - writer.write('[0]') - writer.writeline('.options());') + writer.writeline('tensor_options);') writer.pop_indent() return diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 4bbfc2c2db..b9d6630437 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -7,11 +7,15 @@ #include #include +#include +#include + + namespace torch_ort { namespace eager { //#pragma region Helpers - +using NodeAttributes = onnxruntime::NodeAttributes; namespace { inline bool is_device_supported(at::DeviceType type) { return type == at::kORT || type == at::kCPU; @@ -218,6 +222,92 @@ bool IsSupportedType(at::TensorList tensors, const std::vector& return IsSupportedType(tensors[0], valid_types); } +ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){ + switch (dtype){ + case at::kFloat: + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + case at::kDouble: + return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; + case at::kHalf: + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + case at::kBFloat16: + return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; + case at::kInt: + return ONNX_NAMESPACE::TensorProto_DataType_INT32; + case at::kShort: + return ONNX_NAMESPACE::TensorProto_DataType_INT16; + case at::kLong: + return ONNX_NAMESPACE::TensorProto_DataType_INT64; + case at::kBool: + return ONNX_NAMESPACE::TensorProto_DataType_BOOL; + default: + ORT_THROW("Unsupport aten scalar type: ", dtype); + } +} + +static c10::optional PromoteScalarTypes( + const std::vector& types) { + if (types.empty()) { + return at::nullopt; + } + auto st = types[0]; + for (const auto i : c10::irange(1, types.size())) { + st = c10::promoteTypes(st, types[i]); + } + return st; +} + + +c10::optional PromoteScalarTypesWithCategory( + const std::vector& typesFromTensors, + const std::vector& typesFromScalars) { + auto typeFromTensor = PromoteScalarTypes(typesFromTensors); + auto typeFromScalar = PromoteScalarTypes(typesFromScalars); + + auto getTypeCategory = [](c10::ScalarType t) { + if (c10::kBool == t) { + return 1; + } + if (c10::isIntegralType(t, /*includeBool=*/false)) { + return 2; + } + if (c10::isFloatingType(t)) { + return 3; + } + return 0; + }; + + if (c10::nullopt == typeFromScalar) { + return typeFromTensor; + } else if (c10::nullopt == typeFromTensor) { + return typeFromScalar; + } + + auto typeCategoryFromTensor = getTypeCategory(typeFromTensor.value()); + auto typeCategoryFromScalar = getTypeCategory(typeFromScalar.value()); + + if (typeCategoryFromScalar > typeCategoryFromTensor) { + return typeFromScalar; + } + return typeFromTensor; +} + +OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){ + std::vector output(1); + NodeAttributes attrs(2); + attrs["to"] = create_ort_attribute( + "to", GetONNXTensorProtoDataType(type), at::ScalarType::Long); + + auto status = invoker.Invoke("Cast", { + std::move(input), + }, output, &attrs); + + if (!status.IsOK()) + throw std::runtime_error( + "ORT return failure status:" + status.ErrorMessage()); + return output[0]; +} + //#pragma endregion //#pragma region Hand-Implemented ATen Ops @@ -316,29 +406,6 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) { self.options()); } -ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){ - switch (dtype){ - case at::kFloat: - return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; - case at::kDouble: - return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; - case at::kHalf: - return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; - case at::kBFloat16: - return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; - case at::kInt: - return ONNX_NAMESPACE::TensorProto_DataType_INT32; - case at::kShort: - return ONNX_NAMESPACE::TensorProto_DataType_INT16; - case at::kLong: - return ONNX_NAMESPACE::TensorProto_DataType_INT64; - case at::kBool: - return ONNX_NAMESPACE::TensorProto_DataType_BOOL; - default: - ORT_THROW("Unsupport aten scalar type: ", dtype); - } -} - at::Tensor& copy_( at::Tensor& self, const at::Tensor& src, diff --git a/orttraining/orttraining/eager/ort_aten.h b/orttraining/orttraining/eager/ort_aten.h index 1f3190cabe..b01e370a25 100644 --- a/orttraining/orttraining/eager/ort_aten.h +++ b/orttraining/orttraining/eager/ort_aten.h @@ -118,5 +118,12 @@ bool IsSupportedType(c10::optional val, const std::vector& valid_types); +c10::optional PromoteScalarTypesWithCategory( + const std::vector& typesFromTensors, + const std::vector& typesFromScalars); + +ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype); + +OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type); } // namespace eager } // namespace torch_ort \ No newline at end of file diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 84be88d73a..e12b71a2f1 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -17,6 +17,16 @@ class OrtOpTests(unittest.TestCase): cpu_twos = cpu_ones + cpu_ones ort_twos = ort_ones + ort_ones assert torch.allclose(cpu_twos, ort_twos.cpu()) + + def test_type_promotion_add(self): + device = self.get_device() + x = torch.ones(2, 5, dtype = torch.int64) + y = torch.ones(2, 5, dtype = torch.float32) + ort_x = x.to(device) + ort_y = y.to(device) + ort_z = ort_x + ort_y + assert ort_z.dtype == torch.float32 + assert torch.allclose(ort_z.cpu(), (x + y)) def test_add_alpha(self): device = self.get_device()