Ported changes / bug fixes from torch/ort. (#8784)

* Ported changes / bug fixes from torch/ort.

* Fixed formatting

* Renamed function

* Renamed module_ to module.

* Revert "Renamed module_ to module."

This reverts commit b17fc114b3db20d174283811d90592b5b8154c19.

* Include pybind common header to fix linker errors on windows debug.

* Fix to generation of > 1 custom op.

Co-authored-by: Ashwin Hari <ashari@microsoft.com>
This commit is contained in:
Chandru Ramakrishnan 2021-08-23 17:45:40 -04:00 committed by GitHub
parent f51f2bad66
commit 2693af9799
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 280 additions and 163 deletions

1
.gitignore vendored
View file

@ -52,3 +52,4 @@ onnxruntime/python/version_info.py
/tools/perf_util/target/classes
/tools/perf_util/src/main/resources
/orttraining/orttraining/eager/ort_aten.g.cpp
/orttraining/orttraining/eager/ort_customops.g.cpp

View file

@ -0,0 +1 @@
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);

View file

@ -3,120 +3,31 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from opgen.writer import SourceWriter as SourceWriter
from opgen.parser import cpp_create_from_file as CPPParser
import sys
import os
from opgen.generator import ORTGen as ORTGen
from importlib.machinery import SourceFileLoader
import argparse
from pathlib import Path
from copy import deepcopy
from opgen.generator import \
ORTGen as ORTGen, \
ONNXOp as ONNXOp, \
SignatureOnly as SignatureOnly, \
MakeFallthrough as MakeFallthrough
from opgen.onnxops import *
kMSDomain = 'onnxruntime::kMSDomain'
parser = argparse.ArgumentParser(description='Generate ORT ATen operations')
parser.add_argument('--ops_module', type=str,
help='Python module containing the Onnx Operation signature and list of ops to map')
parser.add_argument('--output_file', default=None, type=str, help='Output file [default to std out]')
parser.add_argument('--use_preinstalled_torch', action='store_true', help='Use pre-installed torch from the python environment')
parser.add_argument('--header_file', type=str,
help='Header file which contains ATen / Pytorch operation signature')
parser.add_argument('--custom_ops', action='store_true', help='Whether we are generating code for custom ops or native operation')
args = parser.parse_args()
ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
ortgen = ORTGen(ops_module.ops, custom_ops=args.custom_ops)
class ReluGrad(ONNXOp):
def __init__(self, dY, X):
super().__init__('ReluGrad', 1, dY, X)
self.domain = kMSDomain
class Gelu(ONNXOp):
def __init__(self, X):
super().__init__('Gelu', 1, X)
self.domain = kMSDomain
class GeluGrad(ONNXOp):
def __init__(self, dY, X):
super().__init__('GeluGrad', 1, dY, X)
self.domain = kMSDomain
ops = {
# Hand-Implemented Ops
'aten::empty.memory_format': SignatureOnly(),
'aten::empty_strided': SignatureOnly(),
'aten::zero_': SignatureOnly(),
'aten::copy_': SignatureOnly(),
'aten::reshape': SignatureOnly(),
'aten::view': SignatureOnly(),
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
'aten::t': Transpose('self'),
'aten::mm': MatMul('self', 'mat2'),
'aten::zeros_like': ConstantOfShape(Shape('self')), #the default constant is 0, so don't need to speicify attribute
'aten::sum.dim_IntList': ReduceSum('self', 'dim', keepdims='keepdim'),
'aten::threshold_backward': ReluGrad('grad_output', 'self'),
'aten::fmod.Scalar': Mod('self', 'other', fmod=1),
'aten::fmod.Tensor': Mod('self', 'other', fmod=1),
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
'aten::gelu' : Gelu('self'),
'aten::gelu_backward' : GeluGrad('grad', 'self')
}
for binary_op, onnx_op in {
'add': Add('self', Mul('alpha', 'other')),
'sub': Sub('self', Mul('alpha', 'other')),
'mul': Mul('self', 'other'),
'div': Div('self', 'other')}.items():
for dtype in ['Tensor', 'Scalar']:
for variant in ['', '_']:
ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op)
for unary_op in [
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round',
'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero',
'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']:
aten_name = f'aten::{unary_op}'
onnx_op = onnx_ops[unary_op]('self')
ops[aten_name] = onnx_op
# produce the in-place variant as well for ops that support it
if unary_op not in ['isnan', 'nonzero', 'min', 'max', 'isinf', 'det']:
ops[f'{aten_name}_'] = onnx_op
ortgen = ORTGen(ops)
import os
import sys
from opgen.parser import cpp_create_from_file as CPPParser
from opgen.writer import SourceWriter as SourceWriter
if args.use_preinstalled_torch:
import torch
regdecs_path = Path(torch.__file__).parent.joinpath('include/ATen/RegistrationDeclarations.h')
else:
regdecs_path = os.path.realpath(os.path.join(
os.path.dirname(__file__),
'..',
'..',
'..',
'external',
'pytorch',
'build',
'aten',
'src',
'ATen',
'RegistrationDeclarations.h'))
regdecs_path = args.header_file
print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}")
output = sys.stdout
if not args.output_file is None:
if args.output_file:
output = open(args.output_file, 'wt')
with CPPParser(regdecs_path) as parser, SourceWriter(output) as writer:

View file

@ -0,0 +1,73 @@
from copy import deepcopy
from opgen.generator import \
ORTGen as ORTGen, \
ONNXOp as ONNXOp, \
SignatureOnly as SignatureOnly, \
MakeFallthrough as MakeFallthrough
from opgen.onnxops import *
kMSDomain = 'onnxruntime::kMSDomain'
class ReluGrad(ONNXOp):
def __init__(self, dY, X):
super().__init__('ReluGrad', 1, dY, X)
self.domain = kMSDomain
class Gelu(ONNXOp):
def __init__(self, X):
super().__init__('Gelu', 1, X)
self.domain = kMSDomain
class GeluGrad(ONNXOp):
def __init__(self, dY, X):
super().__init__('GeluGrad', 1, dY, X)
self.domain = kMSDomain
ops = {
# Hand-Implemented Ops
'aten::empty.memory_format': SignatureOnly(),
'aten::empty_strided': SignatureOnly(),
'aten::zero_': SignatureOnly(),
'aten::copy_': SignatureOnly(),
'aten::reshape': SignatureOnly(),
'aten::view': SignatureOnly(),
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
'aten::t': Transpose('self'),
'aten::mm': MatMul('self', 'mat2'),
'aten::zeros_like': ConstantOfShape(Shape('self')), #the default constant is 0, so don't need to speicify attribute
'aten::sum.dim_IntList': ReduceSum('self', 'dim', keepdims='keepdim'),
'aten::threshold_backward': ReluGrad('grad_output', 'self'),
'aten::fmod.Scalar': Mod('self', 'other', fmod=1),
'aten::fmod.Tensor': Mod('self', 'other', fmod=1),
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
'aten::gelu' : Gelu('self'),
'aten::gelu_backward' : GeluGrad('grad', 'self')
}
for binary_op, onnx_op in {
'add': Add('self', Mul('alpha', 'other')),
'sub': Sub('self', Mul('alpha', 'other')),
'mul': Mul('self', 'other'),
'div': Div('self', 'other')}.items():
for dtype in ['Tensor', 'Scalar']:
for variant in ['', '_']:
ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op)
for unary_op in [
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round',
'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero',
'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']:
aten_name = f'aten::{unary_op}'
onnx_op = onnx_ops[unary_op]('self')
ops[aten_name] = onnx_op
# produce the in-place variant as well for ops that support it
if unary_op not in ['isnan', 'nonzero', 'min', 'max', 'isinf', 'det']:
ops[f'{aten_name}_'] = onnx_op

View file

@ -0,0 +1,15 @@
from copy import deepcopy
from opgen.generator import AttrType, ONNXAttr
from opgen.generator import \
ORTGen as ORTGen, \
ONNXOp as ONNXOp, \
SignatureOnly as SignatureOnly, \
MakeFallthrough as MakeFallthrough
from opgen.onnxops import *
ops = {
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB')
}

View file

@ -27,6 +27,7 @@ class AttrType:
STRING = 'const char*'
STRINGS = '<unsupported:STRINGS>'
TENSOR = 'at::Tensor'
LONG = 'at::ScalarType::Long'
class ONNXAttr:
def __init__(self, value, type: AttrType=None):
@ -77,35 +78,36 @@ class MakeFallthrough(ONNXOp):
class FunctionGenerationError(NotImplementedError):
def __init__(self, cpp_func: ast.FunctionDecl, message: str):
super().__init__(f'{message} (torch: {cpp_func.torch_func.torch_schema})')
super().__init__(f'{message} ({cpp_func.identifier})')
class MappedOpFunction:
def __init__(
self,
torch_op_namespace: str,
torch_op_name: str,
op_namespace: str,
mapped_op_name: str,
onnx_op: ONNXOp,
cpp_func: ast.FunctionDecl,
torch_func: ast.FunctionDecl,
signature_only: bool,
make_fallthrough: bool):
self.torch_op_namespace = torch_op_namespace
self.torch_op_name = torch_op_name
self.op_namespace = op_namespace
self.mapped_op_name = mapped_op_name
self.onnx_op = onnx_op
self.cpp_func = cpp_func
self.torch_func = torch_func
self.signature_only = signature_only
self.make_fallthrough = make_fallthrough
class ORTGen:
_mapped_ops: Dict[str, ONNXOp]
_custom_ops: bool
def __init__(
self,
ops: Optional[Dict[str, ONNXOp]] = None):
self._mapped_ops = {}
ops: Optional[Dict[str, ONNXOp]] = None,
custom_ops : bool = False):
self._mapped_ops = {}
if ops:
self.register_many(ops)
self._custom_ops = custom_ops
def register(self, aten_name: str, onnx_op: ONNXOp):
self._mapped_ops[aten_name] = onnx_op
@ -121,13 +123,13 @@ class ORTGen:
current_ns = None
for mapped_func in self._parse_mapped_function_decls(cpp_parser):
del self._mapped_ops[mapped_func.torch_func.identifier.value]
del self._mapped_ops[mapped_func.mapped_op_name]
generated_funcs.append(mapped_func)
if mapped_func.make_fallthrough:
continue
ns = mapped_func.torch_op_namespace
ns = mapped_func.op_namespace
if current_ns and current_ns != ns:
current_ns = None
writer.pop_namespace()
@ -137,7 +139,8 @@ class ORTGen:
writer.push_namespace(ns)
writer.writeline()
writer.writeline(f'// {mapped_func.torch_func.torch_schema}')
if mapped_func.cpp_func.torch_func:
writer.writeline(f'// {mapped_func.cpp_func.torch_func.torch_schema}')
self._write_function_signature(writer, mapped_func.cpp_func)
if mapped_func.signature_only:
@ -153,7 +156,10 @@ class ORTGen:
current_ns = None
writer.pop_namespace()
self._write_function_registrations(writer, generated_funcs)
if not self._custom_ops:
self._write_function_registrations(writer, generated_funcs)
else:
self._write_custom_ops_registrations(writer, generated_funcs)
self._write_file_postlude(writer)
if len(self._mapped_ops) > 0:
@ -164,6 +170,8 @@ class ORTGen:
writer.writeline('// AUTO-GENERATED CODE! - DO NOT EDIT!')
writer.writeline(f'// $ python {" ".join(sys.argv)}')
writer.writeline()
writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
writer.writeline()
writer.writeline('#include <torch/extension.h>')
writer.writeline()
writer.writeline('#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>')
@ -206,7 +214,7 @@ class ORTGen:
assert(len(cpp_func.parameters) > 0)
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type)
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
if return_alias_info and not return_alias_info.is_writable:
return_alias_info = None
in_place_param: ast.ParameterDecl = None
@ -225,16 +233,16 @@ class ORTGen:
# Fetch the ORT invoker from an at::Tensor.device()
# FIXME: find the first at::Tensor param anywhere in the signature
# instead of simply the first parameter?
first_torch_param = cpp_func.torch_func.parameters[0].member
first_param = cpp_func.parameters[0].member
if not isinstance(
first_torch_param.parameter_type.desugar(),
ast.TensorType):
first_param.parameter_type.desugar(),
ast.ConcreteType) or 'Tensor' not in first_param.parameter_type.desugar().identifier_tokens[0].value:
raise FunctionGenerationError(
cpp_func,
'First parameter must be an at::Tensor')
writer.write('auto& invoker = GetORTInvoker(')
writer.write(first_torch_param.identifier.value)
writer.write(first_param.identifier.value)
writer.writeline('.device());')
writer.writeline()
@ -329,11 +337,15 @@ class ORTGen:
# TODO: Handle mutliple results
# TODO: Assert return type
if not return_alias_info:
if not return_alias_info:
writer.writeline('return aten_tensor_from_ort(')
writer.push_indent()
writer.writeline(f'std::move({return_outputs}[0]),')
writer.writeline(f'{first_torch_param.identifier.value}.options());')
if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector':
writer.writeline(f'{return_outputs},')
writer.writeline(f'{first_param.identifier.value}.options());')
else:
writer.writeline(f'std::move({return_outputs}[0]),')
writer.writeline(f'{first_param.identifier.value}.options());')
writer.pop_indent()
return
@ -354,13 +366,13 @@ class ORTGen:
writer.writeline('ORT_LOG_DEBUG << "ATen init";')
for mapped_func in generated_funcs:
cpp_func, torch_func = mapped_func.cpp_func, mapped_func.torch_func
cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func
if mapped_func.make_fallthrough:
reg_function_arg = 'torch::CppFunction::makeFallthrough()'
else:
if mapped_func.torch_op_namespace:
reg_function_arg = f'{mapped_func.torch_op_namespace}::'
if mapped_func.op_namespace:
reg_function_arg = f'{mapped_func.op_namespace}::'
else:
reg_function_arg = ''
reg_function_arg += cpp_func.identifier.value
@ -375,6 +387,24 @@ class ORTGen:
writer.writeline('}')
writer.writeline()
def _write_custom_ops_registrations(
self,
writer: writer.SourceWriter,
generated_funcs: List[MappedOpFunction]):
writer.writeline()
writer.writeline('void GenerateCustomOpsBindings(pybind11::module_ m) {')
writer.push_indent()
writer.writeline('ORT_LOG_DEBUG << "GenerateCustomOpsBindings init";')
for mapped_func in generated_funcs:
cpp_func = mapped_func.cpp_func
writer.write('m.def(')
writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});')
writer.pop_indent()
writer.writeline('}')
writer.writeline()
def _get_alias_info(self, torch_type_or_param: Union[ast.Type, ast.ParameterDecl]):
if isinstance(torch_type_or_param, ast.ParameterDecl):
torch_type = torch_type_or_param.parameter_type
@ -383,28 +413,32 @@ class ORTGen:
return getattr(torch_type.desugar(), 'alias_info', None)
def _parse_mapped_function_decls(self, cpp_parser: parser.CPPParser):
for cpp_func, torch_func in self._parse_function_decls(cpp_parser):
torch_op_name = torch_func.identifier.value
if torch_op_name not in self._mapped_ops:
continue
onnx_op = self._mapped_ops[torch_op_name]
for cpp_func in self._parse_function_decls(cpp_parser):
torch_func = cpp_func.torch_func
if not torch_func:
op_namespace = None
op_name = cpp_func.identifier.value
else:
op_name = torch_func.identifier.value
try:
op_namespace = op_name[0:op_name.index('::')]
op_namewithoutnamespace = op_name[len(op_namespace) + 2:]
except:
op_namespace = None
op_namewithoutnamespace = op_name
cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '__')
onnx_op = self._mapped_ops.get(op_name)
if not onnx_op:
continue
try:
torch_op_namespace = torch_op_name[0:torch_op_name.index('::')]
torch_op_name = torch_op_name[len(torch_op_namespace) + 2:]
except:
torch_op_namespace = None
cpp_func.identifier.value = torch_op_name.replace('.', '__')
yield MappedOpFunction(
torch_op_namespace,
torch_op_name,
op_namespace,
op_name,
onnx_op,
cpp_func,
torch_func,
isinstance(onnx_op, SignatureOnly),
isinstance(onnx_op, MakeFallthrough))
@ -415,7 +449,11 @@ class ORTGen:
# Parse the Torch schema from the JSON comment that follows each C++ decl
# and link associated Torch and C++ decls (functions, parameters, returns)
for cpp_func in tu:
if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
if self._custom_ops == True:
# customops don't have torch schema
cpp_func.torch_func = None
yield cpp_func
elif cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
for trivia in cpp_func.semicolon.trailing_trivia:
if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT:
yield self._parse_and_link_torch_function_decl(cpp_func, trivia)
@ -463,4 +501,4 @@ class ORTGen:
torch_param = torch_func.parameters[i + j].member
cpp_param.torch_param.append(torch_param)
return cpp_func, torch_func
return cpp_func

