Added support to Eager CodeGen for multiple in-place parameters. (#10945)

* Added support to CodeGen for multiple inplace output parameters.

* Updated output Tensor to references.
This commit is contained in:
Chandru Ramakrishnan 2022-03-21 16:10:22 -04:00 committed by GitHub
parent 1cc2cfb7b8
commit 4a5b5328a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 81 additions and 40 deletions

View file

@ -1 +1,2 @@
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);
std::tuple<Tensor&, Tensor&, Tensor&> batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const float epsilon, const float momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"}

View file

@ -25,7 +25,7 @@ ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops)
regdecs_path = args.header_file
print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}")
print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}")
output = sys.stdout
if args.output_file:
output = open(args.output_file, 'wt')

View file

@ -219,14 +219,12 @@ class TupleType(Type):
class AliasInfo(Node):
before_set: List[str]
after_set: List[str]
contained_types: List[Type]
tokens: List[Token]
def __init__(self):
super().__init__()
self.before_set = []
self.after_set = []
self.contained_types = []
self.tokens = []
self.is_writable = False

View file

@ -11,7 +11,8 @@ from opgen.generator import \
from opgen.onnxops import *
ops = {
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB')
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB'),
'batchnorm_inplace': BatchNormalization('X', 'scale', 'B', 'input_mean', 'input_var', 'epsilon', 'momentum', 1)
}
type_promotion_ops = {}

View file

