support type promotion in binary poerators in eager mode (#10285)

Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Tang, Cheng 2022-01-20 10:06:09 -08:00 committed by GitHub
parent c67594694c
commit 2dcb69685e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 163 additions and 32 deletions

View file

@ -22,7 +22,7 @@ parser.add_argument('--custom_ops', action='store_true', help='Whether we are ge
args = parser.parse_args()
ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
ortgen = ORTGen(ops_module.ops, custom_ops=args.custom_ops)
ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops)
regdecs_path = args.header_file
print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}")

View file

@ -26,6 +26,7 @@ class GeluGrad(ONNXOp):
self.domain = kMSDomain
ops = {}
type_promotion_ops = []
for binary_op, onnx_op in {
'add': Add('self', Mul('alpha', 'other')),
@ -37,6 +38,7 @@ for binary_op, onnx_op in {
name = f'aten::{binary_op}{variant}.{dtype}'
if name not in ops:
ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op)
type_promotion_ops.append(f'aten::{binary_op}{variant}.{dtype}')
for unary_op in [
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
@ -92,3 +94,8 @@ hand_implemented = {
}
ops = {**ops, **hand_implemented}
# TODO: this is a temporary whitelist for ops need type promotion
# Need to enhance the support for onnx type constrains to automatically
# resolve whether the op need type promotion.
# Will remove this list in the future.
type_promotion_ops = (*type_promotion_ops, 'aten::gelu_backward')

View file

@ -13,3 +13,5 @@ from opgen.onnxops import *
ops = {
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB')
}
type_promotion_ops = {}

View file

@ -105,11 +105,13 @@ class ORTGen:
def __init__(
self,
ops: Optional[Dict[str, ONNXOp]] = None,
custom_ops : bool = False):
custom_ops : bool = False,
type_promotion_ops : List = ()):
self._mapped_ops = {}
if ops:
self.register_many(ops)
self._custom_ops = custom_ops
self._custom_ops = custom_ops
self.type_promotion_ops = type_promotion_ops
def register(self, aten_name: str, onnx_op: ONNXOp):
self._mapped_ops[aten_name] = onnx_op
@ -310,6 +312,28 @@ class ORTGen:
# Perform kernel fission on the ATen op to yield a chain of ORT Invokes
# e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y))
# whether need type promotion
need_type_promotion = False
if mapped_func.mapped_op_name in self.type_promotion_ops:
types_from_tensor = []
types_from_scalar = []
for onnx_op_index, onnx_op in enumerate(ctx.ops):
for op_input in onnx_op.inputs:
if isinstance(op_input, Outputs):
continue
cpp_param = cpp_func.get_parameter(op_input)
if cpp_param:
if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Tensor':
types_from_tensor.append(f'{op_input}.scalar_type()')
elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar':
types_from_scalar.append(f'{op_input}.type()')
if len(types_from_tensor) > 0 or len(types_from_scalar) > 0 :
need_type_promotion = True
writer.writeline('auto promoted_type = PromoteScalarTypesWithCategory({%s}, {%s});'
% (','.join(types_from_tensor), ','.join(types_from_scalar)))
writer.writeline()
for onnx_op_index, onnx_op in enumerate(ctx.ops):
# Torch -> ORT inputs
for op_input in onnx_op.inputs:
@ -324,6 +348,14 @@ class ORTGen:
writer.write(f'auto ort_input_{op_input} = ')
writer.writeline(f'create_ort_value(invoker, {op_input});')
if need_type_promotion:
type_func_str = 'type()' if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar' else 'scalar_type()'
writer.write(f'if ({op_input}.{type_func_str} != *promoted_type)')
writer.writeline('{')
writer.push_indent()
writer.writeline(f'ort_input_{op_input} = CastToType(invoker, ort_input_{op_input}, *promoted_type);')
writer.pop_indent()
writer.writeline('}')
# Torch kwargs -> ORT attributes
attrs = { k:v for k, v in onnx_op.attributes.items() if v and v.value }
@ -403,17 +435,23 @@ class ORTGen:
# TODO: Assert return type
if not return_alias_info:
# tensor options
writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}')
if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
writer.write('[0]')
writer.write('.options()')
if need_type_promotion:
writer.write('.dtype(*promoted_type)')
writer.writeline(';')
writer.writeline('return aten_tensor_from_ort(')
writer.push_indent()
if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector':
writer.writeline(f'{return_outputs},')
writer.writeline(f'{first_param.identifier.value}.options());')
writer.writeline('tensor_options);')
else:
writer.writeline(f'std::move({return_outputs}[0]),')
writer.write(first_param.identifier.value)
if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
writer.write('[0]')
writer.writeline('.options());')
writer.writeline('tensor_options);')
writer.pop_indent()
return

