support print in ort eager mode (#9825)

* fix reshape implementation in eager mode

* test code

* update opgen script to support fallback to cpu

* enhance the eager backend to support torch cpu fallback

* add more testes

* disable the printensor test for now, as we need to erge a PR to pytorch first
This commit is contained in:
Tang, Cheng 2021-11-29 08:03:57 -08:00 committed by GitHub
parent 1e9e57df3e
commit 37bf46eb19
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1142 additions and 494 deletions

View file

@ -21,6 +21,24 @@ for schema in defs.get_all_schemas_with_history():
onnx_ops[key].since_version < schema.since_version:
onnx_ops[key] = schema
def convert_to_aten_type(onnx_type_strs):
type_map = {'tensor(float16)' : 'at::kHalf',
'tensor(float)' : 'at::kFloat',
'tensor(double)' : 'at::kDouble',
'tensor(bfloat16)' : 'at::kBFloat16',
'tensor(int32)' : 'at::kInt',
'tensor(int16)' : 'at::kShort',
'tensor(int8)' : 'at::kByte',
'tensor(int64)' : 'at::kLong',
'tensor(bool)' : 'at::kBool',
}
result = set({})
for onnx_type in onnx_type_strs:
# ONNX has more types, like tensor(string), ignore those types at this momemnt
if onnx_type in type_map:
result.add(type_map[onnx_type])
return result
with open(out_file, 'wt') as fp:
def write(s): fp.write(s)
def writeline(s = ''): fp.write(s + '\n')
@ -54,9 +72,17 @@ with open(out_file, 'wt') as fp:
writeline('):')
write(f' super().__init__(\'{schema.name}\', {len(schema.outputs)}')
writeline(',')
write(' ')
input_types = []
for input in schema.inputs:
write(f', {input.name}')
input_types.append(convert_to_aten_type(input.types))
write(str(input_types))
if len(schema.inputs) > 0:
writeline(',')
input_names = ','.join([input.name for input in schema.inputs])
write(f' {input_names}')
if len(schema.attributes) > 0:
writeline(',')

View file

@ -4,7 +4,7 @@ from opgen.generator import \
ORTGen as ORTGen, \
ONNXOp as ONNXOp, \
SignatureOnly as SignatureOnly, \
MakeFallthrough as MakeFallthrough
MakeTorchFallback as MakeTorchFallback
from opgen.onnxops import *
@ -12,17 +12,17 @@ kMSDomain = 'onnxruntime::kMSDomain'
class ReluGrad(ONNXOp):
def __init__(self, dY, X):
super().__init__('ReluGrad', 1, dY, X)
super().__init__('ReluGrad', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], dY, X)
self.domain = kMSDomain
class Gelu(ONNXOp):
def __init__(self, X):
super().__init__('Gelu', 1, X)
super().__init__('Gelu', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], X)
self.domain = kMSDomain
class GeluGrad(ONNXOp):
def __init__(self, dY, X):
super().__init__('GeluGrad', 1, dY, X)
super().__init__('GeluGrad', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], dY, X)
self.domain = kMSDomain
ops = {
@ -33,6 +33,7 @@ ops = {
'aten::copy_': SignatureOnly(),
'aten::_reshape_alias': SignatureOnly(),
'aten::view': SignatureOnly(),
'aten::_copy_from_and_resize' : SignatureOnly(),
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
'aten::t': Transpose('self'),
@ -48,7 +49,20 @@ ops = {
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
'aten::gelu' : Gelu('self'),
'aten::gelu_backward' : GeluGrad('grad', 'self')
'aten::gelu_backward' : GeluGrad('grad', 'self'),
'aten::max' : ReduceMax('self', keepdims=1),
'aten::min' : ReduceMin('self', keepdims=1),
'aten::ne.Scalar':MakeTorchFallback(),
'aten::ne.Scalar_out': MakeTorchFallback(),
'aten::ne.Tensor_out': MakeTorchFallback(),
'aten::eq.Tensor': MakeTorchFallback(),
'aten::eq.Tensor_out':MakeTorchFallback(),
'aten::bitwise_and.Tensor_out' : MakeTorchFallback(),
'aten::masked_select' : MakeTorchFallback(),
'aten::as_strided' : MakeTorchFallback(),
'aten::_local_scalar_dense' : MakeTorchFallback(),
'aten::gt.Scalar_out' : MakeTorchFallback(),
}
for binary_op, onnx_op in {
@ -64,7 +78,7 @@ for unary_op in [
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round',
'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero',
'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']:
'sign', 'hardsigmoid', 'isinf', 'det']:
aten_name = f'aten::{unary_op}'
onnx_op = onnx_ops[unary_op]('self')
ops[aten_name] = onnx_op

View file

@ -6,7 +6,7 @@ from opgen.generator import \
ORTGen as ORTGen, \
ONNXOp as ONNXOp, \
SignatureOnly as SignatureOnly, \
MakeFallthrough as MakeFallthrough
MakeTorchFallback as MakeTorchFallback
from opgen.onnxops import *

View file

@ -48,6 +48,7 @@ class ONNXOp:
def __init__(self,
name: str,
outputs: int,
input_types: List,
*inputs: Union[str, Outputs],
**attributes: Optional[Union[str, Outputs]]):
self.name = name
@ -55,6 +56,7 @@ class ONNXOp:
self.inputs = inputs
self.attributes = attributes
self.domain = None
self.input_types = input_types
def eval(self, ctx: ONNXOpEvalContext):
evaluated_inputs = []
@ -71,10 +73,10 @@ class ONNXOp:
return self.outputs
class SignatureOnly(ONNXOp):
def __init__(self): super().__init__(None, 0)
def __init__(self): super().__init__(None, 0, [])
class MakeFallthrough(ONNXOp):
def __init__(self): super().__init__(None, 0)
class MakeTorchFallback(ONNXOp):
def __init__(self): super().__init__(None, 0, [])
class FunctionGenerationError(NotImplementedError):
def __init__(self, cpp_func: ast.FunctionDecl, message: str):
@ -88,13 +90,13 @@ class MappedOpFunction:
onnx_op: ONNXOp,
cpp_func: ast.FunctionDecl,
signature_only: bool,
make_fallthrough: bool):
make_torch_fallback: bool):
self.op_namespace = op_namespace
self.mapped_op_name = mapped_op_name
self.onnx_op = onnx_op
self.cpp_func = cpp_func
self.signature_only = signature_only
self.make_fallthrough = make_fallthrough
self.make_torch_fallback = make_torch_fallback
class ORTGen:
_mapped_ops: Dict[str, ONNXOp]
@ -126,9 +128,6 @@ class ORTGen:
del self._mapped_ops[mapped_func.mapped_op_name]
generated_funcs.append(mapped_func)
if mapped_func.make_fallthrough:
continue
ns = mapped_func.op_namespace
if current_ns and current_ns != ns:
current_ns = None
@ -173,6 +172,7 @@ class ORTGen:
writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
writer.writeline()
writer.writeline('#include <torch/extension.h>')
writer.writeline('#include <ATen/native/CPUFallback.h>')
writer.writeline()
writer.writeline('#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>')
writer.writeline()
@ -206,6 +206,27 @@ class ORTGen:
writer.pop_indent()
writer.write(')')
def _write_cpu_fall_back(self,
writer: writer.SourceWriter,
mapped_func: MappedOpFunction):
onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func
#return at::native::call_fallback_fn<
# &at::native::cpu_fallback,
# ATEN_OP(eq_Tensor)>::call(self, other);
writer.writeline('return native::call_fallback_fn<')
writer.push_indent()
writer.writeline('&native::cpu_fallback,')
writer.write('ATEN_OP(')
writer.write(cpp_func.identifier.value)
writer.write(')>::call(')
params = ', '.join([p.member.identifier.value for p \
in cpp_func.parameters if p.member.identifier])
writer.write(params)
writer.writeline(');')
writer.pop_indent()
def _write_function_body(
self,
writer: writer.SourceWriter,
@ -214,6 +235,15 @@ class ORTGen:
assert(len(cpp_func.parameters) > 0)
# Debug Logging
log_params = ', '.join([p.member.identifier.value for p \
in cpp_func.parameters if p.member.identifier])
writer.writeline(f'ORT_LOG_FN({log_params});')
writer.writeline()
if mapped_func.make_torch_fallback:
return self._write_cpu_fall_back(writer, mapped_func)
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
if return_alias_info and not return_alias_info.is_writable:
return_alias_info = None
@ -224,11 +254,32 @@ class ORTGen:
onnx_op.eval(ctx)
ctx.prepare_outputs()
# Debug Logging
log_params = ', '.join([p.member.identifier.value for p \
in cpp_func.parameters if p.member.identifier])
writer.writeline(f'ORT_LOG_FN({log_params});')
writer.writeline()
# generate the type check
need_type_check = False
if not self._custom_ops:
for onnx_op_index, onnx_op in enumerate(ctx.ops):
for op_input in onnx_op.inputs:
if not isinstance(op_input, Outputs):
need_type_check = True
break
if need_type_check:
writer.write('if (')
i = 0
for onnx_op_index, onnx_op in enumerate(ctx.ops):
for idx, op_input in enumerate(onnx_op.inputs):
if isinstance(op_input, Outputs):
continue
writer.writeline(' || ' if i > 0 else '')
if i == 0:
writer.push_indent()
cpp_param = cpp_func.get_parameter(op_input)
supported_types = ','.join([type for type in onnx_op.input_types[idx]])
writer.write('!IsSupportedType(%s, {%s})' % (cpp_param.identifier.value, supported_types))
i += 1
writer.writeline(') {')
self._write_cpu_fall_back(writer, mapped_func)
writer.pop_indent()
writer.writeline('}')
# Fetch the ORT invoker from an at::Tensor.device()
# FIXME: find the first at::Tensor param anywhere in the signature
@ -258,10 +309,10 @@ class ORTGen:
continue
# See if this input is aliased as an in-place tensor
cpp_param = cpp_func.get_parameter(op_input)
if return_alias_info and cpp_param and \
len(cpp_param.torch_param) == 1 and \
self._get_alias_info(cpp_param.torch_param[0]) == return_alias_info:
in_place_param = cpp_param
if return_alias_info and cpp_param:
for torch_p in cpp_param.torch_param:
if self._get_alias_info(torch_p) == return_alias_info:
in_place_param = cpp_param
writer.write(f'auto ort_input_{op_input} = ')
writer.writeline(f'create_ort_value(invoker, {op_input});')
@ -367,18 +418,15 @@ class ORTGen:
for mapped_func in generated_funcs:
cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func
if mapped_func.make_fallthrough:
reg_function_arg = 'torch::CppFunction::makeFallthrough()'
if mapped_func.op_namespace:
reg_function_arg = f'{mapped_func.op_namespace}::'
else:
if mapped_func.op_namespace:
reg_function_arg = f'{mapped_func.op_namespace}::'
else:
reg_function_arg = ''
reg_function_arg += cpp_func.identifier.value
reg_function_arg = ''
reg_function_arg += cpp_func.identifier.value
writer.write('m.impl(')
if not mapped_func.make_fallthrough:
reg_function_arg = f'TORCH_FN({reg_function_arg})'
reg_function_arg = f'TORCH_FN({reg_function_arg})'
writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});')
@ -427,7 +475,7 @@ class ORTGen:
op_namespace = None
op_namewithoutnamespace = op_name
cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '__')
cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '_')
onnx_op = self._mapped_ops.get(op_name)
if not onnx_op:
@ -439,7 +487,7 @@ class ORTGen:
onnx_op,
cpp_func,
isinstance(onnx_op, SignatureOnly),
isinstance(onnx_op, MakeFallthrough))
isinstance(onnx_op, MakeTorchFallback))
def _parse_function_decls(self, cpp_parser: parser.CPPParser):
# Parse the C++ declarations

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,8 @@
#include "ort_aten.h"
#include "ort_tensor.h"
#include <c10/core/TensorImpl.h>
#include <ATen/native/CPUFallback.h>
namespace torch_ort {
namespace eager {
@ -158,13 +160,26 @@ onnx::AttributeProto create_ort_attribute(
return attr;
}
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end();
}
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types){
return std::find(valid_types.begin(), valid_types.end(), tensor.scalar_type()) != valid_types.end();
}
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types){
return std::find(valid_types.begin(), valid_types.end(), at::kInt) != valid_types.end() ||
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
}
//#pragma endregion
//#pragma region Hand-Implemented ATen Ops
namespace aten {
at::Tensor empty__memory_format(
at::Tensor empty_memory_format(
at::IntArrayRef size,
// *,
c10::optional<at::ScalarType> dtype_opt,
@ -186,7 +201,7 @@ at::Tensor empty__memory_format(
ort_scalar_type_from_aten(*dtype_opt),
size.vec(),
&ot);
return aten_tensor_from_ort(
std::move(ot),
at::TensorOptions()
@ -255,6 +270,29 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
self.options());
}
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
switch (dtype){
case at::kFloat:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
case at::kDouble:
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
case at::kHalf:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
case at::kBFloat16:
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
case at::kInt:
return ONNX_NAMESPACE::TensorProto_DataType_INT32;
case at::kShort:
return ONNX_NAMESPACE::TensorProto_DataType_INT16;
case at::kLong:
return ONNX_NAMESPACE::TensorProto_DataType_INT64;
case at::kBool:
return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
default:
ORT_THROW("Unsupport aten scalar type: ", dtype);
}
}
at::Tensor& copy_(
at::Tensor& self,
const at::Tensor& src,
@ -269,8 +307,45 @@ at::Tensor& copy_(
: src.device());
const auto ort_src = create_ort_value(invoker, src);
auto ort_self = create_ort_value(invoker, self);
if (self.scalar_type() != src.scalar_type()){
// invoke cast first
std::vector<OrtValue> ort_cast_output(1);
onnxruntime::NodeAttributes attrs(1);
attrs["to"] = create_ort_attribute(
"to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong);
copy(invoker, ort_src, ort_self);
auto status = invoker.Invoke("Cast", {
std::move(ort_src),
}, ort_cast_output, &attrs);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
copy(invoker, ort_cast_output[0], ort_self);
}
else{
copy(invoker, ort_src, ort_self);
}
return self;
}
at::Tensor _copy_from_and_resize(
const at::Tensor& self,
const at::Tensor& dst){
ORT_LOG_FN(self, dst);
assert_tensor_supported(self);
assert_tensor_supported(dst);
auto& invoker = GetORTInvoker(self.device().type() == at::kORT
? self.device()
: dst.device());
const auto ort_self = create_ort_value(invoker, self);
auto ort_dst = create_ort_value(invoker, dst);
copy(invoker, ort_self, ort_dst);
return self;
}

View file

@ -73,5 +73,11 @@ onnx::AttributeProto create_ort_attribute(
const char* name,
const char* value);
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types);
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types);
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types);
} // namespace eager
} // namespace torch_ort

