Inductor cpp wrapper: support None as output (#88560)

Map `None` to `at::Tensor()` in the cpp wrapper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88560
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
chunyuan 2022-12-13 09:52:54 +00:00 committed by PyTorch MergeBot
parent 93aee0cdc9
commit e2e4a80cdb
4 changed files with 25 additions and 9 deletions

View file

@ -4920,6 +4920,7 @@ class CommonTemplate:
for name in [
"test_as_strided", # buffer reuse
"test_cat", # alias
"test_lowmem_dropout1", # None as output
"test_profiler_mark_wrapper_call", # TODO: fallback to default wrapper for now
"test_relu", # multiple inputs
"test_silu", # single input, single output

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List
from .. import codecache, config, ir
from ..codecache import cpp_compile_command, get_code_path
from ..utils import dynamo_utils, has_triton, sympy_dot, sympy_product
from ..utils import cache_on_self, dynamo_utils, has_triton, sympy_dot, sympy_product
from ..virtualized import V
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel
from .triton import texpr
@ -312,6 +312,10 @@ class WrapperCodeGen(CodeGen):
self.write_get_cuda_stream
)
@cache_on_self
def get_output_refs(self):
return [x.codegen_reference() for x in V.graph.graph_outputs]
def write_prefix(self):
self.prefix.splice(
"""
@ -470,7 +474,7 @@ class WrapperCodeGen(CodeGen):
else:
self.wrapper_call.writeline(line)
output_refs = [x.codegen_reference() for x in V.graph.graph_outputs]
output_refs = self.get_output_refs()
if config.triton.debug_sync_graph:
self.wrapper_call.writeline("torch.cuda.synchronize()")
self.generate_return(output_refs)
@ -562,6 +566,20 @@ class CppWrapperCodeGen(WrapperCodeGen):
self._call_func_id = next(CppWrapperCodeGen.call_func_id)
super().__init__()
@cache_on_self
def get_output_refs(self):
def has_cpp_codegen_func(x):
return hasattr(x, "cpp_wrapper_codegen_reference") and callable(
x.cpp_wrapper_codegen_reference
)
return [
x.cpp_wrapper_codegen_reference()
if has_cpp_codegen_func(x)
else x.codegen_reference()
for x in V.graph.graph_outputs
]
def write_prefix(self):
self.prefix.splice(
"""
@ -576,7 +594,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
)
with self.wrapper_call.indent():
inputs_len = len(V.graph.graph_inputs.keys())
output_refs = [x.codegen_reference() for x in V.graph.graph_outputs]
output_refs = self.get_output_refs()
if output_refs:
if len(output_refs) == 1:
output_types = "at::Tensor"

View file

@ -404,11 +404,6 @@ class GraphLowering(torch.fx.Interpreter):
if value.get_dtype() != torch.float32:
self.disable_cpp_wrapper("inputs not FP32")
def check_output_for_cpp_buffer(self):
for item in self.graph_outputs:
if isinstance(item, ir.NoneAsConstantBuffer):
self.disable_cpp_wrapper("NoneAsConstantBuffer")
def check_constant_for_cpp_buffer(self):
if self.constants:
self.disable_cpp_wrapper("Constants")
@ -418,7 +413,6 @@ class GraphLowering(torch.fx.Interpreter):
self.check_profiler_mark_wrapper_call()
self.check_device_for_cpp_buffer()
self.check_input_for_cpp_buffer()
self.check_output_for_cpp_buffer()
self.check_constant_for_cpp_buffer()
def init_wrapper_code(self):

View file

@ -1998,6 +1998,9 @@ class NoneAsConstantBuffer(IRNode):
def codegen_reference(self):
return "None"
def cpp_wrapper_codegen_reference(self):
return "at::Tensor()"
class ShapeAsConstantBuffer(IRNode):
def __init__(self, shape):