[inductor] Consolidata kernel and cpp_kernel for wrapper codegen (#98741)

Summary: refactor to simplify the wrapper codegen logic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98741
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
Bin Bao 2023-04-11 23:55:44 +00:00 committed by PyTorch MergeBot
parent 439a716785
commit ff9e34fb35
3 changed files with 28 additions and 53 deletions

View file

@ -200,6 +200,7 @@ class WrapperCodeGen(CodeGen):
self.ending = ""
self.comment = "#"
self.namespace = ""
self.none_str = "None"
self.size = "size()"
self.stride = "stride()"
@ -326,9 +327,7 @@ class WrapperCodeGen(CodeGen):
def generate_end(self, result):
return
def generate_extern_kernel_out(
self, output_view, codegen_reference, args, kernel, cpp_kernel
):
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
if output_view:
args.append(f"out={output_view.codegen_reference()}")
else:
@ -339,7 +338,6 @@ class WrapperCodeGen(CodeGen):
self,
name,
kernel,
cpp_kernel,
codegen_args,
cpp_op_schema,
cpp_kernel_key,
@ -708,6 +706,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
self.ending = ";"
self.comment = "//"
self.namespace = "at::"
self.none_str = "at::Tensor()"
self.extern_call_ops = set()
self.size = "sizes()"
self.stride = "strides()"
@ -740,14 +739,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
@cache_on_self
def get_output_refs(self):
from ..ir import NoneAsConstantBuffer
return [
"at::Tensor()"
if isinstance(x, NoneAsConstantBuffer)
else x.codegen_reference()
for x in V.graph.graph_outputs
]
return [x.codegen_reference() for x in V.graph.graph_outputs]
def mark_output_type(self):
# mark output type to unwrap tensor back to python scalar
@ -873,9 +865,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
"""
)
def generate_extern_kernel_out(
self, output_view, codegen_reference, args, kernel, cpp_kernel
):
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
if output_view:
output_as_strided = f"{output_view.codegen_reference()}"
output_name = f"{output_view.get_name()}_as_strided"
@ -884,7 +874,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
args.insert(0, output_name)
else:
args.insert(0, f"{codegen_reference}")
self.writeline(f"{cpp_kernel}({', '.join(args)});")
self.writeline(f"{kernel}({', '.join(args)});")
def codegen_sizevar(self, x: Expr) -> str:
from .cpp import cexpr
@ -927,7 +917,6 @@ class CppWrapperCodeGen(WrapperCodeGen):
self,
name,
kernel,
cpp_kernel,
codegen_args,
cpp_op_schema,
cpp_kernel_key,
@ -939,7 +928,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
static auto op_{cpp_kernel_key} =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
\"{cpp_kernel}\",
\"{kernel}\",
\"{cpp_kernel_overload_name}\")
.typed<{cpp_op_schema}>();
"""

View file

@ -251,15 +251,7 @@ class GraphLowering(torch.fx.Interpreter):
assert not self.aot_mode, "AOT compilation failed"
log.debug("Set cpp_wrapper to False due to %s", cond)
def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
if isinstance(buffer, ir.ExternKernel):
if not getattr(buffer, "cpp_kernel", False):
self.disable_cpp_wrapper("ExternKernel")
def register_buffer(self, buffer: ir.ComputedBuffer):
if self.cpp_wrapper:
self.check_buffer_for_cpp_wrapper(buffer)
name = f"buf{len(self.buffers)}"
self.buffers.append(buffer)
self.name_to_buffer[name] = buffer

View file

@ -2045,7 +2045,7 @@ class RandSeedBuffer(ConstantBuffer):
class NoneAsConstantBuffer(IRNode):
def codegen_reference(self):
return "None"
return V.graph.wrapper_code.none_str
class ShapeAsConstantBuffer(IRNode):
@ -2778,7 +2778,6 @@ class ExternKernelOut(ExternKernel):
self.codegen_reference(),
args,
self.kernel,
self.cpp_kernel,
)
def __init__(
@ -2796,11 +2795,9 @@ class ExternKernelOut(ExternKernel):
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
)
self.output_view = output_view
self.cpp_kernel = cpp_kernel
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
self.name = V.graph.register_buffer(self)
if kernel is not None:
self.kernel = kernel
self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
def should_allocate(self):
return True
@ -2827,12 +2824,10 @@ class ExternKernelAlloc(ExternKernel):
super().__init__(
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
)
self.cpp_kernel = cpp_kernel
self.name = V.graph.register_buffer(self)
self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
self.cpp_constant_args = cpp_constant_args
self.name = V.graph.register_buffer(self)
if kernel is not None:
self.kernel = kernel
def should_allocate(self):
return False
@ -3448,18 +3443,21 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
class MKLPackedLinear(ExternKernelAlloc):
kernel = "torch.ops.mkl._mkl_linear"
def __init__(
self,
layout,
inputs,
constant_args=(),
cpp_constant_args=(),
kernel="torch.ops.mkl._mkl_linear",
cpp_kernel="mkl::_mkl_linear",
):
super().__init__(layout, inputs, constant_args, None, kernel, cpp_kernel)
super().__init__(
layout,
inputs,
constant_args,
None,
kernel="torch.ops.mkl._mkl_linear",
cpp_kernel="mkl::_mkl_linear",
)
self.cpp_kernel_key = "mkl_linear"
self.cpp_op_schema = """
at::Tensor(
@ -3481,7 +3479,6 @@ class MKLPackedLinear(ExternKernelAlloc):
wrapper.generate_fusion_ops_code(
self.get_name(),
self.kernel,
self.cpp_kernel,
args,
self.cpp_op_schema,
self.cpp_kernel_key,
@ -3489,8 +3486,6 @@ class MKLPackedLinear(ExternKernelAlloc):
@classmethod
def create(cls, x, packed_w, orig_w, batch_size):
kernel = "torch.ops.mkl._mkl_linear"
x = cls.require_stride1(cls.realize_input(x))
orig_w = cls.require_stride1(cls.realize_input(orig_w))
*m, _ = x.get_size()
@ -3510,23 +3505,25 @@ class MKLPackedLinear(ExternKernelAlloc):
inputs=inputs,
constant_args=constant_args,
cpp_constant_args=cpp_constant_args,
kernel=kernel,
)
class LinearUnary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._linear_pointwise"
def __init__(
self,
layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._linear_pointwise",
cpp_kernel="mkldnn::_linear_pointwise",
cpp_constant_args=(),
):
super().__init__(layout, inputs, constant_args, None, kernel, cpp_kernel)
super().__init__(
layout,
inputs,
constant_args,
None,
kernel="torch.ops.mkldnn._linear_pointwise",
cpp_kernel="mkldnn::_linear_pointwise",
)
self.cpp_kernel_key = "linear_pointwise"
self.cpp_op_schema = """
at::Tensor(
@ -3549,7 +3546,6 @@ class LinearUnary(ExternKernelAlloc):
wrapper.generate_fusion_ops_code(
self.get_name(),
self.kernel,
self.cpp_kernel,
args,
self.cpp_op_schema,
self.cpp_kernel_key,
@ -3557,7 +3553,6 @@ class LinearUnary(ExternKernelAlloc):
@classmethod
def create(cls, x, w, b, attr, scalars, algorithm):
kernel = "torch.ops.mkldnn._linear_pointwise"
x = cls.require_stride1(cls.realize_input(x))
w = cls.require_stride1(cls.realize_input(w))
@ -3587,7 +3582,6 @@ class LinearUnary(ExternKernelAlloc):
inputs=inputs,
constant_args=constant_args,
cpp_constant_args=cpp_constant_args,
kernel=kernel,
)
def apply_constraint(self):