View file

@ -37,6 +37,19 @@ at::Tensor aten_tensor_from_ort(
options));
}
const std::vector<at::Tensor> aten_tensor_from_ort(
std::vector<OrtValue>& ortvalues,
const at::TensorOptions& options) {
const size_t num_outputs = ortvalues.size();
std::vector<at::Tensor> atvalues = std::vector<at::Tensor>(num_outputs);
for (size_t i = 0; i < num_outputs; i++) {
atvalues[i] = at::Tensor(c10::make_intrusive<ORTTensorImpl>(
std::move(ortvalues[i]),
options));
}
return atvalues;
}
onnxruntime::MLDataType ort_scalar_type_from_aten(
at::ScalarType dtype) {
switch (dtype){

View file

@ -17,6 +17,10 @@ at::Tensor aten_tensor_from_ort(
OrtValue&& ot,
const at::TensorOptions& options);
const std::vector<at::Tensor> aten_tensor_from_ort(
std::vector<OrtValue>& ortvalues,
const at::TensorOptions& options);
onnxruntime::MLDataType ort_scalar_type_from_aten(
at::ScalarType dtype);

View file

@ -0,0 +1,9 @@
#include <torch/extension.h>
namespace torch_ort {
namespace eager {
void GenerateCustomOpsBindings(pybind11::module_ module);
} // namespace eager
} // namespace torch_ort

View file

@ -11,6 +11,7 @@
#include "python/onnxruntime_pybind_state_common.h"
#include "orttraining/core/framework/torch/dlpack_python.h"
#include <core/session/provider_bridge_ort.h>
#include "ort_customops.h"
namespace onnxruntime{
namespace python{
@ -67,6 +68,9 @@ void addObjectMethodsForEager(py::module& m){
if (!status.IsOK())
throw std::runtime_error(status.ErrorMessage());
});
auto customop_module = m.def_submodule("custom_ops");
torch_ort::eager::GenerateCustomOpsBindings(customop_module);
}
}

View file

@ -36,9 +36,9 @@ inline void CopyVectorToTensor(onnxruntime::ORTInvoker& invoker,
Ort::TypeToTensorType<T>::type,
&ort_value));
execution_provider.GetDataTransfer()->CopyTensor(
ORT_THROW_IF_ERROR(execution_provider.GetDataTransfer()->CopyTensor(
ort_value->Get<onnxruntime::Tensor>(),
tensor);
tensor));
}
// vector<bool> is specialized so we need to handle it separately