View file

@ -7,11 +7,15 @@
#include <ATen/native/CPUFallback.h>
#include <ATen/InferSize.h>
#include <torch/csrc/jit/ir/ir.h>
#include <c10/util/irange.h>
namespace torch_ort {
namespace eager {
//#pragma region Helpers
using NodeAttributes = onnxruntime::NodeAttributes;
namespace {
inline bool is_device_supported(at::DeviceType type) {
return type == at::kORT || type == at::kCPU;
@ -218,6 +222,92 @@ bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>&
return IsSupportedType(tensors[0], valid_types);
}
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
switch (dtype){
case at::kFloat:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
case at::kDouble:
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
case at::kHalf:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
case at::kBFloat16:
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
case at::kInt:
return ONNX_NAMESPACE::TensorProto_DataType_INT32;
case at::kShort:
return ONNX_NAMESPACE::TensorProto_DataType_INT16;
case at::kLong:
return ONNX_NAMESPACE::TensorProto_DataType_INT64;
case at::kBool:
return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
default:
ORT_THROW("Unsupport aten scalar type: ", dtype);
}
}
static c10::optional<at::ScalarType> PromoteScalarTypes(
const std::vector<at::ScalarType>& types) {
if (types.empty()) {
return at::nullopt;
}
auto st = types[0];
for (const auto i : c10::irange(1, types.size())) {
st = c10::promoteTypes(st, types[i]);
}
return st;
}
c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
const std::vector<at::ScalarType>& typesFromTensors,
const std::vector<at::ScalarType>& typesFromScalars) {
auto typeFromTensor = PromoteScalarTypes(typesFromTensors);
auto typeFromScalar = PromoteScalarTypes(typesFromScalars);
auto getTypeCategory = [](c10::ScalarType t) {
if (c10::kBool == t) {
return 1;
}
if (c10::isIntegralType(t, /*includeBool=*/false)) {
return 2;
}
if (c10::isFloatingType(t)) {
return 3;
}
return 0;
};
if (c10::nullopt == typeFromScalar) {
return typeFromTensor;
} else if (c10::nullopt == typeFromTensor) {
return typeFromScalar;
}
auto typeCategoryFromTensor = getTypeCategory(typeFromTensor.value());
auto typeCategoryFromScalar = getTypeCategory(typeFromScalar.value());
if (typeCategoryFromScalar > typeCategoryFromTensor) {
return typeFromScalar;
}
return typeFromTensor;
}
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){
std::vector<OrtValue> output(1);
NodeAttributes attrs(2);
attrs["to"] = create_ort_attribute(
"to", GetONNXTensorProtoDataType(type), at::ScalarType::Long);
auto status = invoker.Invoke("Cast", {
std::move(input),
}, output, &attrs);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
return output[0];
}
//#pragma endregion
//#pragma region Hand-Implemented ATen Ops
@ -316,29 +406,6 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
self.options());
}
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
switch (dtype){
case at::kFloat:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
case at::kDouble:
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
case at::kHalf:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
case at::kBFloat16:
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
case at::kInt:
return ONNX_NAMESPACE::TensorProto_DataType_INT32;
case at::kShort:
return ONNX_NAMESPACE::TensorProto_DataType_INT16;
case at::kLong:
return ONNX_NAMESPACE::TensorProto_DataType_INT64;
case at::kBool:
return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
default:
ORT_THROW("Unsupport aten scalar type: ", dtype);
}
}
at::Tensor& copy_(
at::Tensor& self,
const at::Tensor& src,

View file

@ -118,5 +118,12 @@ bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarTyp
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types);
c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
const std::vector<at::ScalarType>& typesFromTensors,
const std::vector<at::ScalarType>& typesFromScalars);
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype);
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type);
} // namespace eager
} // namespace torch_ort

View file

@ -17,6 +17,16 @@ class OrtOpTests(unittest.TestCase):
cpu_twos = cpu_ones + cpu_ones
ort_twos = ort_ones + ort_ones
assert torch.allclose(cpu_twos, ort_twos.cpu())
def test_type_promotion_add(self):
device = self.get_device()
x = torch.ones(2, 5, dtype = torch.int64)
y = torch.ones(2, 5, dtype = torch.float32)
ort_x = x.to(device)
ort_y = y.to(device)
ort_z = ort_x + ort_y
assert ort_z.dtype == torch.float32
assert torch.allclose(ort_z.cpu(), (x + y))
def test_add_alpha(self):
device = self.get_device()