diff --git a/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h b/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h index ffe83a1c67..a7bf492f74 100644 --- a/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h +++ b/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h @@ -1 +1,2 @@ -Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB); \ No newline at end of file +Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB); +std::tuple batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const float epsilon, const float momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"} \ No newline at end of file diff --git a/orttraining/orttraining/eager/opgen/opgen.py b/orttraining/orttraining/eager/opgen/opgen.py index 7fe7f02370..e0f267a9fb 100755 --- a/orttraining/orttraining/eager/opgen/opgen.py +++ b/orttraining/orttraining/eager/opgen/opgen.py @@ -25,7 +25,7 @@ ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module() ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops) regdecs_path = args.header_file -print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}") +print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}") output = sys.stdout if args.output_file: output = open(args.output_file, 'wt') diff --git a/orttraining/orttraining/eager/opgen/opgen/ast.py b/orttraining/orttraining/eager/opgen/opgen/ast.py index e01359ea7c..672d30a57b 100644 --- a/orttraining/orttraining/eager/opgen/opgen/ast.py +++ b/orttraining/orttraining/eager/opgen/opgen/ast.py @@ -219,14 +219,12 @@ class TupleType(Type): class AliasInfo(Node): before_set: List[str] after_set: List[str] - contained_types: List[Type] tokens: List[Token] def __init__(self): super().__init__() self.before_set = [] self.after_set = [] - self.contained_types = [] self.tokens = [] self.is_writable = False diff --git a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py index a49d4be751..d2631277cc 100644 --- a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py +++ b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py @@ -11,7 +11,8 @@ from opgen.generator import \ from opgen.onnxops import * ops = { - 'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB') + 'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB'), + 'batchnorm_inplace': BatchNormalization('X', 'scale', 'B', 'input_mean', 'input_var', 'epsilon', 'momentum', 1) } type_promotion_ops = {} diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 7226c504bd..64073e5689 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -246,11 +246,6 @@ class ORTGen: if mapped_func.make_torch_fallback: return self._write_cpu_fall_back(writer, mapped_func) - return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None - if return_alias_info and not return_alias_info.is_writable: - return_alias_info = None - in_place_param: ast.ParameterDecl = None - # Eval the outer ONNX op to produce a topologically ordered list of ops ctx = ONNXOpEvalContext() onnx_op.eval(ctx) @@ -339,13 +334,7 @@ class ORTGen: for op_input in onnx_op.inputs: if isinstance(op_input, Outputs): continue - # See if this input is aliased as an in-place tensor cpp_param = cpp_func.get_parameter(op_input) - if return_alias_info and cpp_param: - for torch_p in cpp_param.torch_param: - if self._get_alias_info(torch_p) == return_alias_info: - in_place_param = cpp_param - writer.write(f'auto ort_input_{op_input} = ') writer.writeline(f'create_ort_value(invoker, {op_input});') if need_type_promotion: @@ -388,11 +377,36 @@ class ORTGen: writer.write(f'std::vector {onnx_op.outputs}') writer.writeline(f'({onnx_op.outputs.count});') - if in_place_param: - assert(onnx_op.outputs.count == 1) - # TODO: This assumes that the first output corresponds to the first input. - # This may not work for more complicated ops. - writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[0]};') + return_info = cpp_func.torch_func.return_type if cpp_func.torch_func else None + in_place_params = {} + + if return_info: + for input_index, op_input in enumerate(onnx_op.inputs): + if isinstance(op_input, Outputs): + continue + + # See if this input is aliased as an in-place tensor + cpp_param = cpp_func.get_parameter(op_input) + if cpp_param: + for torch_p in cpp_param.torch_param: + if isinstance(return_info, ast.TupleType): + for output_index, output_param in enumerate(return_info.elements): + assert isinstance(output_param.member, ast.TupleMemberType), "output_param.member must be of TupleMemberType" + output_alias = self._get_alias_info(output_param.member.element_type) + if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable: + writer.writeline(f'{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op.inputs[input_index]};') + in_place_params[output_index] = cpp_param.identifier.value + break + else: + output_alias = self._get_alias_info(return_info) + if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable: + writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[input_index]};') + in_place_params[0] = cpp_param.identifier.value + break + + if len(in_place_params) != 0 and len(in_place_params) != (len(return_info.elements) if isinstance(return_info, ast.TupleType) else 1): + raise Exception(f'Cannot mix and match inplace with non-inplace parameters - function: {cpp_func.identifier.value} ' + + f'in_place_params={in_place_params}, return_elements={return_info.elements}') # Perform the invocation writer.writeline() @@ -434,7 +448,7 @@ class ORTGen: # TODO: Handle mutliple results # TODO: Assert return type - if not return_alias_info: + if len(in_place_params) == 0: # tensor options writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}') if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList': @@ -454,12 +468,20 @@ class ORTGen: writer.writeline('tensor_options);') writer.pop_indent() return - - if not in_place_param: - raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' + - 'has alias info on its return type but no associated parameter') - - writer.writeline(f'return {in_place_param.identifier.value};') + else: + if len(in_place_params) == 1: + writer.writeline(f'return {in_place_params[0]};') + else: + if not (isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::tuple'): + raise Exception(f'') + tensorRef = "Tensor&," * len(in_place_params) + tensorRef = tensorRef[:len(tensorRef)-1] + writer.write(f'return std::tuple<{tensorRef}>(') + for index, key in enumerate(sorted(in_place_params)): + if index > 0: + writer.write(', ') + writer.write(in_place_params[key]) + writer.writeline(');') def _write_function_registrations( self, @@ -550,16 +572,19 @@ class ORTGen: # Parse the Torch schema from the JSON comment that follows each C++ decl # and link associated Torch and C++ decls (functions, parameters, returns) for cpp_func in tu: - if self._custom_ops == True: - # customops don't have torch schema - cpp_func.torch_func = None - yield cpp_func - elif cpp_func.semicolon and cpp_func.semicolon.trailing_trivia: + hasSchema = False + if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia: for trivia in cpp_func.semicolon.trailing_trivia: if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT: yield self._parse_and_link_torch_function_decl(cpp_func, trivia) + hasSchema = True break + if not hasSchema: + # customops might not have torch schema + cpp_func.torch_func = None + yield cpp_func + def _parse_and_link_torch_function_decl( self, cpp_func: ast.FunctionDecl, diff --git a/orttraining/orttraining/eager/opgen/opgen/parser.py b/orttraining/orttraining/eager/opgen/opgen/parser.py index 58d9a16702..ba6ef93795 100644 --- a/orttraining/orttraining/eager/opgen/opgen/parser.py +++ b/orttraining/orttraining/eager/opgen/opgen/parser.py @@ -219,16 +219,19 @@ class TorchParser(ParserBase): else: raise UnexpectedTokenError("expression", self._peek_token()) - def parse_type(self) -> Type: - parsed_type, alias_info = self._parse_type_and_alias() - if not alias_info: - return parsed_type + def _create_alias_info_type(self, parsed_type: Type, alias_info: AliasInfo) -> AliasInfoType: if isinstance(parsed_type, ModifiedType): parsed_type.base_type = AliasInfoType(parsed_type.base_type, alias_info) else: parsed_type = AliasInfoType(parsed_type, alias_info) return parsed_type + def parse_type(self) -> Type: + parsed_type, alias_info = self._parse_type_and_alias() + if not alias_info: + return parsed_type + return self._create_alias_info_type(parsed_type, alias_info) + def _parse_type_and_alias(self) -> Tuple[Type, AliasInfo]: parsed_type: Type = None alias_info: AliasInfo = None @@ -239,8 +242,9 @@ class TorchParser(ParserBase): if self._peek_token(TokenKind.OPEN_PAREN): def parse_tuple_element(): element_type, element_alias_info = self._parse_type_and_alias() - if alias_info and element_alias_info: - alias_info.add_contained_type(element_alias_info) + if element_alias_info: + element_type = self._create_alias_info_type(element_type, element_alias_info) + return TupleMemberType( element_type, self._read_token() \ diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index e12b71a2f1..b2e39431cd 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -87,7 +87,19 @@ class OrtOpTests(unittest.TestCase): cpu_ans = cpu_ones * 4 ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0) assert torch.allclose(cpu_ans, ort_ans.cpu()) - + + def test_batchnormalization_inplace(self): + device = self.get_device() + x = torch.Tensor([[[[-1, 0, 1]], [[2., 3., 4.]]]]).to(device) + s = torch.Tensor([1.0, 1.5]).to(device) + bias = torch.Tensor([0., 1.]).to(device) + mean = torch.Tensor([0., 3.]).to(device) + var = torch.Tensor([1., 1.5]).to(device) + y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9) + assert torch.allclose(x.cpu(), y.cpu()), "x != y" + assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out" + assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out" + def test_max(self): cpu_tensor = torch.rand(10, 10) ort_tensor = cpu_tensor.to('ort')