View file

@ -61,5 +61,13 @@ class OrtOpTests(unittest.TestCase):
ort_zeros = torch.zeros_like(ones.to(device))
assert torch.allclose(cpu_zeros, ort_zeros.cpu())
def test_gemm(self):
device = self.get_device()
cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
ort_ones = cpu_ones.to(device)
cpu_ans = cpu_ones * 4
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
assert torch.allclose(cpu_ans, ort_ans.cpu())
if __name__ == '__main__':
unittest.main()

View file

@ -49,7 +49,7 @@ constexpr const char* const k_gradient_norm_op_name = "ReduceAllL2";
constexpr const char* const k_unscale_op_name = "MixedPrecisionScale";
constexpr const char* const k_inplace_accumulator_op_name = "InPlaceAccumulator";
constexpr const char* const k_zero_gradient_op_name = "ZeroGradient";
#if defined(USE_MPI)
#if defined(ORT_USE_NCCL) && defined(USE_MPI)
constexpr const char* const k_adasum_op_name = "AdasumAllReduce";
#endif
Status SetUpBaseGraph(Graph& graph);

View file

@ -535,9 +535,14 @@ def parse_arguments():
parser.add_argument(
"--ms_experimental", action='store_true', help="Build microsoft experimental operators.")\
# eager mode
parser.add_argument(
"--build_eager_mode", action='store_true',
help="Build ONNXRuntime micro-benchmarks.")
parser.add_argument('--eager_customop_module', default=None,
help='Module containing custom op mappings for eager mode.')
parser.add_argument('--eager_customop_header', default=None,
help='Header containing custom op definitions for eager mode.')
return parser.parse_args()
@ -2146,21 +2151,56 @@ def main():
args.rocm_version = ""
if args.build_eager_mode:
# generate the ort aten backend code
def gen_ort_aten_ops(eager_root_dir):
gen_cpp_name = os.path.join(eager_root_dir, "ort_aten.g.cpp")
if os.path.exists(gen_cpp_name):
os.remove(gen_cpp_name)
subprocess.check_call([
sys.executable,
os.path.join(eager_root_dir, 'opgen', 'opgen.py'),
"--output_file",
gen_cpp_name,
"--use_preinstalled_torch"
])
eager_root_dir = os.path.join(source_dir, "orttraining", "orttraining", "eager")
gen_ort_aten_ops(eager_root_dir)
if args.eager_customop_module and not args.eager_customop_header:
raise Exception('eager_customop_header must be provided when eager_customop_module is')
elif args.eager_customop_header and not args.eager_customop_module:
raise Exception('eager_customop_module must be provided when eager_customop_header is')
def gen_ops(gen_cpp_name: str, header_file: str, ops_module: str, custom_ops: bool):
gen_cpp_scratch_name = gen_cpp_name + '.working'
print(f'Generating ORT ATen overrides (output_file: {gen_cpp_name}, header_file: {header_file},'
f'ops_module: {ops_module}), custom_ops: {custom_ops}')
cmd = [sys.executable, os.path.join(os.path.join(eager_root_dir, 'opgen', 'opgen.py')),
'--output_file', gen_cpp_scratch_name,
'--ops_module', ops_module,
'--header_file', header_file]
if custom_ops:
cmd += ["--custom_ops"]
subprocess.check_call(cmd)
import filecmp
if (not os.path.isfile(gen_cpp_name) or
not filecmp.cmp(gen_cpp_name, gen_cpp_scratch_name, shallow=False)):
os.rename(gen_cpp_scratch_name, gen_cpp_name)
else:
os.remove(gen_cpp_scratch_name)
def gen_ort_ops():
# generate native aten ops
import torch
regdecs_path = os.path.join(os.path.dirname(torch.__file__), 'include/ATen/RegistrationDeclarations.h')
ops_module = os.path.join(eager_root_dir, 'opgen/opgen/atenops.py')
gen_ops(os.path.join(eager_root_dir, 'ort_aten.g.cpp'), regdecs_path, ops_module, False)
# generate custom ops
if not args.eager_customop_header:
args.eager_customop_header = os.path.realpath(os.path.join(
eager_root_dir,
"opgen",
"CustomOpDeclarations.h"))
if not args.eager_customop_module:
args.eager_customop_module = os.path.join(eager_root_dir, 'opgen/opgen/custom_ops.py')
gen_ops(os.path.join(eager_root_dir, 'ort_customops.g.cpp'),
args.eager_customop_header, args.eager_customop_module, True)
gen_ort_ops()
generate_build_tree(
cmake_path, source_dir, build_dir, cuda_home, cudnn_home, rocm_home, mpi_home, nccl_home,