mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
support type promotion in binary poerators in eager mode (#10285)
Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
c67594694c
commit
2dcb69685e
7 changed files with 163 additions and 32 deletions
|
|
@ -22,7 +22,7 @@ parser.add_argument('--custom_ops', action='store_true', help='Whether we are ge
|
|||
args = parser.parse_args()
|
||||
ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
|
||||
|
||||
ortgen = ORTGen(ops_module.ops, custom_ops=args.custom_ops)
|
||||
ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops)
|
||||
|
||||
regdecs_path = args.header_file
|
||||
print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class GeluGrad(ONNXOp):
|
|||
self.domain = kMSDomain
|
||||
|
||||
ops = {}
|
||||
type_promotion_ops = []
|
||||
|
||||
for binary_op, onnx_op in {
|
||||
'add': Add('self', Mul('alpha', 'other')),
|
||||
|
|
@ -37,6 +38,7 @@ for binary_op, onnx_op in {
|
|||
name = f'aten::{binary_op}{variant}.{dtype}'
|
||||
if name not in ops:
|
||||
ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op)
|
||||
type_promotion_ops.append(f'aten::{binary_op}{variant}.{dtype}')
|
||||
|
||||
for unary_op in [
|
||||
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
|
||||
|
|
@ -92,3 +94,8 @@ hand_implemented = {
|
|||
}
|
||||
|
||||
ops = {**ops, **hand_implemented}
|
||||
# TODO: this is a temporary whitelist for ops need type promotion
|
||||
# Need to enhance the support for onnx type constrains to automatically
|
||||
# resolve whether the op need type promotion.
|
||||
# Will remove this list in the future.
|
||||
type_promotion_ops = (*type_promotion_ops, 'aten::gelu_backward')
|
||||
|
|
|
|||
|
|
@ -13,3 +13,5 @@ from opgen.onnxops import *
|
|||
ops = {
|
||||
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB')
|
||||
}
|
||||
|
||||
type_promotion_ops = {}
|
||||
|
|
|
|||
|
|
@ -105,11 +105,13 @@ class ORTGen:
|
|||
def __init__(
|
||||
self,
|
||||
ops: Optional[Dict[str, ONNXOp]] = None,
|
||||
custom_ops : bool = False):
|
||||
custom_ops : bool = False,
|
||||
type_promotion_ops : List = ()):
|
||||
self._mapped_ops = {}
|
||||
if ops:
|
||||
self.register_many(ops)
|
||||
self._custom_ops = custom_ops
|
||||
self._custom_ops = custom_ops
|
||||
self.type_promotion_ops = type_promotion_ops
|
||||
|
||||
def register(self, aten_name: str, onnx_op: ONNXOp):
|
||||
self._mapped_ops[aten_name] = onnx_op
|
||||
|
|
@ -310,6 +312,28 @@ class ORTGen:
|
|||
|
||||
# Perform kernel fission on the ATen op to yield a chain of ORT Invokes
|
||||
# e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y))
|
||||
|
||||
# whether need type promotion
|
||||
need_type_promotion = False
|
||||
if mapped_func.mapped_op_name in self.type_promotion_ops:
|
||||
types_from_tensor = []
|
||||
types_from_scalar = []
|
||||
for onnx_op_index, onnx_op in enumerate(ctx.ops):
|
||||
for op_input in onnx_op.inputs:
|
||||
if isinstance(op_input, Outputs):
|
||||
continue
|
||||
cpp_param = cpp_func.get_parameter(op_input)
|
||||
if cpp_param:
|
||||
if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Tensor':
|
||||
types_from_tensor.append(f'{op_input}.scalar_type()')
|
||||
elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar':
|
||||
types_from_scalar.append(f'{op_input}.type()')
|
||||
if len(types_from_tensor) > 0 or len(types_from_scalar) > 0 :
|
||||
need_type_promotion = True
|
||||
writer.writeline('auto promoted_type = PromoteScalarTypesWithCategory({%s}, {%s});'
|
||||
% (','.join(types_from_tensor), ','.join(types_from_scalar)))
|
||||
writer.writeline()
|
||||
|
||||
for onnx_op_index, onnx_op in enumerate(ctx.ops):
|
||||
# Torch -> ORT inputs
|
||||
for op_input in onnx_op.inputs:
|
||||
|
|
@ -324,6 +348,14 @@ class ORTGen:
|
|||
|
||||
writer.write(f'auto ort_input_{op_input} = ')
|
||||
writer.writeline(f'create_ort_value(invoker, {op_input});')
|
||||
if need_type_promotion:
|
||||
type_func_str = 'type()' if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar' else 'scalar_type()'
|
||||
writer.write(f'if ({op_input}.{type_func_str} != *promoted_type)')
|
||||
writer.writeline('{')
|
||||
writer.push_indent()
|
||||
writer.writeline(f'ort_input_{op_input} = CastToType(invoker, ort_input_{op_input}, *promoted_type);')
|
||||
writer.pop_indent()
|
||||
writer.writeline('}')
|
||||
|
||||
# Torch kwargs -> ORT attributes
|
||||
attrs = { k:v for k, v in onnx_op.attributes.items() if v and v.value }
|
||||
|
|
@ -403,17 +435,23 @@ class ORTGen:
|
|||
# TODO: Assert return type
|
||||
|
||||
if not return_alias_info:
|
||||
# tensor options
|
||||
writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}')
|
||||
if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
|
||||
writer.write('[0]')
|
||||
writer.write('.options()')
|
||||
if need_type_promotion:
|
||||
writer.write('.dtype(*promoted_type)')
|
||||
writer.writeline(';')
|
||||
|
||||
writer.writeline('return aten_tensor_from_ort(')
|
||||
writer.push_indent()
|
||||
if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector':
|
||||
writer.writeline(f'{return_outputs},')
|
||||
writer.writeline(f'{first_param.identifier.value}.options());')
|
||||
writer.writeline('tensor_options);')
|
||||
else:
|
||||
writer.writeline(f'std::move({return_outputs}[0]),')
|
||||
writer.write(first_param.identifier.value)
|
||||
if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
|
||||
writer.write('[0]')
|
||||
writer.writeline('.options());')
|
||||
writer.writeline('tensor_options);')
|
||||
writer.pop_indent()
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,15 @@
|
|||
#include <ATen/native/CPUFallback.h>
|
||||
#include <ATen/InferSize.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
||||
//#pragma region Helpers
|
||||
|
||||
using NodeAttributes = onnxruntime::NodeAttributes;
|
||||
namespace {
|
||||
inline bool is_device_supported(at::DeviceType type) {
|
||||
return type == at::kORT || type == at::kCPU;
|
||||
|
|
@ -218,6 +222,92 @@ bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>&
|
|||
return IsSupportedType(tensors[0], valid_types);
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
|
||||
switch (dtype){
|
||||
case at::kFloat:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
|
||||
case at::kDouble:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
case at::kHalf:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
|
||||
case at::kBFloat16:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
|
||||
case at::kInt:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
case at::kShort:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT16;
|
||||
case at::kLong:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
case at::kBool:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
|
||||
default:
|
||||
ORT_THROW("Unsupport aten scalar type: ", dtype);
|
||||
}
|
||||
}
|
||||
|
||||
static c10::optional<at::ScalarType> PromoteScalarTypes(
|
||||
const std::vector<at::ScalarType>& types) {
|
||||
if (types.empty()) {
|
||||
return at::nullopt;
|
||||
}
|
||||
auto st = types[0];
|
||||
for (const auto i : c10::irange(1, types.size())) {
|
||||
st = c10::promoteTypes(st, types[i]);
|
||||
}
|
||||
return st;
|
||||
}
|
||||
|
||||
|
||||
c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
|
||||
const std::vector<at::ScalarType>& typesFromTensors,
|
||||
const std::vector<at::ScalarType>& typesFromScalars) {
|
||||
auto typeFromTensor = PromoteScalarTypes(typesFromTensors);
|
||||
auto typeFromScalar = PromoteScalarTypes(typesFromScalars);
|
||||
|
||||
auto getTypeCategory = [](c10::ScalarType t) {
|
||||
if (c10::kBool == t) {
|
||||
return 1;
|
||||
}
|
||||
if (c10::isIntegralType(t, /*includeBool=*/false)) {
|
||||
return 2;
|
||||
}
|
||||
if (c10::isFloatingType(t)) {
|
||||
return 3;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
if (c10::nullopt == typeFromScalar) {
|
||||
return typeFromTensor;
|
||||
} else if (c10::nullopt == typeFromTensor) {
|
||||
return typeFromScalar;
|
||||
}
|
||||
|
||||
auto typeCategoryFromTensor = getTypeCategory(typeFromTensor.value());
|
||||
auto typeCategoryFromScalar = getTypeCategory(typeFromScalar.value());
|
||||
|
||||
if (typeCategoryFromScalar > typeCategoryFromTensor) {
|
||||
return typeFromScalar;
|
||||
}
|
||||
return typeFromTensor;
|
||||
}
|
||||
|
||||
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){
|
||||
std::vector<OrtValue> output(1);
|
||||
NodeAttributes attrs(2);
|
||||
attrs["to"] = create_ort_attribute(
|
||||
"to", GetONNXTensorProtoDataType(type), at::ScalarType::Long);
|
||||
|
||||
auto status = invoker.Invoke("Cast", {
|
||||
std::move(input),
|
||||
}, output, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
return output[0];
|
||||
}
|
||||
|
||||
//#pragma endregion
|
||||
|
||||
//#pragma region Hand-Implemented ATen Ops
|
||||
|
|
@ -316,29 +406,6 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
|
|||
self.options());
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
|
||||
switch (dtype){
|
||||
case at::kFloat:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
|
||||
case at::kDouble:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
case at::kHalf:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
|
||||
case at::kBFloat16:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
|
||||
case at::kInt:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
case at::kShort:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT16;
|
||||
case at::kLong:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
case at::kBool:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
|
||||
default:
|
||||
ORT_THROW("Unsupport aten scalar type: ", dtype);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor& copy_(
|
||||
at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
|
|
|
|||
|
|
@ -118,5 +118,12 @@ bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarTyp
|
|||
|
||||
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types);
|
||||
|
||||
c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
|
||||
const std::vector<at::ScalarType>& typesFromTensors,
|
||||
const std::vector<at::ScalarType>& typesFromScalars);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype);
|
||||
|
||||
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type);
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
|
|
@ -17,6 +17,16 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_twos = cpu_ones + cpu_ones
|
||||
ort_twos = ort_ones + ort_ones
|
||||
assert torch.allclose(cpu_twos, ort_twos.cpu())
|
||||
|
||||
def test_type_promotion_add(self):
|
||||
device = self.get_device()
|
||||
x = torch.ones(2, 5, dtype = torch.int64)
|
||||
y = torch.ones(2, 5, dtype = torch.float32)
|
||||
ort_x = x.to(device)
|
||||
ort_y = y.to(device)
|
||||
ort_z = ort_x + ort_y
|
||||
assert ort_z.dtype == torch.float32
|
||||
assert torch.allclose(ort_z.cpu(), (x + y))
|
||||
|
||||
def test_add_alpha(self):
|
||||
device = self.get_device()
|
||||
|
|
|
|||
Loading…
Reference in a new issue