diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index 5eac5c51965..c702a68063c 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index 775a2cf3075..fc2af8fc139 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -314,12 +314,13 @@ class StructuredRegisterDispatchKey(RegisterDispatchKey): set_output_super = f"{parent_class}::set_output(output_idx, sizes, strides, options, names);" else: set_output_super = "" + maybe_star = "*" if k is SchemaKind.functional else "" return f""" void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override {{ {textwrap.indent(self.gen_class_set_output_body(k), " ")} if (!names.empty()) {{ - namedinference::propagate_names(outputs_[output_idx], names); + namedinference::propagate_names({maybe_star}outputs_[output_idx], names); }} // super must happen after, so that downstream can use maybe_get_output // to retrieve the output @@ -417,8 +418,10 @@ if (resized) {{ def gen_class( self, f: NativeFunction, k: SchemaKind, *, class_name: str, parent_class: str, generate_super: bool ) -> str: + maybe_star = '' if k is SchemaKind.functional: - output_type = "Tensor" + output_type = "c10::ExclusivelyOwned" + maybe_star = '*' elif k is SchemaKind.inplace: output_type = "std::reference_wrapper" elif k is SchemaKind.out: @@ -441,7 +444,7 @@ if (resized) {{ f"{textwrap.indent(class_ctor_str, indent)}", f"{textwrap.indent(self.gen_class_set_output(k, parent_class, generate_super), indent)}", " const Tensor& maybe_get_output(int64_t output_idx) override {", - " return outputs_[output_idx];", + f" return {maybe_star}outputs_[output_idx];", " }", f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", f"{textwrap.indent(guard_field, indent)}", @@ -555,10 +558,11 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) + maybe_star = '*' if k is SchemaKind.functional else '' for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append(Expr( - expr=f"op.outputs_[{i}]", + expr=f"{maybe_star}op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already @@ -605,9 +609,9 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: - ret_expr = "std::move(op.outputs_[0])" # small optimization + ret_expr = "std::move(op.outputs_[0]).take()" # small optimization else: - moved = ', '.join(f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns))) + moved = ', '.join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self"