mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
gen ExclusivelyOwned in structured kernels (#59827)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59827 ghstack-source-id: 133089541 Test Plan: existing CI Reviewed By: ezyang, janeyx99 Differential Revision: D28965922 fbshipit-source-id: ffbc1d43e5d3ab3abfad3b0830b4da1ce899f505
This commit is contained in:
parent
711ded688d
commit
a5c5b56cf5
2 changed files with 11 additions and 6 deletions
|
|
@ -15,6 +15,7 @@
|
|||
#include <ATen/Utils.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <c10/util/ExclusivelyOwned.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/core/UndefinedTensorImpl.h>
|
||||
|
|
|
|||
|
|
@ -314,12 +314,13 @@ class StructuredRegisterDispatchKey(RegisterDispatchKey):
|
|||
set_output_super = f"{parent_class}::set_output(output_idx, sizes, strides, options, names);"
|
||||
else:
|
||||
set_output_super = ""
|
||||
maybe_star = "*" if k is SchemaKind.functional else ""
|
||||
return f"""
|
||||
void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
|
||||
TensorOptions options, DimnameList names) override {{
|
||||
{textwrap.indent(self.gen_class_set_output_body(k), " ")}
|
||||
if (!names.empty()) {{
|
||||
namedinference::propagate_names(outputs_[output_idx], names);
|
||||
namedinference::propagate_names({maybe_star}outputs_[output_idx], names);
|
||||
}}
|
||||
// super must happen after, so that downstream can use maybe_get_output
|
||||
// to retrieve the output
|
||||
|
|
@ -417,8 +418,10 @@ if (resized) {{
|
|||
def gen_class(
|
||||
self, f: NativeFunction, k: SchemaKind, *, class_name: str, parent_class: str, generate_super: bool
|
||||
) -> str:
|
||||
maybe_star = ''
|
||||
if k is SchemaKind.functional:
|
||||
output_type = "Tensor"
|
||||
output_type = "c10::ExclusivelyOwned<Tensor>"
|
||||
maybe_star = '*'
|
||||
elif k is SchemaKind.inplace:
|
||||
output_type = "std::reference_wrapper<Tensor>"
|
||||
elif k is SchemaKind.out:
|
||||
|
|
@ -441,7 +444,7 @@ if (resized) {{
|
|||
f"{textwrap.indent(class_ctor_str, indent)}",
|
||||
f"{textwrap.indent(self.gen_class_set_output(k, parent_class, generate_super), indent)}",
|
||||
" const Tensor& maybe_get_output(int64_t output_idx) override {",
|
||||
" return outputs_[output_idx];",
|
||||
f" return {maybe_star}outputs_[output_idx];",
|
||||
" }",
|
||||
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
|
||||
f"{textwrap.indent(guard_field, indent)}",
|
||||
|
|
@ -555,10 +558,11 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
|||
# After running meta, op.outputs_ is guaranteed to be valid;
|
||||
# add it to the context
|
||||
out_args = structured.out_arguments(self.g)
|
||||
maybe_star = '*' if k is SchemaKind.functional else ''
|
||||
for i, out_arg in enumerate(out_args):
|
||||
assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
|
||||
context.append(Expr(
|
||||
expr=f"op.outputs_[{i}]",
|
||||
expr=f"{maybe_star}op.outputs_[{i}]",
|
||||
# TODO: Stop hardcoding that the output type is a Tensor. Note
|
||||
# that for the codegen here this is fine because outputs_ is
|
||||
# hardcoded to be tensor already
|
||||
|
|
@ -605,9 +609,9 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
|||
# TODO: Do this in translate instead
|
||||
if k is SchemaKind.functional:
|
||||
if len(f.func.returns) == 1:
|
||||
ret_expr = "std::move(op.outputs_[0])" # small optimization
|
||||
ret_expr = "std::move(op.outputs_[0]).take()" # small optimization
|
||||
else:
|
||||
moved = ', '.join(f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns)))
|
||||
moved = ', '.join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns)))
|
||||
ret_expr = f"std::make_tuple({moved})"
|
||||
elif k is SchemaKind.inplace:
|
||||
ret_expr = "self"
|
||||
|
|
|
|||
Loading…
Reference in a new issue