@ -246,11 +246,6 @@ class ORTGen:
if mapped_func.make_torch_fallback:
return self._write_cpu_fall_back(writer, mapped_func)
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
if return_alias_info and not return_alias_info.is_writable:
return_alias_info = None
in_place_param: ast.ParameterDecl = None
# Eval the outer ONNX op to produce a topologically ordered list of ops
ctx = ONNXOpEvalContext()
onnx_op.eval(ctx)
@ -339,13 +334,7 @@ class ORTGen:
for op_input in onnx_op.inputs:
if isinstance(op_input, Outputs):
continue
# See if this input is aliased as an in-place tensor
cpp_param = cpp_func.get_parameter(op_input)
if return_alias_info and cpp_param:
for torch_p in cpp_param.torch_param:
if self._get_alias_info(torch_p) == return_alias_info:
in_place_param = cpp_param
writer.write(f'auto ort_input_{op_input} = ')
writer.writeline(f'create_ort_value(invoker, {op_input});')
if need_type_promotion:
@ -388,11 +377,36 @@ class ORTGen:
writer.write(f'std::vector<OrtValue> {onnx_op.outputs}')
writer.writeline(f'({onnx_op.outputs.count});')
if in_place_param:
assert(onnx_op.outputs.count == 1)
# TODO: This assumes that the first output corresponds to the first input.
# This may not work for more complicated ops.
writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[0]};')
return_info = cpp_func.torch_func.return_type if cpp_func.torch_func else None
in_place_params = {}
if return_info:
for input_index, op_input in enumerate(onnx_op.inputs):
if isinstance(op_input, Outputs):
continue
# See if this input is aliased as an in-place tensor
cpp_param = cpp_func.get_parameter(op_input)
if cpp_param:
for torch_p in cpp_param.torch_param:
if isinstance(return_info, ast.TupleType):
for output_index, output_param in enumerate(return_info.elements):
assert isinstance(output_param.member, ast.TupleMemberType), "output_param.member must be of TupleMemberType"
output_alias = self._get_alias_info(output_param.member.element_type)
if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable:
writer.writeline(f'{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op.inputs[input_index]};')
in_place_params[output_index] = cpp_param.identifier.value
break
else:
output_alias = self._get_alias_info(return_info)
if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable:
writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[input_index]};')
in_place_params[0] = cpp_param.identifier.value
break
if len(in_place_params) != 0 and len(in_place_params) != (len(return_info.elements) if isinstance(return_info, ast.TupleType) else 1):
raise Exception(f'Cannot mix and match inplace with non-inplace parameters - function: {cpp_func.identifier.value} ' +
f'in_place_params={in_place_params}, return_elements={return_info.elements}')
# Perform the invocation
writer.writeline()
@ -434,7 +448,7 @@ class ORTGen:
# TODO: Handle mutliple results
# TODO: Assert return type
if not return_alias_info:
if len(in_place_params) == 0:
# tensor options
writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}')
if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
@ -454,12 +468,20 @@ class ORTGen:
writer.writeline('tensor_options);')
writer.pop_indent()
return
if not in_place_param:
raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' +
'has alias info on its return type but no associated parameter')
writer.writeline(f'return {in_place_param.identifier.value};')
else:
if len(in_place_params) == 1:
writer.writeline(f'return {in_place_params[0]};')
else:
if not (isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::tuple'):
raise Exception(f'')
tensorRef = "Tensor&," * len(in_place_params)
tensorRef = tensorRef[:len(tensorRef)-1]
writer.write(f'return std::tuple<{tensorRef}>(')
for index, key in enumerate(sorted(in_place_params)):
if index > 0:
writer.write(', ')
writer.write(in_place_params[key])
writer.writeline(');')
def _write_function_registrations(
self,
@ -550,16 +572,19 @@ class ORTGen:
# Parse the Torch schema from the JSON comment that follows each C++ decl
# and link associated Torch and C++ decls (functions, parameters, returns)
for cpp_func in tu:
if self._custom_ops == True:
# customops don't have torch schema
cpp_func.torch_func = None
yield cpp_func
elif cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
hasSchema = False
if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
for trivia in cpp_func.semicolon.trailing_trivia:
if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT:
yield self._parse_and_link_torch_function_decl(cpp_func, trivia)
hasSchema = True
break
if not hasSchema:
# customops might not have torch schema
cpp_func.torch_func = None
yield cpp_func
def _parse_and_link_torch_function_decl(
self,
cpp_func: ast.FunctionDecl,

View file

@ -219,16 +219,19 @@ class TorchParser(ParserBase):
else:
raise UnexpectedTokenError("expression", self._peek_token())
def parse_type(self) -> Type:
parsed_type, alias_info = self._parse_type_and_alias()
if not alias_info:
return parsed_type
def _create_alias_info_type(self, parsed_type: Type, alias_info: AliasInfo) -> AliasInfoType:
if isinstance(parsed_type, ModifiedType):
parsed_type.base_type = AliasInfoType(parsed_type.base_type, alias_info)
else:
parsed_type = AliasInfoType(parsed_type, alias_info)
return parsed_type
def parse_type(self) -> Type:
parsed_type, alias_info = self._parse_type_and_alias()
if not alias_info:
return parsed_type
return self._create_alias_info_type(parsed_type, alias_info)
def _parse_type_and_alias(self) -> Tuple[Type, AliasInfo]:
parsed_type: Type = None
alias_info: AliasInfo = None
@ -239,8 +242,9 @@ class TorchParser(ParserBase):
if self._peek_token(TokenKind.OPEN_PAREN):
def parse_tuple_element():
element_type, element_alias_info = self._parse_type_and_alias()
if alias_info and element_alias_info:
alias_info.add_contained_type(element_alias_info)
if element_alias_info:
element_type = self._create_alias_info_type(element_type, element_alias_info)
return TupleMemberType(
element_type,
self._read_token() \

View file

@ -87,7 +87,19 @@ class OrtOpTests(unittest.TestCase):
cpu_ans = cpu_ones * 4
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
assert torch.allclose(cpu_ans, ort_ans.cpu())
def test_batchnormalization_inplace(self):
device = self.get_device()
x = torch.Tensor([[[[-1, 0, 1]], [[2., 3., 4.]]]]).to(device)
s = torch.Tensor([1.0, 1.5]).to(device)
bias = torch.Tensor([0., 1.]).to(device)
mean = torch.Tensor([0., 3.]).to(device)
var = torch.Tensor([1., 1.5]).to(device)
y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9)
assert torch.allclose(x.cpu(), y.cpu()), "x != y"
assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out"
assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out"
def test_max(self):
cpu_tensor = torch.rand(10, 10)
ort_tensor = cpu_tensor.to('ort')