mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Consolidata kernel and cpp_kernel for wrapper codegen (#98741)
Summary: refactor to simplify the wrapper codegen logic Pull Request resolved: https://github.com/pytorch/pytorch/pull/98741 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
parent
439a716785
commit
ff9e34fb35
3 changed files with 28 additions and 53 deletions
|
|
@ -200,6 +200,7 @@ class WrapperCodeGen(CodeGen):
|
|||
self.ending = ""
|
||||
self.comment = "#"
|
||||
self.namespace = ""
|
||||
self.none_str = "None"
|
||||
self.size = "size()"
|
||||
self.stride = "stride()"
|
||||
|
||||
|
|
@ -326,9 +327,7 @@ class WrapperCodeGen(CodeGen):
|
|||
def generate_end(self, result):
|
||||
return
|
||||
|
||||
def generate_extern_kernel_out(
|
||||
self, output_view, codegen_reference, args, kernel, cpp_kernel
|
||||
):
|
||||
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
|
||||
if output_view:
|
||||
args.append(f"out={output_view.codegen_reference()}")
|
||||
else:
|
||||
|
|
@ -339,7 +338,6 @@ class WrapperCodeGen(CodeGen):
|
|||
self,
|
||||
name,
|
||||
kernel,
|
||||
cpp_kernel,
|
||||
codegen_args,
|
||||
cpp_op_schema,
|
||||
cpp_kernel_key,
|
||||
|
|
@ -708,6 +706,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
self.ending = ";"
|
||||
self.comment = "//"
|
||||
self.namespace = "at::"
|
||||
self.none_str = "at::Tensor()"
|
||||
self.extern_call_ops = set()
|
||||
self.size = "sizes()"
|
||||
self.stride = "strides()"
|
||||
|
|
@ -740,14 +739,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
|
||||
@cache_on_self
|
||||
def get_output_refs(self):
|
||||
from ..ir import NoneAsConstantBuffer
|
||||
|
||||
return [
|
||||
"at::Tensor()"
|
||||
if isinstance(x, NoneAsConstantBuffer)
|
||||
else x.codegen_reference()
|
||||
for x in V.graph.graph_outputs
|
||||
]
|
||||
return [x.codegen_reference() for x in V.graph.graph_outputs]
|
||||
|
||||
def mark_output_type(self):
|
||||
# mark output type to unwrap tensor back to python scalar
|
||||
|
|
@ -873,9 +865,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
"""
|
||||
)
|
||||
|
||||
def generate_extern_kernel_out(
|
||||
self, output_view, codegen_reference, args, kernel, cpp_kernel
|
||||
):
|
||||
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
|
||||
if output_view:
|
||||
output_as_strided = f"{output_view.codegen_reference()}"
|
||||
output_name = f"{output_view.get_name()}_as_strided"
|
||||
|
|
@ -884,7 +874,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
args.insert(0, output_name)
|
||||
else:
|
||||
args.insert(0, f"{codegen_reference}")
|
||||
self.writeline(f"{cpp_kernel}({', '.join(args)});")
|
||||
self.writeline(f"{kernel}({', '.join(args)});")
|
||||
|
||||
def codegen_sizevar(self, x: Expr) -> str:
|
||||
from .cpp import cexpr
|
||||
|
|
@ -927,7 +917,6 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
self,
|
||||
name,
|
||||
kernel,
|
||||
cpp_kernel,
|
||||
codegen_args,
|
||||
cpp_op_schema,
|
||||
cpp_kernel_key,
|
||||
|
|
@ -939,7 +928,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
static auto op_{cpp_kernel_key} =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow(
|
||||
\"{cpp_kernel}\",
|
||||
\"{kernel}\",
|
||||
\"{cpp_kernel_overload_name}\")
|
||||
.typed<{cpp_op_schema}>();
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -251,15 +251,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
assert not self.aot_mode, "AOT compilation failed"
|
||||
log.debug("Set cpp_wrapper to False due to %s", cond)
|
||||
|
||||
def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
|
||||
if isinstance(buffer, ir.ExternKernel):
|
||||
if not getattr(buffer, "cpp_kernel", False):
|
||||
self.disable_cpp_wrapper("ExternKernel")
|
||||
|
||||
def register_buffer(self, buffer: ir.ComputedBuffer):
|
||||
if self.cpp_wrapper:
|
||||
self.check_buffer_for_cpp_wrapper(buffer)
|
||||
|
||||
name = f"buf{len(self.buffers)}"
|
||||
self.buffers.append(buffer)
|
||||
self.name_to_buffer[name] = buffer
|
||||
|
|
|
|||
|
|
@ -2045,7 +2045,7 @@ class RandSeedBuffer(ConstantBuffer):
|
|||
|
||||
class NoneAsConstantBuffer(IRNode):
|
||||
def codegen_reference(self):
|
||||
return "None"
|
||||
return V.graph.wrapper_code.none_str
|
||||
|
||||
|
||||
class ShapeAsConstantBuffer(IRNode):
|
||||
|
|
@ -2778,7 +2778,6 @@ class ExternKernelOut(ExternKernel):
|
|||
self.codegen_reference(),
|
||||
args,
|
||||
self.kernel,
|
||||
self.cpp_kernel,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
|
@ -2796,11 +2795,9 @@ class ExternKernelOut(ExternKernel):
|
|||
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
|
||||
)
|
||||
self.output_view = output_view
|
||||
self.cpp_kernel = cpp_kernel
|
||||
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
|
||||
self.name = V.graph.register_buffer(self)
|
||||
if kernel is not None:
|
||||
self.kernel = kernel
|
||||
self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel
|
||||
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
|
||||
|
||||
def should_allocate(self):
|
||||
return True
|
||||
|
|
@ -2827,12 +2824,10 @@ class ExternKernelAlloc(ExternKernel):
|
|||
super().__init__(
|
||||
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
|
||||
)
|
||||
self.cpp_kernel = cpp_kernel
|
||||
self.name = V.graph.register_buffer(self)
|
||||
self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel
|
||||
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
|
||||
self.cpp_constant_args = cpp_constant_args
|
||||
self.name = V.graph.register_buffer(self)
|
||||
if kernel is not None:
|
||||
self.kernel = kernel
|
||||
|
||||
def should_allocate(self):
|
||||
return False
|
||||
|
|
@ -3448,18 +3443,21 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
|
|||
|
||||
|
||||
class MKLPackedLinear(ExternKernelAlloc):
|
||||
kernel = "torch.ops.mkl._mkl_linear"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layout,
|
||||
inputs,
|
||||
constant_args=(),
|
||||
cpp_constant_args=(),
|
||||
kernel="torch.ops.mkl._mkl_linear",
|
||||
cpp_kernel="mkl::_mkl_linear",
|
||||
):
|
||||
super().__init__(layout, inputs, constant_args, None, kernel, cpp_kernel)
|
||||
super().__init__(
|
||||
layout,
|
||||
inputs,
|
||||
constant_args,
|
||||
None,
|
||||
kernel="torch.ops.mkl._mkl_linear",
|
||||
cpp_kernel="mkl::_mkl_linear",
|
||||
)
|
||||
self.cpp_kernel_key = "mkl_linear"
|
||||
self.cpp_op_schema = """
|
||||
at::Tensor(
|
||||
|
|
@ -3481,7 +3479,6 @@ class MKLPackedLinear(ExternKernelAlloc):
|
|||
wrapper.generate_fusion_ops_code(
|
||||
self.get_name(),
|
||||
self.kernel,
|
||||
self.cpp_kernel,
|
||||
args,
|
||||
self.cpp_op_schema,
|
||||
self.cpp_kernel_key,
|
||||
|
|
@ -3489,8 +3486,6 @@ class MKLPackedLinear(ExternKernelAlloc):
|
|||
|
||||
@classmethod
|
||||
def create(cls, x, packed_w, orig_w, batch_size):
|
||||
kernel = "torch.ops.mkl._mkl_linear"
|
||||
|
||||
x = cls.require_stride1(cls.realize_input(x))
|
||||
orig_w = cls.require_stride1(cls.realize_input(orig_w))
|
||||
*m, _ = x.get_size()
|
||||
|
|
@ -3510,23 +3505,25 @@ class MKLPackedLinear(ExternKernelAlloc):
|
|||
inputs=inputs,
|
||||
constant_args=constant_args,
|
||||
cpp_constant_args=cpp_constant_args,
|
||||
kernel=kernel,
|
||||
)
|
||||
|
||||
|
||||
class LinearUnary(ExternKernelAlloc):
|
||||
kernel = "torch.ops.mkldnn._linear_pointwise"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layout,
|
||||
inputs,
|
||||
constant_args=(),
|
||||
kernel="torch.ops.mkldnn._linear_pointwise",
|
||||
cpp_kernel="mkldnn::_linear_pointwise",
|
||||
cpp_constant_args=(),
|
||||
):
|
||||
super().__init__(layout, inputs, constant_args, None, kernel, cpp_kernel)
|
||||
super().__init__(
|
||||
layout,
|
||||
inputs,
|
||||
constant_args,
|
||||
None,
|
||||
kernel="torch.ops.mkldnn._linear_pointwise",
|
||||
cpp_kernel="mkldnn::_linear_pointwise",
|
||||
)
|
||||
self.cpp_kernel_key = "linear_pointwise"
|
||||
self.cpp_op_schema = """
|
||||
at::Tensor(
|
||||
|
|
@ -3549,7 +3546,6 @@ class LinearUnary(ExternKernelAlloc):
|
|||
wrapper.generate_fusion_ops_code(
|
||||
self.get_name(),
|
||||
self.kernel,
|
||||
self.cpp_kernel,
|
||||
args,
|
||||
self.cpp_op_schema,
|
||||
self.cpp_kernel_key,
|
||||
|
|
@ -3557,7 +3553,6 @@ class LinearUnary(ExternKernelAlloc):
|
|||
|
||||
@classmethod
|
||||
def create(cls, x, w, b, attr, scalars, algorithm):
|
||||
kernel = "torch.ops.mkldnn._linear_pointwise"
|
||||
x = cls.require_stride1(cls.realize_input(x))
|
||||
w = cls.require_stride1(cls.realize_input(w))
|
||||
|
||||
|
|
@ -3587,7 +3582,6 @@ class LinearUnary(ExternKernelAlloc):
|
|||
inputs=inputs,
|
||||
constant_args=constant_args,
|
||||
cpp_constant_args=cpp_constant_args,
|
||||
kernel=kernel,
|
||||
)
|
||||
|
||||
def apply_constraint(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue