diff --git a/.gitignore b/.gitignore index 113cf46c83..d5507e0c24 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,4 @@ onnxruntime/python/version_info.py /tools/perf_util/target/classes /tools/perf_util/src/main/resources /orttraining/orttraining/eager/ort_aten.g.cpp +/orttraining/orttraining/eager/ort_customops.g.cpp \ No newline at end of file diff --git a/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h b/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h new file mode 100644 index 0000000000..ffe83a1c67 --- /dev/null +++ b/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h @@ -0,0 +1 @@ +Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB); \ No newline at end of file diff --git a/orttraining/orttraining/eager/opgen/opgen.py b/orttraining/orttraining/eager/opgen/opgen.py index dbdc6616d1..71e0542894 100755 --- a/orttraining/orttraining/eager/opgen/opgen.py +++ b/orttraining/orttraining/eager/opgen/opgen.py @@ -3,120 +3,31 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - +from opgen.writer import SourceWriter as SourceWriter +from opgen.parser import cpp_create_from_file as CPPParser +import sys +import os +from opgen.generator import ORTGen as ORTGen +from importlib.machinery import SourceFileLoader import argparse -from pathlib import Path - -from copy import deepcopy - -from opgen.generator import \ - ORTGen as ORTGen, \ - ONNXOp as ONNXOp, \ - SignatureOnly as SignatureOnly, \ - MakeFallthrough as MakeFallthrough - -from opgen.onnxops import * - -kMSDomain = 'onnxruntime::kMSDomain' - parser = argparse.ArgumentParser(description='Generate ORT ATen operations') +parser.add_argument('--ops_module', type=str, + help='Python module containing the Onnx Operation signature and list of ops to map') parser.add_argument('--output_file', default=None, type=str, help='Output file [default to std out]') -parser.add_argument('--use_preinstalled_torch', action='store_true', help='Use pre-installed torch from the python environment') +parser.add_argument('--header_file', type=str, + help='Header file which contains ATen / Pytorch operation signature') +parser.add_argument('--custom_ops', action='store_true', help='Whether we are generating code for custom ops or native operation') args = parser.parse_args() +ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module() +ortgen = ORTGen(ops_module.ops, custom_ops=args.custom_ops) -class ReluGrad(ONNXOp): - def __init__(self, dY, X): - super().__init__('ReluGrad', 1, dY, X) - self.domain = kMSDomain - -class Gelu(ONNXOp): - def __init__(self, X): - super().__init__('Gelu', 1, X) - self.domain = kMSDomain - -class GeluGrad(ONNXOp): - def __init__(self, dY, X): - super().__init__('GeluGrad', 1, dY, X) - self.domain = kMSDomain - -ops = { - # Hand-Implemented Ops - 'aten::empty.memory_format': SignatureOnly(), - 'aten::empty_strided': SignatureOnly(), - 'aten::zero_': SignatureOnly(), - 'aten::copy_': SignatureOnly(), - 'aten::reshape': SignatureOnly(), - 'aten::view': SignatureOnly(), - - 'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'), - 'aten::t': Transpose('self'), - 'aten::mm': MatMul('self', 'mat2'), - 'aten::zeros_like': ConstantOfShape(Shape('self')), #the default constant is 0, so don't need to speicify attribute - - 'aten::sum.dim_IntList': ReduceSum('self', 'dim', keepdims='keepdim'), - 'aten::threshold_backward': ReluGrad('grad_output', 'self'), - - 'aten::fmod.Scalar': Mod('self', 'other', fmod=1), - 'aten::fmod.Tensor': Mod('self', 'other', fmod=1), - - 'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd' - 'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'), - 'aten::gelu' : Gelu('self'), - 'aten::gelu_backward' : GeluGrad('grad', 'self') -} - -for binary_op, onnx_op in { - 'add': Add('self', Mul('alpha', 'other')), - 'sub': Sub('self', Mul('alpha', 'other')), - 'mul': Mul('self', 'other'), - 'div': Div('self', 'other')}.items(): - for dtype in ['Tensor', 'Scalar']: - for variant in ['', '_']: - ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op) - -for unary_op in [ - 'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos', - 'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round', - 'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero', - 'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']: - aten_name = f'aten::{unary_op}' - onnx_op = onnx_ops[unary_op]('self') - ops[aten_name] = onnx_op - # produce the in-place variant as well for ops that support it - if unary_op not in ['isnan', 'nonzero', 'min', 'max', 'isinf', 'det']: - ops[f'{aten_name}_'] = onnx_op - -ortgen = ORTGen(ops) - -import os -import sys - -from opgen.parser import cpp_create_from_file as CPPParser -from opgen.writer import SourceWriter as SourceWriter - -if args.use_preinstalled_torch: - import torch - regdecs_path = Path(torch.__file__).parent.joinpath('include/ATen/RegistrationDeclarations.h') -else: - regdecs_path = os.path.realpath(os.path.join( - os.path.dirname(__file__), - '..', - '..', - '..', - 'external', - 'pytorch', - 'build', - 'aten', - 'src', - 'ATen', - 'RegistrationDeclarations.h')) - +regdecs_path = args.header_file print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}") output = sys.stdout -if not args.output_file is None: +if args.output_file: output = open(args.output_file, 'wt') with CPPParser(regdecs_path) as parser, SourceWriter(output) as writer: diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py new file mode 100644 index 0000000000..c07c7d281d --- /dev/null +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -0,0 +1,73 @@ +from copy import deepcopy + +from opgen.generator import \ + ORTGen as ORTGen, \ + ONNXOp as ONNXOp, \ + SignatureOnly as SignatureOnly, \ + MakeFallthrough as MakeFallthrough + +from opgen.onnxops import * + +kMSDomain = 'onnxruntime::kMSDomain' + +class ReluGrad(ONNXOp): + def __init__(self, dY, X): + super().__init__('ReluGrad', 1, dY, X) + self.domain = kMSDomain + +class Gelu(ONNXOp): + def __init__(self, X): + super().__init__('Gelu', 1, X) + self.domain = kMSDomain + +class GeluGrad(ONNXOp): + def __init__(self, dY, X): + super().__init__('GeluGrad', 1, dY, X) + self.domain = kMSDomain + +ops = { + # Hand-Implemented Ops + 'aten::empty.memory_format': SignatureOnly(), + 'aten::empty_strided': SignatureOnly(), + 'aten::zero_': SignatureOnly(), + 'aten::copy_': SignatureOnly(), + 'aten::reshape': SignatureOnly(), + 'aten::view': SignatureOnly(), + + 'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'), + 'aten::t': Transpose('self'), + 'aten::mm': MatMul('self', 'mat2'), + 'aten::zeros_like': ConstantOfShape(Shape('self')), #the default constant is 0, so don't need to speicify attribute + + 'aten::sum.dim_IntList': ReduceSum('self', 'dim', keepdims='keepdim'), + 'aten::threshold_backward': ReluGrad('grad_output', 'self'), + + 'aten::fmod.Scalar': Mod('self', 'other', fmod=1), + 'aten::fmod.Tensor': Mod('self', 'other', fmod=1), + + 'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd' + 'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'), + 'aten::gelu' : Gelu('self'), + 'aten::gelu_backward' : GeluGrad('grad', 'self') +} + +for binary_op, onnx_op in { + 'add': Add('self', Mul('alpha', 'other')), + 'sub': Sub('self', Mul('alpha', 'other')), + 'mul': Mul('self', 'other'), + 'div': Div('self', 'other')}.items(): + for dtype in ['Tensor', 'Scalar']: + for variant in ['', '_']: + ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op) + +for unary_op in [ + 'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos', + 'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round', + 'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero', + 'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']: + aten_name = f'aten::{unary_op}' + onnx_op = onnx_ops[unary_op]('self') + ops[aten_name] = onnx_op + # produce the in-place variant as well for ops that support it + if unary_op not in ['isnan', 'nonzero', 'min', 'max', 'isinf', 'det']: + ops[f'{aten_name}_'] = onnx_op diff --git a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py new file mode 100644 index 0000000000..90ed820c83 --- /dev/null +++ b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py @@ -0,0 +1,15 @@ +from copy import deepcopy + +from opgen.generator import AttrType, ONNXAttr + +from opgen.generator import \ + ORTGen as ORTGen, \ + ONNXOp as ONNXOp, \ + SignatureOnly as SignatureOnly, \ + MakeFallthrough as MakeFallthrough + +from opgen.onnxops import * + +ops = { + 'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB') +} diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 9c77b68c21..84eb936a7f 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -27,6 +27,7 @@ class AttrType: STRING = 'const char*' STRINGS = '' TENSOR = 'at::Tensor' + LONG = 'at::ScalarType::Long' class ONNXAttr: def __init__(self, value, type: AttrType=None): @@ -77,35 +78,36 @@ class MakeFallthrough(ONNXOp): class FunctionGenerationError(NotImplementedError): def __init__(self, cpp_func: ast.FunctionDecl, message: str): - super().__init__(f'{message} (torch: {cpp_func.torch_func.torch_schema})') + super().__init__(f'{message} ({cpp_func.identifier})') class MappedOpFunction: def __init__( self, - torch_op_namespace: str, - torch_op_name: str, + op_namespace: str, + mapped_op_name: str, onnx_op: ONNXOp, cpp_func: ast.FunctionDecl, - torch_func: ast.FunctionDecl, signature_only: bool, make_fallthrough: bool): - self.torch_op_namespace = torch_op_namespace - self.torch_op_name = torch_op_name + self.op_namespace = op_namespace + self.mapped_op_name = mapped_op_name self.onnx_op = onnx_op self.cpp_func = cpp_func - self.torch_func = torch_func self.signature_only = signature_only self.make_fallthrough = make_fallthrough class ORTGen: _mapped_ops: Dict[str, ONNXOp] + _custom_ops: bool def __init__( self, - ops: Optional[Dict[str, ONNXOp]] = None): - self._mapped_ops = {} + ops: Optional[Dict[str, ONNXOp]] = None, + custom_ops : bool = False): + self._mapped_ops = {} if ops: self.register_many(ops) + self._custom_ops = custom_ops def register(self, aten_name: str, onnx_op: ONNXOp): self._mapped_ops[aten_name] = onnx_op @@ -121,13 +123,13 @@ class ORTGen: current_ns = None for mapped_func in self._parse_mapped_function_decls(cpp_parser): - del self._mapped_ops[mapped_func.torch_func.identifier.value] + del self._mapped_ops[mapped_func.mapped_op_name] generated_funcs.append(mapped_func) if mapped_func.make_fallthrough: continue - ns = mapped_func.torch_op_namespace + ns = mapped_func.op_namespace if current_ns and current_ns != ns: current_ns = None writer.pop_namespace() @@ -137,7 +139,8 @@ class ORTGen: writer.push_namespace(ns) writer.writeline() - writer.writeline(f'// {mapped_func.torch_func.torch_schema}') + if mapped_func.cpp_func.torch_func: + writer.writeline(f'// {mapped_func.cpp_func.torch_func.torch_schema}') self._write_function_signature(writer, mapped_func.cpp_func) if mapped_func.signature_only: @@ -153,7 +156,10 @@ class ORTGen: current_ns = None writer.pop_namespace() - self._write_function_registrations(writer, generated_funcs) + if not self._custom_ops: + self._write_function_registrations(writer, generated_funcs) + else: + self._write_custom_ops_registrations(writer, generated_funcs) self._write_file_postlude(writer) if len(self._mapped_ops) > 0: @@ -164,6 +170,8 @@ class ORTGen: writer.writeline('// AUTO-GENERATED CODE! - DO NOT EDIT!') writer.writeline(f'// $ python {" ".join(sys.argv)}') writer.writeline() + writer.writeline('#include "python/onnxruntime_pybind_state_common.h"') + writer.writeline() writer.writeline('#include ') writer.writeline() writer.writeline('#include ') @@ -206,7 +214,7 @@ class ORTGen: assert(len(cpp_func.parameters) > 0) - return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) + return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None if return_alias_info and not return_alias_info.is_writable: return_alias_info = None in_place_param: ast.ParameterDecl = None @@ -225,16 +233,16 @@ class ORTGen: # Fetch the ORT invoker from an at::Tensor.device() # FIXME: find the first at::Tensor param anywhere in the signature # instead of simply the first parameter? - first_torch_param = cpp_func.torch_func.parameters[0].member + first_param = cpp_func.parameters[0].member if not isinstance( - first_torch_param.parameter_type.desugar(), - ast.TensorType): + first_param.parameter_type.desugar(), + ast.ConcreteType) or 'Tensor' not in first_param.parameter_type.desugar().identifier_tokens[0].value: raise FunctionGenerationError( cpp_func, 'First parameter must be an at::Tensor') writer.write('auto& invoker = GetORTInvoker(') - writer.write(first_torch_param.identifier.value) + writer.write(first_param.identifier.value) writer.writeline('.device());') writer.writeline() @@ -329,11 +337,15 @@ class ORTGen: # TODO: Handle mutliple results # TODO: Assert return type - if not return_alias_info: + if not return_alias_info: writer.writeline('return aten_tensor_from_ort(') writer.push_indent() - writer.writeline(f'std::move({return_outputs}[0]),') - writer.writeline(f'{first_torch_param.identifier.value}.options());') + if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector': + writer.writeline(f'{return_outputs},') + writer.writeline(f'{first_param.identifier.value}.options());') + else: + writer.writeline(f'std::move({return_outputs}[0]),') + writer.writeline(f'{first_param.identifier.value}.options());') writer.pop_indent() return @@ -354,13 +366,13 @@ class ORTGen: writer.writeline('ORT_LOG_DEBUG << "ATen init";') for mapped_func in generated_funcs: - cpp_func, torch_func = mapped_func.cpp_func, mapped_func.torch_func + cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func if mapped_func.make_fallthrough: reg_function_arg = 'torch::CppFunction::makeFallthrough()' else: - if mapped_func.torch_op_namespace: - reg_function_arg = f'{mapped_func.torch_op_namespace}::' + if mapped_func.op_namespace: + reg_function_arg = f'{mapped_func.op_namespace}::' else: reg_function_arg = '' reg_function_arg += cpp_func.identifier.value @@ -375,6 +387,24 @@ class ORTGen: writer.writeline('}') writer.writeline() + def _write_custom_ops_registrations( + self, + writer: writer.SourceWriter, + generated_funcs: List[MappedOpFunction]): + writer.writeline() + writer.writeline('void GenerateCustomOpsBindings(pybind11::module_ m) {') + writer.push_indent() + writer.writeline('ORT_LOG_DEBUG << "GenerateCustomOpsBindings init";') + + for mapped_func in generated_funcs: + cpp_func = mapped_func.cpp_func + writer.write('m.def(') + writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});') + + writer.pop_indent() + writer.writeline('}') + writer.writeline() + def _get_alias_info(self, torch_type_or_param: Union[ast.Type, ast.ParameterDecl]): if isinstance(torch_type_or_param, ast.ParameterDecl): torch_type = torch_type_or_param.parameter_type @@ -383,28 +413,32 @@ class ORTGen: return getattr(torch_type.desugar(), 'alias_info', None) def _parse_mapped_function_decls(self, cpp_parser: parser.CPPParser): - for cpp_func, torch_func in self._parse_function_decls(cpp_parser): - torch_op_name = torch_func.identifier.value - if torch_op_name not in self._mapped_ops: - continue - onnx_op = self._mapped_ops[torch_op_name] + for cpp_func in self._parse_function_decls(cpp_parser): + torch_func = cpp_func.torch_func + if not torch_func: + op_namespace = None + op_name = cpp_func.identifier.value + else: + op_name = torch_func.identifier.value + + try: + op_namespace = op_name[0:op_name.index('::')] + op_namewithoutnamespace = op_name[len(op_namespace) + 2:] + except: + op_namespace = None + op_namewithoutnamespace = op_name + + cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '__') + + onnx_op = self._mapped_ops.get(op_name) if not onnx_op: continue - try: - torch_op_namespace = torch_op_name[0:torch_op_name.index('::')] - torch_op_name = torch_op_name[len(torch_op_namespace) + 2:] - except: - torch_op_namespace = None - - cpp_func.identifier.value = torch_op_name.replace('.', '__') - yield MappedOpFunction( - torch_op_namespace, - torch_op_name, + op_namespace, + op_name, onnx_op, cpp_func, - torch_func, isinstance(onnx_op, SignatureOnly), isinstance(onnx_op, MakeFallthrough)) @@ -415,7 +449,11 @@ class ORTGen: # Parse the Torch schema from the JSON comment that follows each C++ decl # and link associated Torch and C++ decls (functions, parameters, returns) for cpp_func in tu: - if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia: + if self._custom_ops == True: + # customops don't have torch schema + cpp_func.torch_func = None + yield cpp_func + elif cpp_func.semicolon and cpp_func.semicolon.trailing_trivia: for trivia in cpp_func.semicolon.trailing_trivia: if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT: yield self._parse_and_link_torch_function_decl(cpp_func, trivia) @@ -463,4 +501,4 @@ class ORTGen: torch_param = torch_func.parameters[i + j].member cpp_param.torch_param.append(torch_param) - return cpp_func, torch_func \ No newline at end of file + return cpp_func \ No newline at end of file diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 0caa145a35..066c05bae3 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -37,6 +37,19 @@ at::Tensor aten_tensor_from_ort( options)); } +const std::vector aten_tensor_from_ort( + std::vector& ortvalues, + const at::TensorOptions& options) { + const size_t num_outputs = ortvalues.size(); + std::vector atvalues = std::vector(num_outputs); + for (size_t i = 0; i < num_outputs; i++) { + atvalues[i] = at::Tensor(c10::make_intrusive( + std::move(ortvalues[i]), + options)); + } + return atvalues; +} + onnxruntime::MLDataType ort_scalar_type_from_aten( at::ScalarType dtype) { switch (dtype){ diff --git a/orttraining/orttraining/eager/ort_aten.h b/orttraining/orttraining/eager/ort_aten.h index 977660fde2..951fdcba8d 100644 --- a/orttraining/orttraining/eager/ort_aten.h +++ b/orttraining/orttraining/eager/ort_aten.h @@ -17,6 +17,10 @@ at::Tensor aten_tensor_from_ort( OrtValue&& ot, const at::TensorOptions& options); +const std::vector aten_tensor_from_ort( + std::vector& ortvalues, + const at::TensorOptions& options); + onnxruntime::MLDataType ort_scalar_type_from_aten( at::ScalarType dtype); diff --git a/orttraining/orttraining/eager/ort_customops.h b/orttraining/orttraining/eager/ort_customops.h new file mode 100644 index 0000000000..0a0be430f4 --- /dev/null +++ b/orttraining/orttraining/eager/ort_customops.h @@ -0,0 +1,9 @@ +#include + +namespace torch_ort { +namespace eager { + +void GenerateCustomOpsBindings(pybind11::module_ module); + +} // namespace eager +} // namespace torch_ort \ No newline at end of file diff --git a/orttraining/orttraining/eager/ort_eager.cpp b/orttraining/orttraining/eager/ort_eager.cpp index d4dbc18f5b..0ce4cd27f0 100644 --- a/orttraining/orttraining/eager/ort_eager.cpp +++ b/orttraining/orttraining/eager/ort_eager.cpp @@ -11,6 +11,7 @@ #include "python/onnxruntime_pybind_state_common.h" #include "orttraining/core/framework/torch/dlpack_python.h" #include +#include "ort_customops.h" namespace onnxruntime{ namespace python{ @@ -67,6 +68,9 @@ void addObjectMethodsForEager(py::module& m){ if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); }); + + auto customop_module = m.def_submodule("custom_ops"); + torch_ort::eager::GenerateCustomOpsBindings(customop_module); } } diff --git a/orttraining/orttraining/eager/ort_util.h b/orttraining/orttraining/eager/ort_util.h index b8e651882f..44952bb1dc 100644 --- a/orttraining/orttraining/eager/ort_util.h +++ b/orttraining/orttraining/eager/ort_util.h @@ -36,9 +36,9 @@ inline void CopyVectorToTensor(onnxruntime::ORTInvoker& invoker, Ort::TypeToTensorType::type, &ort_value)); - execution_provider.GetDataTransfer()->CopyTensor( + ORT_THROW_IF_ERROR(execution_provider.GetDataTransfer()->CopyTensor( ort_value->Get(), - tensor); + tensor)); } // vector is specialized so we need to handle it separately diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 9dafe899e5..8e2461218e 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -61,5 +61,13 @@ class OrtOpTests(unittest.TestCase): ort_zeros = torch.zeros_like(ones.to(device)) assert torch.allclose(cpu_zeros, ort_zeros.cpu()) + def test_gemm(self): + device = self.get_device() + cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + ort_ones = cpu_ones.to(device) + cpu_ans = cpu_ones * 4 + ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0) + assert torch.allclose(cpu_ans, ort_ans.cpu()) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc b/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc index 4aa33d642b..4362619e21 100644 --- a/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc @@ -49,7 +49,7 @@ constexpr const char* const k_gradient_norm_op_name = "ReduceAllL2"; constexpr const char* const k_unscale_op_name = "MixedPrecisionScale"; constexpr const char* const k_inplace_accumulator_op_name = "InPlaceAccumulator"; constexpr const char* const k_zero_gradient_op_name = "ZeroGradient"; -#if defined(USE_MPI) +#if defined(ORT_USE_NCCL) && defined(USE_MPI) constexpr const char* const k_adasum_op_name = "AdasumAllReduce"; #endif Status SetUpBaseGraph(Graph& graph); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index a16d742d8f..0c8039368c 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -535,9 +535,14 @@ def parse_arguments(): parser.add_argument( "--ms_experimental", action='store_true', help="Build microsoft experimental operators.")\ + # eager mode parser.add_argument( "--build_eager_mode", action='store_true', help="Build ONNXRuntime micro-benchmarks.") + parser.add_argument('--eager_customop_module', default=None, + help='Module containing custom op mappings for eager mode.') + parser.add_argument('--eager_customop_header', default=None, + help='Header containing custom op definitions for eager mode.') return parser.parse_args() @@ -2146,21 +2151,56 @@ def main(): args.rocm_version = "" if args.build_eager_mode: - # generate the ort aten backend code - def gen_ort_aten_ops(eager_root_dir): - gen_cpp_name = os.path.join(eager_root_dir, "ort_aten.g.cpp") - if os.path.exists(gen_cpp_name): - os.remove(gen_cpp_name) - subprocess.check_call([ - sys.executable, - os.path.join(eager_root_dir, 'opgen', 'opgen.py'), - "--output_file", - gen_cpp_name, - "--use_preinstalled_torch" - ]) - eager_root_dir = os.path.join(source_dir, "orttraining", "orttraining", "eager") - gen_ort_aten_ops(eager_root_dir) + if args.eager_customop_module and not args.eager_customop_header: + raise Exception('eager_customop_header must be provided when eager_customop_module is') + elif args.eager_customop_header and not args.eager_customop_module: + raise Exception('eager_customop_module must be provided when eager_customop_header is') + + def gen_ops(gen_cpp_name: str, header_file: str, ops_module: str, custom_ops: bool): + gen_cpp_scratch_name = gen_cpp_name + '.working' + print(f'Generating ORT ATen overrides (output_file: {gen_cpp_name}, header_file: {header_file},' + f'ops_module: {ops_module}), custom_ops: {custom_ops}') + + cmd = [sys.executable, os.path.join(os.path.join(eager_root_dir, 'opgen', 'opgen.py')), + '--output_file', gen_cpp_scratch_name, + '--ops_module', ops_module, + '--header_file', header_file] + + if custom_ops: + cmd += ["--custom_ops"] + + subprocess.check_call(cmd) + + import filecmp + if (not os.path.isfile(gen_cpp_name) or + not filecmp.cmp(gen_cpp_name, gen_cpp_scratch_name, shallow=False)): + os.rename(gen_cpp_scratch_name, gen_cpp_name) + else: + os.remove(gen_cpp_scratch_name) + + def gen_ort_ops(): + # generate native aten ops + import torch + regdecs_path = os.path.join(os.path.dirname(torch.__file__), 'include/ATen/RegistrationDeclarations.h') + + ops_module = os.path.join(eager_root_dir, 'opgen/opgen/atenops.py') + gen_ops(os.path.join(eager_root_dir, 'ort_aten.g.cpp'), regdecs_path, ops_module, False) + + # generate custom ops + if not args.eager_customop_header: + args.eager_customop_header = os.path.realpath(os.path.join( + eager_root_dir, + "opgen", + "CustomOpDeclarations.h")) + + if not args.eager_customop_module: + args.eager_customop_module = os.path.join(eager_root_dir, 'opgen/opgen/custom_ops.py') + + gen_ops(os.path.join(eager_root_dir, 'ort_customops.g.cpp'), + args.eager_customop_header, args.eager_customop_module, True) + + gen_ort_ops() generate_build_tree( cmake_path, source_dir, build_dir, cuda_home, cudnn_home, rocm_home, mpi_home, nccl_home,