gen ExclusivelyOwned in structured kernels (#59827)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59827

ghstack-source-id: 133089541

Test Plan: existing CI

Reviewed By: ezyang, janeyx99

Differential Revision: D28965922

fbshipit-source-id: ffbc1d43e5d3ab3abfad3b0830b4da1ce899f505
This commit is contained in:
Scott Wolchok 2021-07-09 13:30:29 -07:00 committed by Facebook GitHub Bot
parent 711ded688d
commit a5c5b56cf5
2 changed files with 11 additions and 6 deletions

View file

@ -15,6 +15,7 @@
#include <ATen/Utils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/Dispatch.h>
#include <c10/util/ExclusivelyOwned.h>
#include <c10/util/Half.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>

View file

@ -314,12 +314,13 @@ class StructuredRegisterDispatchKey(RegisterDispatchKey):
set_output_super = f"{parent_class}::set_output(output_idx, sizes, strides, options, names);"
else:
set_output_super = ""
maybe_star = "*" if k is SchemaKind.functional else ""
return f"""
void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names) override {{
{textwrap.indent(self.gen_class_set_output_body(k), " ")}
if (!names.empty()) {{
namedinference::propagate_names(outputs_[output_idx], names);
namedinference::propagate_names({maybe_star}outputs_[output_idx], names);
}}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
@ -417,8 +418,10 @@ if (resized) {{
def gen_class(
self, f: NativeFunction, k: SchemaKind, *, class_name: str, parent_class: str, generate_super: bool
) -> str:
maybe_star = ''
if k is SchemaKind.functional:
output_type = "Tensor"
output_type = "c10::ExclusivelyOwned<Tensor>"
maybe_star = '*'
elif k is SchemaKind.inplace:
output_type = "std::reference_wrapper<Tensor>"
elif k is SchemaKind.out:
@ -441,7 +444,7 @@ if (resized) {{
f"{textwrap.indent(class_ctor_str, indent)}",
f"{textwrap.indent(self.gen_class_set_output(k, parent_class, generate_super), indent)}",
" const Tensor& maybe_get_output(int64_t output_idx) override {",
" return outputs_[output_idx];",
f" return {maybe_star}outputs_[output_idx];",
" }",
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
f"{textwrap.indent(guard_field, indent)}",
@ -555,10 +558,11 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
# After running meta, op.outputs_ is guaranteed to be valid;
# add it to the context
out_args = structured.out_arguments(self.g)
maybe_star = '*' if k is SchemaKind.functional else ''
for i, out_arg in enumerate(out_args):
assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
context.append(Expr(
expr=f"op.outputs_[{i}]",
expr=f"{maybe_star}op.outputs_[{i}]",
# TODO: Stop hardcoding that the output type is a Tensor. Note
# that for the codegen here this is fine because outputs_ is
# hardcoded to be tensor already
@ -605,9 +609,9 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
# TODO: Do this in translate instead
if k is SchemaKind.functional:
if len(f.func.returns) == 1:
ret_expr = "std::move(op.outputs_[0])" # small optimization
ret_expr = "std::move(op.outputs_[0]).take()" # small optimization
else:
moved = ', '.join(f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns)))
moved = ', '.join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns)))
ret_expr = f"std::make_tuple({moved})"
elif k is SchemaKind.inplace:
ret_expr = "self"