mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Switch AOT Inductor test to export, add dynamic, fix invocation bug (#101585)
Fixes Pull Request resolved: https://github.com/pytorch/pytorch/pull/101585 Approved by: https://github.com/ngimel, https://github.com/desertfire
This commit is contained in:
parent
c3a893c659
commit
39f52c0218
3 changed files with 38 additions and 12 deletions
|
|
@ -17,12 +17,12 @@ class Net(torch.nn.Module):
|
|||
x = torch.randn((32, 64), device="cuda")
|
||||
y = torch.randn((32, 64), device="cuda")
|
||||
|
||||
for dynamic in [True, False]:
|
||||
torch._dynamo.config.dynamic_shapes = dynamic
|
||||
torch._dynamo.reset()
|
||||
|
||||
with torch.no_grad():
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
# Using export is blocked by https://github.com/pytorch/pytorch/issues/99000
|
||||
# module, _ = torch._dynamo.export(Net().cuda(), inp)
|
||||
module = make_fx(Net().cuda())(x, y)
|
||||
lib_path = torch._inductor.aot_compile(module, [x, y])
|
||||
with torch.no_grad():
|
||||
module, _ = torch._dynamo.export(Net().cuda(), x, y)
|
||||
lib_path = torch._inductor.aot_compile(module, [x, y])
|
||||
|
||||
shutil.copy(lib_path, "libaot_inductor_output.so")
|
||||
shutil.copy(lib_path, f"libaot_inductor_output{'_dynamic' if dynamic else ''}.so")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import itertools
|
|||
import logging
|
||||
import math
|
||||
import operator
|
||||
from typing import Dict, Iterable, List, Set
|
||||
from typing import Any, Dict, Iterable, List, Set
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -441,6 +441,14 @@ class TritonOverrides(OpOverrides):
|
|||
return f"tl.math.ceil({x})"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolicCallArg:
|
||||
inner: Any
|
||||
|
||||
def __str__(self):
|
||||
return str(self.inner)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class IterationRanges:
|
||||
"""
|
||||
|
|
@ -1646,13 +1654,29 @@ class TritonKernel(Kernel):
|
|||
grid = []
|
||||
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
|
||||
for tree in self.range_trees:
|
||||
assignment = False
|
||||
if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
|
||||
expr = pexpr(tree.numel)
|
||||
else:
|
||||
assignment = True
|
||||
expr = f"{name}_{tree.prefix}numel"
|
||||
code.writeline(f"{expr} = {pexpr(tree.numel)}")
|
||||
# TODO(voz): Tragic. This should at the very least be a util to slapp on declare and ending.
|
||||
# The real fix here is to revisit our cross language calling convention.
|
||||
code.writeline(
|
||||
f"{code.declare}{expr} = {pexpr(tree.numel)}{code.ending}"
|
||||
)
|
||||
if tree.prefix != "r" or self.inside_reduction:
|
||||
call_args.append(expr)
|
||||
if assignment:
|
||||
# We can get symbolic expressions here, like s0*64
|
||||
# It is fine to have them here, but we need to handle them correctly as their own type
|
||||
# This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
|
||||
# scalars as well.
|
||||
# This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
|
||||
# constant now, need type info. I agree, this needs type info, and while this is not true type info
|
||||
# it suffices as a type hint for the purposes of producing the correct code for this type.
|
||||
call_args.append(SymbolicCallArg(expr))
|
||||
else:
|
||||
call_args.append(expr)
|
||||
if tree.prefix != "r":
|
||||
grid.append(expr)
|
||||
|
||||
|
|
@ -1662,7 +1686,7 @@ class TritonKernel(Kernel):
|
|||
)
|
||||
else:
|
||||
# TODO: refactor generate_kernel_call
|
||||
call_args_str = ", ".join(call_args)
|
||||
call_args_str = ", ".join(str(item) for item in call_args)
|
||||
stream_name = code.write_get_cuda_stream(
|
||||
V.graph.scheduler.current_device.index
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1155,7 +1155,9 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
new_args = []
|
||||
for arg in call_args:
|
||||
var_name = f"var_{next(self.arg_var_id)}"
|
||||
if is_int(arg):
|
||||
if isinstance(arg, torch._inductor.codegen.triton.SymbolicCallArg):
|
||||
self.writeline(f"auto {var_name} = {arg};")
|
||||
elif is_int(arg):
|
||||
self.writeline(f"int {var_name} = {arg};")
|
||||
elif is_float(arg):
|
||||
self.writeline(f"float {var_name} = {arg};")
|
||||
|
|
|
|||
Loading…
Reference in a new issue