View file

@ -5,6 +5,7 @@
#include <c10/core/TensorImpl.h>
#include <core/framework/ort_value.h>
#include <iostream>
namespace torch_ort {
namespace eager {

View file

@ -120,6 +120,13 @@ class OrtEPTests(unittest.TestCase):
ort_device = torch_ort.device(1)
assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
#disable the print test for now as we need to merge a PR to pytorch first.
#def test_print(self):
# x = torch.ones(1, 2)
# ort_x = x.to('ort')
# with OutputGrabber() as out:
# print(ort_x)
# assert "tensor([[1., 1.]], device='ort:0')" in out.capturedtext
if __name__ == '__main__':
unittest.main()

View file

@ -25,6 +25,14 @@ class OrtOpTests(unittest.TestCase):
assert torch.allclose(
torch.add(cpu_ones, cpu_ones, alpha=2.5),
torch.add(ort_ones, ort_ones, alpha=2.5).cpu())
def test_mul_bool(self):
device = self.get_device()
cpu_ones = torch.ones(3, 3, dtype=bool)
ort_ones = cpu_ones.to(device)
assert torch.allclose(
torch.mul(cpu_ones, cpu_ones),
torch.mul(ort_ones, ort_ones).cpu())
def test_add_(self):
device = self.get_device()
@ -68,6 +76,20 @@ class OrtOpTests(unittest.TestCase):
cpu_ans = cpu_ones * 4
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
assert torch.allclose(cpu_ans, ort_ans.cpu())
def test_max(self):
cpu_tensor = torch.rand(10, 10)
ort_tensor = cpu_tensor.to('ort')
y = ort_tensor.max()
x = cpu_tensor.max()
assert torch.allclose(x, y.cpu())
def test_min(self):
cpu_tensor = torch.rand(10, 10)
ort_tensor = cpu_tensor.to('ort')
y = ort_tensor.min()
x = cpu_tensor.min()
assert torch.allclose(x, y.cpu())
if __name__ == '__main__':
unittest.main()