mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Added support to Eager CodeGen for multiple in-place parameters. (#10945)
* Added support to CodeGen for multiple inplace output parameters. * Updated output Tensor to references.
This commit is contained in:
parent
1cc2cfb7b8
commit
4a5b5328a4
7 changed files with 81 additions and 40 deletions
|
|
@ -1 +1,2 @@
|
|||
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);
|
||||
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const float epsilon, const float momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"}
|
||||
|
|
@ -25,7 +25,7 @@ ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
|
|||
ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops)
|
||||
|
||||
regdecs_path = args.header_file
|
||||
print(f"INFO: Using ATen RegistrationDeclations from: {regdecs_path}")
|
||||
print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}")
|
||||
output = sys.stdout
|
||||
if args.output_file:
|
||||
output = open(args.output_file, 'wt')
|
||||
|
|
|
|||
|
|
@ -219,14 +219,12 @@ class TupleType(Type):
|
|||
class AliasInfo(Node):
|
||||
before_set: List[str]
|
||||
after_set: List[str]
|
||||
contained_types: List[Type]
|
||||
tokens: List[Token]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.before_set = []
|
||||
self.after_set = []
|
||||
self.contained_types = []
|
||||
self.tokens = []
|
||||
self.is_writable = False
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ from opgen.generator import \
|
|||
from opgen.onnxops import *
|
||||
|
||||
ops = {
|
||||
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB')
|
||||
'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB'),
|
||||
'batchnorm_inplace': BatchNormalization('X', 'scale', 'B', 'input_mean', 'input_var', 'epsilon', 'momentum', 1)
|
||||
}
|
||||
|
||||
type_promotion_ops = {}
|
||||
|
|
|
|||
|
|
@ -246,11 +246,6 @@ class ORTGen:
|
|||
if mapped_func.make_torch_fallback:
|
||||
return self._write_cpu_fall_back(writer, mapped_func)
|
||||
|
||||
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
|
||||
if return_alias_info and not return_alias_info.is_writable:
|
||||
return_alias_info = None
|
||||
in_place_param: ast.ParameterDecl = None
|
||||
|
||||
# Eval the outer ONNX op to produce a topologically ordered list of ops
|
||||
ctx = ONNXOpEvalContext()
|
||||
onnx_op.eval(ctx)
|
||||
|
|
@ -339,13 +334,7 @@ class ORTGen:
|
|||
for op_input in onnx_op.inputs:
|
||||
if isinstance(op_input, Outputs):
|
||||
continue
|
||||
# See if this input is aliased as an in-place tensor
|
||||
cpp_param = cpp_func.get_parameter(op_input)
|
||||
if return_alias_info and cpp_param:
|
||||
for torch_p in cpp_param.torch_param:
|
||||
if self._get_alias_info(torch_p) == return_alias_info:
|
||||
in_place_param = cpp_param
|
||||
|
||||
writer.write(f'auto ort_input_{op_input} = ')
|
||||
writer.writeline(f'create_ort_value(invoker, {op_input});')
|
||||
if need_type_promotion:
|
||||
|
|
@ -388,11 +377,36 @@ class ORTGen:
|
|||
writer.write(f'std::vector<OrtValue> {onnx_op.outputs}')
|
||||
writer.writeline(f'({onnx_op.outputs.count});')
|
||||
|
||||
if in_place_param:
|
||||
assert(onnx_op.outputs.count == 1)
|
||||
# TODO: This assumes that the first output corresponds to the first input.
|
||||
# This may not work for more complicated ops.
|
||||
writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[0]};')
|
||||
return_info = cpp_func.torch_func.return_type if cpp_func.torch_func else None
|
||||
in_place_params = {}
|
||||
|
||||
if return_info:
|
||||
for input_index, op_input in enumerate(onnx_op.inputs):
|
||||
if isinstance(op_input, Outputs):
|
||||
continue
|
||||
|
||||
# See if this input is aliased as an in-place tensor
|
||||
cpp_param = cpp_func.get_parameter(op_input)
|
||||
if cpp_param:
|
||||
for torch_p in cpp_param.torch_param:
|
||||
if isinstance(return_info, ast.TupleType):
|
||||
for output_index, output_param in enumerate(return_info.elements):
|
||||
assert isinstance(output_param.member, ast.TupleMemberType), "output_param.member must be of TupleMemberType"
|
||||
output_alias = self._get_alias_info(output_param.member.element_type)
|
||||
if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable:
|
||||
writer.writeline(f'{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op.inputs[input_index]};')
|
||||
in_place_params[output_index] = cpp_param.identifier.value
|
||||
break
|
||||
else:
|
||||
output_alias = self._get_alias_info(return_info)
|
||||
if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable:
|
||||
writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[input_index]};')
|
||||
in_place_params[0] = cpp_param.identifier.value
|
||||
break
|
||||
|
||||
if len(in_place_params) != 0 and len(in_place_params) != (len(return_info.elements) if isinstance(return_info, ast.TupleType) else 1):
|
||||
raise Exception(f'Cannot mix and match inplace with non-inplace parameters - function: {cpp_func.identifier.value} ' +
|
||||
f'in_place_params={in_place_params}, return_elements={return_info.elements}')
|
||||
|
||||
# Perform the invocation
|
||||
writer.writeline()
|
||||
|
|
@ -434,7 +448,7 @@ class ORTGen:
|
|||
# TODO: Handle mutliple results
|
||||
# TODO: Assert return type
|
||||
|
||||
if not return_alias_info:
|
||||
if len(in_place_params) == 0:
|
||||
# tensor options
|
||||
writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}')
|
||||
if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
|
||||
|
|
@ -454,12 +468,20 @@ class ORTGen:
|
|||
writer.writeline('tensor_options);')
|
||||
writer.pop_indent()
|
||||
return
|
||||
|
||||
if not in_place_param:
|
||||
raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' +
|
||||
'has alias info on its return type but no associated parameter')
|
||||
|
||||
writer.writeline(f'return {in_place_param.identifier.value};')
|
||||
else:
|
||||
if len(in_place_params) == 1:
|
||||
writer.writeline(f'return {in_place_params[0]};')
|
||||
else:
|
||||
if not (isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::tuple'):
|
||||
raise Exception(f'')
|
||||
tensorRef = "Tensor&," * len(in_place_params)
|
||||
tensorRef = tensorRef[:len(tensorRef)-1]
|
||||
writer.write(f'return std::tuple<{tensorRef}>(')
|
||||
for index, key in enumerate(sorted(in_place_params)):
|
||||
if index > 0:
|
||||
writer.write(', ')
|
||||
writer.write(in_place_params[key])
|
||||
writer.writeline(');')
|
||||
|
||||
def _write_function_registrations(
|
||||
self,
|
||||
|
|
@ -550,16 +572,19 @@ class ORTGen:
|
|||
# Parse the Torch schema from the JSON comment that follows each C++ decl
|
||||
# and link associated Torch and C++ decls (functions, parameters, returns)
|
||||
for cpp_func in tu:
|
||||
if self._custom_ops == True:
|
||||
# customops don't have torch schema
|
||||
cpp_func.torch_func = None
|
||||
yield cpp_func
|
||||
elif cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
|
||||
hasSchema = False
|
||||
if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
|
||||
for trivia in cpp_func.semicolon.trailing_trivia:
|
||||
if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT:
|
||||
yield self._parse_and_link_torch_function_decl(cpp_func, trivia)
|
||||
hasSchema = True
|
||||
break
|
||||
|
||||
if not hasSchema:
|
||||
# customops might not have torch schema
|
||||
cpp_func.torch_func = None
|
||||
yield cpp_func
|
||||
|
||||
def _parse_and_link_torch_function_decl(
|
||||
self,
|
||||
cpp_func: ast.FunctionDecl,
|
||||
|
|
|
|||
|
|
@ -219,16 +219,19 @@ class TorchParser(ParserBase):
|
|||
else:
|
||||
raise UnexpectedTokenError("expression", self._peek_token())
|
||||
|
||||
def parse_type(self) -> Type:
|
||||
parsed_type, alias_info = self._parse_type_and_alias()
|
||||
if not alias_info:
|
||||
return parsed_type
|
||||
def _create_alias_info_type(self, parsed_type: Type, alias_info: AliasInfo) -> AliasInfoType:
|
||||
if isinstance(parsed_type, ModifiedType):
|
||||
parsed_type.base_type = AliasInfoType(parsed_type.base_type, alias_info)
|
||||
else:
|
||||
parsed_type = AliasInfoType(parsed_type, alias_info)
|
||||
return parsed_type
|
||||
|
||||
def parse_type(self) -> Type:
|
||||
parsed_type, alias_info = self._parse_type_and_alias()
|
||||
if not alias_info:
|
||||
return parsed_type
|
||||
return self._create_alias_info_type(parsed_type, alias_info)
|
||||
|
||||
def _parse_type_and_alias(self) -> Tuple[Type, AliasInfo]:
|
||||
parsed_type: Type = None
|
||||
alias_info: AliasInfo = None
|
||||
|
|
@ -239,8 +242,9 @@ class TorchParser(ParserBase):
|
|||
if self._peek_token(TokenKind.OPEN_PAREN):
|
||||
def parse_tuple_element():
|
||||
element_type, element_alias_info = self._parse_type_and_alias()
|
||||
if alias_info and element_alias_info:
|
||||
alias_info.add_contained_type(element_alias_info)
|
||||
if element_alias_info:
|
||||
element_type = self._create_alias_info_type(element_type, element_alias_info)
|
||||
|
||||
return TupleMemberType(
|
||||
element_type,
|
||||
self._read_token() \
|
||||
|
|
|
|||
|
|
@ -87,7 +87,19 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_ans = cpu_ones * 4
|
||||
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
|
||||
assert torch.allclose(cpu_ans, ort_ans.cpu())
|
||||
|
||||
|
||||
def test_batchnormalization_inplace(self):
|
||||
device = self.get_device()
|
||||
x = torch.Tensor([[[[-1, 0, 1]], [[2., 3., 4.]]]]).to(device)
|
||||
s = torch.Tensor([1.0, 1.5]).to(device)
|
||||
bias = torch.Tensor([0., 1.]).to(device)
|
||||
mean = torch.Tensor([0., 3.]).to(device)
|
||||
var = torch.Tensor([1., 1.5]).to(device)
|
||||
y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9)
|
||||
assert torch.allclose(x.cpu(), y.cpu()), "x != y"
|
||||
assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out"
|
||||
assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out"
|
||||
|
||||
def test_max(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
ort_tensor = cpu_tensor.to('ort')
|
||||
|
|
|
|||
Loading…
Reference in a new issue