From e2e4a80cdb010a43c6092e683e660db0f8a86bbb Mon Sep 17 00:00:00 2001 From: chunyuan Date: Tue, 13 Dec 2022 09:52:54 +0000 Subject: [PATCH] Inductor cpp wrapper: support None as output (#88560) Map `None` to `at::Tensor()` in the cpp wrapper Pull Request resolved: https://github.com/pytorch/pytorch/pull/88560 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire --- test/inductor/test_torchinductor.py | 1 + torch/_inductor/codegen/wrapper.py | 24 +++++++++++++++++++++--- torch/_inductor/graph.py | 6 ------ torch/_inductor/ir.py | 3 +++ 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 6ee93bb31a5..83b53272a30 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4920,6 +4920,7 @@ class CommonTemplate: for name in [ "test_as_strided", # buffer reuse "test_cat", # alias + "test_lowmem_dropout1", # None as output "test_profiler_mark_wrapper_call", # TODO: fallback to default wrapper for now "test_relu", # multiple inputs "test_silu", # single input, single output diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index b68c54cfdf4..02c4e2e8f18 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List from .. import codecache, config, ir from ..codecache import cpp_compile_command, get_code_path -from ..utils import dynamo_utils, has_triton, sympy_dot, sympy_product +from ..utils import cache_on_self, dynamo_utils, has_triton, sympy_dot, sympy_product from ..virtualized import V from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel from .triton import texpr @@ -312,6 +312,10 @@ class WrapperCodeGen(CodeGen): self.write_get_cuda_stream ) + @cache_on_self + def get_output_refs(self): + return [x.codegen_reference() for x in V.graph.graph_outputs] + def write_prefix(self): self.prefix.splice( """ @@ -470,7 +474,7 @@ class WrapperCodeGen(CodeGen): else: self.wrapper_call.writeline(line) - output_refs = [x.codegen_reference() for x in V.graph.graph_outputs] + output_refs = self.get_output_refs() if config.triton.debug_sync_graph: self.wrapper_call.writeline("torch.cuda.synchronize()") self.generate_return(output_refs) @@ -562,6 +566,20 @@ class CppWrapperCodeGen(WrapperCodeGen): self._call_func_id = next(CppWrapperCodeGen.call_func_id) super().__init__() + @cache_on_self + def get_output_refs(self): + def has_cpp_codegen_func(x): + return hasattr(x, "cpp_wrapper_codegen_reference") and callable( + x.cpp_wrapper_codegen_reference + ) + + return [ + x.cpp_wrapper_codegen_reference() + if has_cpp_codegen_func(x) + else x.codegen_reference() + for x in V.graph.graph_outputs + ] + def write_prefix(self): self.prefix.splice( """ @@ -576,7 +594,7 @@ class CppWrapperCodeGen(WrapperCodeGen): ) with self.wrapper_call.indent(): inputs_len = len(V.graph.graph_inputs.keys()) - output_refs = [x.codegen_reference() for x in V.graph.graph_outputs] + output_refs = self.get_output_refs() if output_refs: if len(output_refs) == 1: output_types = "at::Tensor" diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 270f3dc22af..8dc7fd390d8 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -404,11 +404,6 @@ class GraphLowering(torch.fx.Interpreter): if value.get_dtype() != torch.float32: self.disable_cpp_wrapper("inputs not FP32") - def check_output_for_cpp_buffer(self): - for item in self.graph_outputs: - if isinstance(item, ir.NoneAsConstantBuffer): - self.disable_cpp_wrapper("NoneAsConstantBuffer") - def check_constant_for_cpp_buffer(self): if self.constants: self.disable_cpp_wrapper("Constants") @@ -418,7 +413,6 @@ class GraphLowering(torch.fx.Interpreter): self.check_profiler_mark_wrapper_call() self.check_device_for_cpp_buffer() self.check_input_for_cpp_buffer() - self.check_output_for_cpp_buffer() self.check_constant_for_cpp_buffer() def init_wrapper_code(self): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 2bcbc25d732..a6c834267fa 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1998,6 +1998,9 @@ class NoneAsConstantBuffer(IRNode): def codegen_reference(self): return "None" + def cpp_wrapper_codegen_reference(self): + return "at::Tensor()" + class ShapeAsConstantBuffer(IRNode): def __init__(self, shape):