Switch AOT Inductor test to export, add dynamic, fix invocation bug (#101585)

Fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101585
Approved by: https://github.com/ngimel, https://github.com/desertfire
This commit is contained in:
Michael Voznesensky 2023-05-16 21:53:26 +00:00 committed by PyTorch MergeBot
parent c3a893c659
commit 39f52c0218
3 changed files with 38 additions and 12 deletions

View file

@ -17,12 +17,12 @@ class Net(torch.nn.Module):
x = torch.randn((32, 64), device="cuda")
y = torch.randn((32, 64), device="cuda")
for dynamic in [True, False]:
torch._dynamo.config.dynamic_shapes = dynamic
torch._dynamo.reset()
with torch.no_grad():
from torch.fx.experimental.proxy_tensor import make_fx
# Using export is blocked by https://github.com/pytorch/pytorch/issues/99000
# module, _ = torch._dynamo.export(Net().cuda(), inp)
module = make_fx(Net().cuda())(x, y)
lib_path = torch._inductor.aot_compile(module, [x, y])
with torch.no_grad():
module, _ = torch._dynamo.export(Net().cuda(), x, y)
lib_path = torch._inductor.aot_compile(module, [x, y])
shutil.copy(lib_path, "libaot_inductor_output.so")
shutil.copy(lib_path, f"libaot_inductor_output{'_dynamic' if dynamic else ''}.so")

View file

@ -6,7 +6,7 @@ import itertools
import logging
import math
import operator
from typing import Dict, Iterable, List, Set
from typing import Any, Dict, Iterable, List, Set
import sympy
@ -441,6 +441,14 @@ class TritonOverrides(OpOverrides):
return f"tl.math.ceil({x})"
@dataclasses.dataclass
class SymbolicCallArg:
inner: Any
def __str__(self):
return str(self.inner)
@dataclasses.dataclass
class IterationRanges:
"""
@ -1646,13 +1654,29 @@ class TritonKernel(Kernel):
grid = []
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
for tree in self.range_trees:
assignment = False
if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
expr = pexpr(tree.numel)
else:
assignment = True
expr = f"{name}_{tree.prefix}numel"
code.writeline(f"{expr} = {pexpr(tree.numel)}")
# TODO(voz): Tragic. This should at the very least be a util to slapp on declare and ending.
# The real fix here is to revisit our cross language calling convention.
code.writeline(
f"{code.declare}{expr} = {pexpr(tree.numel)}{code.ending}"
)
if tree.prefix != "r" or self.inside_reduction:
call_args.append(expr)
if assignment:
# We can get symbolic expressions here, like s0*64
# It is fine to have them here, but we need to handle them correctly as their own type
# This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
# scalars as well.
# This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
# constant now, need type info. I agree, this needs type info, and while this is not true type info
# it suffices as a type hint for the purposes of producing the correct code for this type.
call_args.append(SymbolicCallArg(expr))
else:
call_args.append(expr)
if tree.prefix != "r":
grid.append(expr)
@ -1662,7 +1686,7 @@ class TritonKernel(Kernel):
)
else:
# TODO: refactor generate_kernel_call
call_args_str = ", ".join(call_args)
call_args_str = ", ".join(str(item) for item in call_args)
stream_name = code.write_get_cuda_stream(
V.graph.scheduler.current_device.index
)

View file

@ -1155,7 +1155,9 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
new_args = []
for arg in call_args:
var_name = f"var_{next(self.arg_var_id)}"
if is_int(arg):
if isinstance(arg, torch._inductor.codegen.triton.SymbolicCallArg):
self.writeline(f"auto {var_name} = {arg};")
elif is_int(arg):
self.writeline(f"int {var_name} = {arg};")
elif is_float(arg):
self.writeline(f"float {var_name} = {arg};")