mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Inductor cpp wrapper: support None as output (#88560)
Map `None` to `at::Tensor()` in the cpp wrapper Pull Request resolved: https://github.com/pytorch/pytorch/pull/88560 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
parent
93aee0cdc9
commit
e2e4a80cdb
4 changed files with 25 additions and 9 deletions
|
|
@ -4920,6 +4920,7 @@ class CommonTemplate:
|
|||
for name in [
|
||||
"test_as_strided", # buffer reuse
|
||||
"test_cat", # alias
|
||||
"test_lowmem_dropout1", # None as output
|
||||
"test_profiler_mark_wrapper_call", # TODO: fallback to default wrapper for now
|
||||
"test_relu", # multiple inputs
|
||||
"test_silu", # single input, single output
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List
|
|||
|
||||
from .. import codecache, config, ir
|
||||
from ..codecache import cpp_compile_command, get_code_path
|
||||
from ..utils import dynamo_utils, has_triton, sympy_dot, sympy_product
|
||||
from ..utils import cache_on_self, dynamo_utils, has_triton, sympy_dot, sympy_product
|
||||
from ..virtualized import V
|
||||
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel
|
||||
from .triton import texpr
|
||||
|
|
@ -312,6 +312,10 @@ class WrapperCodeGen(CodeGen):
|
|||
self.write_get_cuda_stream
|
||||
)
|
||||
|
||||
@cache_on_self
|
||||
def get_output_refs(self):
|
||||
return [x.codegen_reference() for x in V.graph.graph_outputs]
|
||||
|
||||
def write_prefix(self):
|
||||
self.prefix.splice(
|
||||
"""
|
||||
|
|
@ -470,7 +474,7 @@ class WrapperCodeGen(CodeGen):
|
|||
else:
|
||||
self.wrapper_call.writeline(line)
|
||||
|
||||
output_refs = [x.codegen_reference() for x in V.graph.graph_outputs]
|
||||
output_refs = self.get_output_refs()
|
||||
if config.triton.debug_sync_graph:
|
||||
self.wrapper_call.writeline("torch.cuda.synchronize()")
|
||||
self.generate_return(output_refs)
|
||||
|
|
@ -562,6 +566,20 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
self._call_func_id = next(CppWrapperCodeGen.call_func_id)
|
||||
super().__init__()
|
||||
|
||||
@cache_on_self
|
||||
def get_output_refs(self):
|
||||
def has_cpp_codegen_func(x):
|
||||
return hasattr(x, "cpp_wrapper_codegen_reference") and callable(
|
||||
x.cpp_wrapper_codegen_reference
|
||||
)
|
||||
|
||||
return [
|
||||
x.cpp_wrapper_codegen_reference()
|
||||
if has_cpp_codegen_func(x)
|
||||
else x.codegen_reference()
|
||||
for x in V.graph.graph_outputs
|
||||
]
|
||||
|
||||
def write_prefix(self):
|
||||
self.prefix.splice(
|
||||
"""
|
||||
|
|
@ -576,7 +594,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
)
|
||||
with self.wrapper_call.indent():
|
||||
inputs_len = len(V.graph.graph_inputs.keys())
|
||||
output_refs = [x.codegen_reference() for x in V.graph.graph_outputs]
|
||||
output_refs = self.get_output_refs()
|
||||
if output_refs:
|
||||
if len(output_refs) == 1:
|
||||
output_types = "at::Tensor"
|
||||
|
|
|
|||
|
|
@ -404,11 +404,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
if value.get_dtype() != torch.float32:
|
||||
self.disable_cpp_wrapper("inputs not FP32")
|
||||
|
||||
def check_output_for_cpp_buffer(self):
|
||||
for item in self.graph_outputs:
|
||||
if isinstance(item, ir.NoneAsConstantBuffer):
|
||||
self.disable_cpp_wrapper("NoneAsConstantBuffer")
|
||||
|
||||
def check_constant_for_cpp_buffer(self):
|
||||
if self.constants:
|
||||
self.disable_cpp_wrapper("Constants")
|
||||
|
|
@ -418,7 +413,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.check_profiler_mark_wrapper_call()
|
||||
self.check_device_for_cpp_buffer()
|
||||
self.check_input_for_cpp_buffer()
|
||||
self.check_output_for_cpp_buffer()
|
||||
self.check_constant_for_cpp_buffer()
|
||||
|
||||
def init_wrapper_code(self):
|
||||
|
|
|
|||
|
|
@ -1998,6 +1998,9 @@ class NoneAsConstantBuffer(IRNode):
|
|||
def codegen_reference(self):
|
||||
return "None"
|
||||
|
||||
def cpp_wrapper_codegen_reference(self):
|
||||
return "at::Tensor()"
|
||||
|
||||
|
||||
class ShapeAsConstantBuffer(IRNode):
|
||||
def __init__(self, shape):
|
||||
|
|
|
|||
Loading…
Reference in a new issue