mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Ported changes / bug fixes from torch/ort. (#8784)
* Ported changes / bug fixes from torch/ort. * Fixed formatting * Renamed function * Renamed module_ to module. * Revert "Renamed module_ to module." This reverts commit b17fc114b3db20d174283811d90592b5b8154c19. * Include pybind common header to fix linker errors on windows debug. * Fix to generation of > 1 custom op. Co-authored-by: Ashwin Hari <ashari@microsoft.com>
This commit is contained in:
parent
f51f2bad66
commit
2693af9799
14 changed files with 280 additions and 163 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -52,3 +52,4 @@ onnxruntime/python/version_info.py
|
|||
/tools/perf_util/target/classes
|
||||
/tools/perf_util/src/main/resources
|
||||
/orttraining/orttraining/eager/ort_aten.g.cpp
|
||||
/orttraining/orttraining/eager/ort_customops.g.cpp
|
||||
|
|
@ -0,0 +1 @@
|
|||
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);
|
||||
|
|
@ -3,120 +3,31 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from opgen.writer import SourceWriter as SourceWriter
|
||||
from opgen.parser import cpp_create_from_file as CPPParser
|
||||
import sys
|
||||
import os
|
||||
from opgen.generator import ORTGen as ORTGen
|
||||
from importlib.machinery import SourceFileLoader
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from opgen.generator import \
|
||||
ORTGen as ORTGen, \
|
||||
ONNXOp as ONNXOp, \
|
||||
SignatureOnly as SignatureOnly, \
|
||||
MakeFallthrough as MakeFallthrough
|
||||
|
||||
from opgen.onnxops import *
|
||||
|
||||
kMSDomain = 'onnxruntime::kMSDomain'
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate ORT ATen operations')
|
||||
parser.add_argument('--ops_module', type=str,
|
||||
help='Python module containing the Onnx Operation signature and list of ops to map')
|
||||
parser.add_argument('--output_file', default=None, type=str, help='Output file [default to std out]')
|
||||
parser.add_argument('--use_preinstalled_torch', action='store_true', help='Use pre-installed torch from the python environment')
|
||||
parser.add_argument('--header_file', type=str,
|
||||
help='Header file which contains ATen / Pytorch operation signature')
|
||||
parser.add_argument('--custom_ops', action='store_true', help='Whether we are generating code for custom ops or native operation')
|
||||
|
||||
args = parser.parse_args()
|
||||
ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
|
||||
|
||||
ortgen = ORTGen(ops_module.ops, custom_ops=args.custom_ops)
|
||||
|
||||
class ReluGrad(ONNXOp):
|
||||
def __init__(self, dY, X):
|
||||
super().__init__('ReluGrad', 1, dY, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
class Gelu(ONNXOp):
|
||||
def __init__(self, X):
|
||||
super().__init__('Gelu', 1, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
class GeluGrad(ONNXOp):
|
||||
def __init__(self, dY, X):
|
||||
super().__init__('GeluGrad', 1, dY, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
ops = {
|
||||
# Hand-Implemented Ops
|
||||
'aten::empty.memory_format': SignatureOnly(),
|
||||
'aten::empty_strided': SignatureOnly(),
|
||||
'aten::zero_': SignatureOnly(),
|
||||
'aten::copy_': SignatureOnly(),
|
||||
'aten::reshape': SignatureOnly(),
|
||||
'aten::view': SignatureOnly(),
|
||||
|
||||
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
|
||||
'aten::t': Transpose('self'),
|
||||
'aten::mm': MatMul('self', 'mat2'),
|
||||
'aten::zeros_like': ConstantOfShape(Shape('self')), #the default constant is 0, so don't need to speicify attribute
|
||||
|
||||
'aten::sum.dim_IntList': ReduceSum('self', 'dim', keepdims='keepdim'),
|
||||
'aten::threshold_backward': ReluGrad('grad_output', 'self'),
|
||||
|
||||
'aten::fmod.Scalar': Mod('self', 'other', fmod=1),
|
||||
'aten::fmod.Tensor': Mod('self', 'other', fmod=1),
|
||||
|
||||
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
|
||||
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
|
||||
'aten::gelu' : Gelu('self'),
|
||||
'aten::gelu_backward' : GeluGrad('grad', 'self')
|
||||
}
|
||||
|
||||
for binary_op, onnx_op in {
|
||||
'add': Add('self', Mul('alpha', 'other')),
|
||||
'sub': Sub('self', Mul('alpha', 'other')),
|
||||
'mul': Mul('self', 'other'),
|
||||
'div': Div('self', 'other')}.items():
|
||||
for dtype in ['Tensor', 'Scalar']:
|
||||
for variant in ['', '_']:
|
||||
ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op)
|
||||
|
||||
for unary_op in [
|
||||
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
|
||||
'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round',
|
||||
'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero',
|
||||
'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']:
|
||||
aten_name = f'aten::{unary_op}'
|
||||
onnx_op = onnx_ops[unary_op]('self')
|
||||
ops[aten_name] = onnx_op
|
||||
# produce the in-place variant as well for ops that support it
|
||||
if unary_op not in ['isnan', 'nonzero', 'min', 'max', 'isinf', 'det']:
|
||||
ops[f'{aten_name}_'] = onnx_op
|
||||
|
||||
ortgen = ORTGen(ops)
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from opgen.parser import cpp_create_from_file as CPPParser
|
||||
from opgen.writer import SourceWriter as SourceWriter
|
||||
|
||||
if args.use_preinstalled_torch:
|
||||
import torch
|
||||
regdecs_path = Path(torch.__file__).parent.joinpath('include/ATen/RegistrationDeclarations.h')
|
||||
else:
|
||||
regdecs_path = os.path.realpath(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'..',
|
||||
'..',
|
||||
'..',
|
||||
'external',
|
||||
'pytorch',
|
||||
'build',
|
||||
'aten',
|
||||
'src',
|
||||
'ATen',
|
||||
'RegistrationDeclarations.h'))
|
||||
|
||||
regdecs_path = args.header_file
|
||||
print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}")
|
||||
output = sys.stdout
|
||||
if not args.output_file is None:
|
||||
if args.output_file:
|
||||
output = open(args.output_file, 'wt')
|
||||
|
||||
with CPPParser(regdecs_path) as parser, SourceWriter(output) as writer:
|
||||
|
|
|
|||
73
orttraining/orttraining/eager/opgen/opgen/atenops.py
Normal file
73
orttraining/orttraining/eager/opgen/opgen/atenops.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from copy import deepcopy
|
||||
|
||||
from opgen.generator import \
|
||||
ORTGen as ORTGen, \
|
||||
ONNXOp as ONNXOp, \
|
||||
SignatureOnly as SignatureOnly, \
|
||||
MakeFallthrough as MakeFallthrough
|
||||
|
||||
from opgen.onnxops import *
|
||||
|
||||
kMSDomain = 'onnxruntime::kMSDomain'
|
||||
|
||||
class ReluGrad(ONNXOp):
|
||||
def __init__(self, dY, X):
|
||||
super().__init__('ReluGrad', 1, dY, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
class Gelu(ONNXOp):
|
||||
def __init__(self, X):
|
||||
super().__init__('Gelu', 1, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
class GeluGrad(ONNXOp):
|
||||
def __init__(self, dY, X):
|
||||
super().__init__('GeluGrad', 1, dY, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
ops = {
|
||||
# Hand-Implemented Ops
|
||||
'aten::empty.memory_format': SignatureOnly(),
|
||||
'aten::empty_strided': SignatureOnly(),
|
||||
'aten::zero_': SignatureOnly(),
|
||||
'aten::copy_': SignatureOnly(),
|
||||
'aten::reshape': SignatureOnly(),
|
||||
'aten::view': SignatureOnly(),
|
||||
|
||||
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
|
||||
'aten::t': Transpose('self'),
|
||||
'aten::mm': MatMul('self', 'mat2'),
|
||||
'aten::zeros_like': ConstantOfShape(Shape('self')), #the default constant is 0, so don't need to speicify attribute
|
||||
|
||||
'aten::sum.dim_IntList': ReduceSum('self', 'dim', keepdims='keepdim'),
|
||||
'aten::threshold_backward': ReluGrad('grad_output', 'self'),
|
||||
|
||||
'aten::fmod.Scalar': Mod('self', 'other', fmod=1),
|
||||
'aten::fmod.Tensor': Mod('self', 'other', fmod=1),
|
||||
|
||||
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
|
||||
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
|
||||
'aten::gelu' : Gelu('self'),
|
||||
'aten::gelu_backward' : GeluGrad('grad', 'self')
|
||||
}
|
||||
|
||||
for binary_op, onnx_op in {
|
||||
'add': Add('self', Mul('alpha', 'other')),
|
||||
'sub': Sub('self', Mul('alpha', 'other')),
|
||||
'mul': Mul('self', 'other'),
|
||||
'div': Div('self', 'other')}.items():
|
||||
for dtype in ['Tensor', 'Scalar']:
|
||||
for variant in ['', '_']:
|
||||
ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op)
|
||||
|
||||
for unary_op in [
|
||||
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
|
||||
'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round',
|
||||
'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero',
|
||||
'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']:
|
||||
aten_name = f'aten::{unary_op}'
|
||||
onnx_op = onnx_ops[unary_op]('self')
|
||||
ops[aten_name] = onnx_op
|
||||
# produce the in-place variant as well for ops that support it
|
||||
if unary_op not in ['isnan', 'nonzero', 'min', 'max', 'isinf', 'det']:
|
||||
ops[f'{aten_name}_'] = onnx_op
|
||||
15
orttraining/orttraining/eager/opgen/opgen/custom_ops.py
Normal file
15
orttraining/orttraining/eager/opgen/opgen/custom_ops.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from copy import deepcopy
|
||||
|
||||
from opgen.generator import AttrType, ONNXAttr
|
||||
|
||||
from opgen.generator import \
|
||||
ORTGen as ORTGen, \
|
||||
ONNXOp as ONNXOp, \
|
||||
SignatureOnly as SignatureOnly, \
|
||||
MakeFallthrough as MakeFallthrough
|
||||
|
||||
from opgen.onnxops import *
|
||||
|
||||
ops = {
|
||||
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB')
|
||||
}
|
||||
|
|
@ -27,6 +27,7 @@ class AttrType:
|
|||
STRING = 'const char*'
|
||||
STRINGS = '<unsupported:STRINGS>'
|
||||
TENSOR = 'at::Tensor'
|
||||
LONG = 'at::ScalarType::Long'
|
||||
|
||||
class ONNXAttr:
|
||||
def __init__(self, value, type: AttrType=None):
|
||||
|
|
@ -77,35 +78,36 @@ class MakeFallthrough(ONNXOp):
|
|||
|
||||
class FunctionGenerationError(NotImplementedError):
|
||||
def __init__(self, cpp_func: ast.FunctionDecl, message: str):
|
||||
super().__init__(f'{message} (torch: {cpp_func.torch_func.torch_schema})')
|
||||
super().__init__(f'{message} ({cpp_func.identifier})')
|
||||
|
||||
class MappedOpFunction:
|
||||
def __init__(
|
||||
self,
|
||||
torch_op_namespace: str,
|
||||
torch_op_name: str,
|
||||
op_namespace: str,
|
||||
mapped_op_name: str,
|
||||
onnx_op: ONNXOp,
|
||||
cpp_func: ast.FunctionDecl,
|
||||
torch_func: ast.FunctionDecl,
|
||||
signature_only: bool,
|
||||
make_fallthrough: bool):
|
||||
self.torch_op_namespace = torch_op_namespace
|
||||
self.torch_op_name = torch_op_name
|
||||
self.op_namespace = op_namespace
|
||||
self.mapped_op_name = mapped_op_name
|
||||
self.onnx_op = onnx_op
|
||||
self.cpp_func = cpp_func
|
||||
self.torch_func = torch_func
|
||||
self.signature_only = signature_only
|
||||
self.make_fallthrough = make_fallthrough
|
||||
|
||||
class ORTGen:
|
||||
_mapped_ops: Dict[str, ONNXOp]
|
||||
_custom_ops: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ops: Optional[Dict[str, ONNXOp]] = None):
|
||||
self._mapped_ops = {}
|
||||
ops: Optional[Dict[str, ONNXOp]] = None,
|
||||
custom_ops : bool = False):
|
||||
self._mapped_ops = {}
|
||||
if ops:
|
||||
self.register_many(ops)
|
||||
self._custom_ops = custom_ops
|
||||
|
||||
def register(self, aten_name: str, onnx_op: ONNXOp):
|
||||
self._mapped_ops[aten_name] = onnx_op
|
||||
|
|
@ -121,13 +123,13 @@ class ORTGen:
|
|||
current_ns = None
|
||||
|
||||
for mapped_func in self._parse_mapped_function_decls(cpp_parser):
|
||||
del self._mapped_ops[mapped_func.torch_func.identifier.value]
|
||||
del self._mapped_ops[mapped_func.mapped_op_name]
|
||||
generated_funcs.append(mapped_func)
|
||||
|
||||
if mapped_func.make_fallthrough:
|
||||
continue
|
||||
|
||||
ns = mapped_func.torch_op_namespace
|
||||
ns = mapped_func.op_namespace
|
||||
if current_ns and current_ns != ns:
|
||||
current_ns = None
|
||||
writer.pop_namespace()
|
||||
|
|
@ -137,7 +139,8 @@ class ORTGen:
|
|||
writer.push_namespace(ns)
|
||||
|
||||
writer.writeline()
|
||||
writer.writeline(f'// {mapped_func.torch_func.torch_schema}')
|
||||
if mapped_func.cpp_func.torch_func:
|
||||
writer.writeline(f'// {mapped_func.cpp_func.torch_func.torch_schema}')
|
||||
|
||||
self._write_function_signature(writer, mapped_func.cpp_func)
|
||||
if mapped_func.signature_only:
|
||||
|
|
@ -153,7 +156,10 @@ class ORTGen:
|
|||
current_ns = None
|
||||
writer.pop_namespace()
|
||||
|
||||
self._write_function_registrations(writer, generated_funcs)
|
||||
if not self._custom_ops:
|
||||
self._write_function_registrations(writer, generated_funcs)
|
||||
else:
|
||||
self._write_custom_ops_registrations(writer, generated_funcs)
|
||||
self._write_file_postlude(writer)
|
||||
|
||||
if len(self._mapped_ops) > 0:
|
||||
|
|
@ -164,6 +170,8 @@ class ORTGen:
|
|||
writer.writeline('// AUTO-GENERATED CODE! - DO NOT EDIT!')
|
||||
writer.writeline(f'// $ python {" ".join(sys.argv)}')
|
||||
writer.writeline()
|
||||
writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
|
||||
writer.writeline()
|
||||
writer.writeline('#include <torch/extension.h>')
|
||||
writer.writeline()
|
||||
writer.writeline('#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>')
|
||||
|
|
@ -206,7 +214,7 @@ class ORTGen:
|
|||
|
||||
assert(len(cpp_func.parameters) > 0)
|
||||
|
||||
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type)
|
||||
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
|
||||
if return_alias_info and not return_alias_info.is_writable:
|
||||
return_alias_info = None
|
||||
in_place_param: ast.ParameterDecl = None
|
||||
|
|
@ -225,16 +233,16 @@ class ORTGen:
|
|||
# Fetch the ORT invoker from an at::Tensor.device()
|
||||
# FIXME: find the first at::Tensor param anywhere in the signature
|
||||
# instead of simply the first parameter?
|
||||
first_torch_param = cpp_func.torch_func.parameters[0].member
|
||||
first_param = cpp_func.parameters[0].member
|
||||
if not isinstance(
|
||||
first_torch_param.parameter_type.desugar(),
|
||||
ast.TensorType):
|
||||
first_param.parameter_type.desugar(),
|
||||
ast.ConcreteType) or 'Tensor' not in first_param.parameter_type.desugar().identifier_tokens[0].value:
|
||||
raise FunctionGenerationError(
|
||||
cpp_func,
|
||||
'First parameter must be an at::Tensor')
|
||||
|
||||
writer.write('auto& invoker = GetORTInvoker(')
|
||||
writer.write(first_torch_param.identifier.value)
|
||||
writer.write(first_param.identifier.value)
|
||||
writer.writeline('.device());')
|
||||
writer.writeline()
|
||||
|
||||
|
|
@ -329,11 +337,15 @@ class ORTGen:
|
|||
# TODO: Handle mutliple results
|
||||
# TODO: Assert return type
|
||||
|
||||
if not return_alias_info:
|
||||
if not return_alias_info:
|
||||
writer.writeline('return aten_tensor_from_ort(')
|
||||
writer.push_indent()
|
||||
writer.writeline(f'std::move({return_outputs}[0]),')
|
||||
writer.writeline(f'{first_torch_param.identifier.value}.options());')
|
||||
if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector':
|
||||
writer.writeline(f'{return_outputs},')
|
||||
writer.writeline(f'{first_param.identifier.value}.options());')
|
||||
else:
|
||||
writer.writeline(f'std::move({return_outputs}[0]),')
|
||||
writer.writeline(f'{first_param.identifier.value}.options());')
|
||||
writer.pop_indent()
|
||||
return
|
||||
|
||||
|
|
@ -354,13 +366,13 @@ class ORTGen:
|
|||
writer.writeline('ORT_LOG_DEBUG << "ATen init";')
|
||||
|
||||
for mapped_func in generated_funcs:
|
||||
cpp_func, torch_func = mapped_func.cpp_func, mapped_func.torch_func
|
||||
cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func
|
||||
|
||||
if mapped_func.make_fallthrough:
|
||||
reg_function_arg = 'torch::CppFunction::makeFallthrough()'
|
||||
else:
|
||||
if mapped_func.torch_op_namespace:
|
||||
reg_function_arg = f'{mapped_func.torch_op_namespace}::'
|
||||
if mapped_func.op_namespace:
|
||||
reg_function_arg = f'{mapped_func.op_namespace}::'
|
||||
else:
|
||||
reg_function_arg = ''
|
||||
reg_function_arg += cpp_func.identifier.value
|
||||
|
|
@ -375,6 +387,24 @@ class ORTGen:
|
|||
writer.writeline('}')
|
||||
writer.writeline()
|
||||
|
||||
def _write_custom_ops_registrations(
|
||||
self,
|
||||
writer: writer.SourceWriter,
|
||||
generated_funcs: List[MappedOpFunction]):
|
||||
writer.writeline()
|
||||
writer.writeline('void GenerateCustomOpsBindings(pybind11::module_ m) {')
|
||||
writer.push_indent()
|
||||
writer.writeline('ORT_LOG_DEBUG << "GenerateCustomOpsBindings init";')
|
||||
|
||||
for mapped_func in generated_funcs:
|
||||
cpp_func = mapped_func.cpp_func
|
||||
writer.write('m.def(')
|
||||
writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});')
|
||||
|
||||
writer.pop_indent()
|
||||
writer.writeline('}')
|
||||
writer.writeline()
|
||||
|
||||
def _get_alias_info(self, torch_type_or_param: Union[ast.Type, ast.ParameterDecl]):
|
||||
if isinstance(torch_type_or_param, ast.ParameterDecl):
|
||||
torch_type = torch_type_or_param.parameter_type
|
||||
|
|
@ -383,28 +413,32 @@ class ORTGen:
|
|||
return getattr(torch_type.desugar(), 'alias_info', None)
|
||||
|
||||
def _parse_mapped_function_decls(self, cpp_parser: parser.CPPParser):
|
||||
for cpp_func, torch_func in self._parse_function_decls(cpp_parser):
|
||||
torch_op_name = torch_func.identifier.value
|
||||
if torch_op_name not in self._mapped_ops:
|
||||
continue
|
||||
onnx_op = self._mapped_ops[torch_op_name]
|
||||
for cpp_func in self._parse_function_decls(cpp_parser):
|
||||
torch_func = cpp_func.torch_func
|
||||
if not torch_func:
|
||||
op_namespace = None
|
||||
op_name = cpp_func.identifier.value
|
||||
else:
|
||||
op_name = torch_func.identifier.value
|
||||
|
||||
try:
|
||||
op_namespace = op_name[0:op_name.index('::')]
|
||||
op_namewithoutnamespace = op_name[len(op_namespace) + 2:]
|
||||
except:
|
||||
op_namespace = None
|
||||
op_namewithoutnamespace = op_name
|
||||
|
||||
cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '__')
|
||||
|
||||
onnx_op = self._mapped_ops.get(op_name)
|
||||
if not onnx_op:
|
||||
continue
|
||||
|
||||
try:
|
||||
torch_op_namespace = torch_op_name[0:torch_op_name.index('::')]
|
||||
torch_op_name = torch_op_name[len(torch_op_namespace) + 2:]
|
||||
except:
|
||||
torch_op_namespace = None
|
||||
|
||||
cpp_func.identifier.value = torch_op_name.replace('.', '__')
|
||||
|
||||
yield MappedOpFunction(
|
||||
torch_op_namespace,
|
||||
torch_op_name,
|
||||
op_namespace,
|
||||
op_name,
|
||||
onnx_op,
|
||||
cpp_func,
|
||||
torch_func,
|
||||
isinstance(onnx_op, SignatureOnly),
|
||||
isinstance(onnx_op, MakeFallthrough))
|
||||
|
||||
|
|
@ -415,7 +449,11 @@ class ORTGen:
|
|||
# Parse the Torch schema from the JSON comment that follows each C++ decl
|
||||
# and link associated Torch and C++ decls (functions, parameters, returns)
|
||||
for cpp_func in tu:
|
||||
if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
|
||||
if self._custom_ops == True:
|
||||
# customops don't have torch schema
|
||||
cpp_func.torch_func = None
|
||||
yield cpp_func
|
||||
elif cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
|
||||
for trivia in cpp_func.semicolon.trailing_trivia:
|
||||
if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT:
|
||||
yield self._parse_and_link_torch_function_decl(cpp_func, trivia)
|
||||
|
|
@ -463,4 +501,4 @@ class ORTGen:
|
|||
torch_param = torch_func.parameters[i + j].member
|
||||
cpp_param.torch_param.append(torch_param)
|
||||
|
||||
return cpp_func, torch_func
|
||||
return cpp_func
|
||||
|
|
@ -37,6 +37,19 @@ at::Tensor aten_tensor_from_ort(
|
|||
options));
|
||||
}
|
||||
|
||||
const std::vector<at::Tensor> aten_tensor_from_ort(
|
||||
std::vector<OrtValue>& ortvalues,
|
||||
const at::TensorOptions& options) {
|
||||
const size_t num_outputs = ortvalues.size();
|
||||
std::vector<at::Tensor> atvalues = std::vector<at::Tensor>(num_outputs);
|
||||
for (size_t i = 0; i < num_outputs; i++) {
|
||||
atvalues[i] = at::Tensor(c10::make_intrusive<ORTTensorImpl>(
|
||||
std::move(ortvalues[i]),
|
||||
options));
|
||||
}
|
||||
return atvalues;
|
||||
}
|
||||
|
||||
onnxruntime::MLDataType ort_scalar_type_from_aten(
|
||||
at::ScalarType dtype) {
|
||||
switch (dtype){
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ at::Tensor aten_tensor_from_ort(
|
|||
OrtValue&& ot,
|
||||
const at::TensorOptions& options);
|
||||
|
||||
const std::vector<at::Tensor> aten_tensor_from_ort(
|
||||
std::vector<OrtValue>& ortvalues,
|
||||
const at::TensorOptions& options);
|
||||
|
||||
onnxruntime::MLDataType ort_scalar_type_from_aten(
|
||||
at::ScalarType dtype);
|
||||
|
||||
|
|
|
|||
9
orttraining/orttraining/eager/ort_customops.h
Normal file
9
orttraining/orttraining/eager/ort_customops.h
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
||||
void GenerateCustomOpsBindings(pybind11::module_ module);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
|
|
@ -11,6 +11,7 @@
|
|||
#include "python/onnxruntime_pybind_state_common.h"
|
||||
#include "orttraining/core/framework/torch/dlpack_python.h"
|
||||
#include <core/session/provider_bridge_ort.h>
|
||||
#include "ort_customops.h"
|
||||
|
||||
namespace onnxruntime{
|
||||
namespace python{
|
||||
|
|
@ -67,6 +68,9 @@ void addObjectMethodsForEager(py::module& m){
|
|||
if (!status.IsOK())
|
||||
throw std::runtime_error(status.ErrorMessage());
|
||||
});
|
||||
|
||||
auto customop_module = m.def_submodule("custom_ops");
|
||||
torch_ort::eager::GenerateCustomOpsBindings(customop_module);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,9 +36,9 @@ inline void CopyVectorToTensor(onnxruntime::ORTInvoker& invoker,
|
|||
Ort::TypeToTensorType<T>::type,
|
||||
&ort_value));
|
||||
|
||||
execution_provider.GetDataTransfer()->CopyTensor(
|
||||
ORT_THROW_IF_ERROR(execution_provider.GetDataTransfer()->CopyTensor(
|
||||
ort_value->Get<onnxruntime::Tensor>(),
|
||||
tensor);
|
||||
tensor));
|
||||
}
|
||||
|
||||
// vector<bool> is specialized so we need to handle it separately
|
||||
|
|
|
|||
|
|
@ -61,5 +61,13 @@ class OrtOpTests(unittest.TestCase):
|
|||
ort_zeros = torch.zeros_like(ones.to(device))
|
||||
assert torch.allclose(cpu_zeros, ort_zeros.cpu())
|
||||
|
||||
def test_gemm(self):
|
||||
device = self.get_device()
|
||||
cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
||||
ort_ones = cpu_ones.to(device)
|
||||
cpu_ans = cpu_ones * 4
|
||||
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
|
||||
assert torch.allclose(cpu_ans, ort_ans.cpu())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -49,7 +49,7 @@ constexpr const char* const k_gradient_norm_op_name = "ReduceAllL2";
|
|||
constexpr const char* const k_unscale_op_name = "MixedPrecisionScale";
|
||||
constexpr const char* const k_inplace_accumulator_op_name = "InPlaceAccumulator";
|
||||
constexpr const char* const k_zero_gradient_op_name = "ZeroGradient";
|
||||
#if defined(USE_MPI)
|
||||
#if defined(ORT_USE_NCCL) && defined(USE_MPI)
|
||||
constexpr const char* const k_adasum_op_name = "AdasumAllReduce";
|
||||
#endif
|
||||
Status SetUpBaseGraph(Graph& graph);
|
||||
|
|
|
|||
|
|
@ -535,9 +535,14 @@ def parse_arguments():
|
|||
parser.add_argument(
|
||||
"--ms_experimental", action='store_true', help="Build microsoft experimental operators.")\
|
||||
|
||||
# eager mode
|
||||
parser.add_argument(
|
||||
"--build_eager_mode", action='store_true',
|
||||
help="Build ONNXRuntime micro-benchmarks.")
|
||||
parser.add_argument('--eager_customop_module', default=None,
|
||||
help='Module containing custom op mappings for eager mode.')
|
||||
parser.add_argument('--eager_customop_header', default=None,
|
||||
help='Header containing custom op definitions for eager mode.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -2146,21 +2151,56 @@ def main():
|
|||
args.rocm_version = ""
|
||||
|
||||
if args.build_eager_mode:
|
||||
# generate the ort aten backend code
|
||||
def gen_ort_aten_ops(eager_root_dir):
|
||||
gen_cpp_name = os.path.join(eager_root_dir, "ort_aten.g.cpp")
|
||||
if os.path.exists(gen_cpp_name):
|
||||
os.remove(gen_cpp_name)
|
||||
subprocess.check_call([
|
||||
sys.executable,
|
||||
os.path.join(eager_root_dir, 'opgen', 'opgen.py'),
|
||||
"--output_file",
|
||||
gen_cpp_name,
|
||||
"--use_preinstalled_torch"
|
||||
])
|
||||
|
||||
eager_root_dir = os.path.join(source_dir, "orttraining", "orttraining", "eager")
|
||||
gen_ort_aten_ops(eager_root_dir)
|
||||
if args.eager_customop_module and not args.eager_customop_header:
|
||||
raise Exception('eager_customop_header must be provided when eager_customop_module is')
|
||||
elif args.eager_customop_header and not args.eager_customop_module:
|
||||
raise Exception('eager_customop_module must be provided when eager_customop_header is')
|
||||
|
||||
def gen_ops(gen_cpp_name: str, header_file: str, ops_module: str, custom_ops: bool):
|
||||
gen_cpp_scratch_name = gen_cpp_name + '.working'
|
||||
print(f'Generating ORT ATen overrides (output_file: {gen_cpp_name}, header_file: {header_file},'
|
||||
f'ops_module: {ops_module}), custom_ops: {custom_ops}')
|
||||
|
||||
cmd = [sys.executable, os.path.join(os.path.join(eager_root_dir, 'opgen', 'opgen.py')),
|
||||
'--output_file', gen_cpp_scratch_name,
|
||||
'--ops_module', ops_module,
|
||||
'--header_file', header_file]
|
||||
|
||||
if custom_ops:
|
||||
cmd += ["--custom_ops"]
|
||||
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
import filecmp
|
||||
if (not os.path.isfile(gen_cpp_name) or
|
||||
not filecmp.cmp(gen_cpp_name, gen_cpp_scratch_name, shallow=False)):
|
||||
os.rename(gen_cpp_scratch_name, gen_cpp_name)
|
||||
else:
|
||||
os.remove(gen_cpp_scratch_name)
|
||||
|
||||
def gen_ort_ops():
|
||||
# generate native aten ops
|
||||
import torch
|
||||
regdecs_path = os.path.join(os.path.dirname(torch.__file__), 'include/ATen/RegistrationDeclarations.h')
|
||||
|
||||
ops_module = os.path.join(eager_root_dir, 'opgen/opgen/atenops.py')
|
||||
gen_ops(os.path.join(eager_root_dir, 'ort_aten.g.cpp'), regdecs_path, ops_module, False)
|
||||
|
||||
# generate custom ops
|
||||
if not args.eager_customop_header:
|
||||
args.eager_customop_header = os.path.realpath(os.path.join(
|
||||
eager_root_dir,
|
||||
"opgen",
|
||||
"CustomOpDeclarations.h"))
|
||||
|
||||
if not args.eager_customop_module:
|
||||
args.eager_customop_module = os.path.join(eager_root_dir, 'opgen/opgen/custom_ops.py')
|
||||
|
||||
gen_ops(os.path.join(eager_root_dir, 'ort_customops.g.cpp'),
|
||||
args.eager_customop_header, args.eager_customop_module, True)
|
||||
|
||||
gen_ort_ops()
|
||||
|
||||
generate_build_tree(
|
||||
cmake_path, source_dir, build_dir, cuda_home, cudnn_home, rocm_home, mpi_home, nccl_home,
|
||||
|
|
|
|||
Loading…
Reference in a new issue