diff --git a/test/cpp/aot_inductor/test.py b/test/cpp/aot_inductor/test.py index 8023c7ebc91..46090fe4d30 100644 --- a/test/cpp/aot_inductor/test.py +++ b/test/cpp/aot_inductor/test.py @@ -17,12 +17,12 @@ class Net(torch.nn.Module): x = torch.randn((32, 64), device="cuda") y = torch.randn((32, 64), device="cuda") +for dynamic in [True, False]: + torch._dynamo.config.dynamic_shapes = dynamic + torch._dynamo.reset() -with torch.no_grad(): - from torch.fx.experimental.proxy_tensor import make_fx - # Using export is blocked by https://github.com/pytorch/pytorch/issues/99000 - # module, _ = torch._dynamo.export(Net().cuda(), inp) - module = make_fx(Net().cuda())(x, y) - lib_path = torch._inductor.aot_compile(module, [x, y]) + with torch.no_grad(): + module, _ = torch._dynamo.export(Net().cuda(), x, y) + lib_path = torch._inductor.aot_compile(module, [x, y]) -shutil.copy(lib_path, "libaot_inductor_output.so") + shutil.copy(lib_path, f"libaot_inductor_output{'_dynamic' if dynamic else ''}.so") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 9a71e8d926f..84c5b54b9eb 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -6,7 +6,7 @@ import itertools import logging import math import operator -from typing import Dict, Iterable, List, Set +from typing import Any, Dict, Iterable, List, Set import sympy @@ -441,6 +441,14 @@ class TritonOverrides(OpOverrides): return f"tl.math.ceil({x})" +@dataclasses.dataclass +class SymbolicCallArg: + inner: Any + + def __str__(self): + return str(self.inner) + + @dataclasses.dataclass class IterationRanges: """ @@ -1646,13 +1654,29 @@ class TritonKernel(Kernel): grid = [] # TODO(jansel): if there are constants, we shouldn't bother passing them as args for tree in self.range_trees: + assignment = False if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): expr = pexpr(tree.numel) else: + assignment = True expr = f"{name}_{tree.prefix}numel" - code.writeline(f"{expr} = {pexpr(tree.numel)}") + # TODO(voz): Tragic. This should at the very least be a util to slapp on declare and ending. + # The real fix here is to revisit our cross language calling convention. + code.writeline( + f"{code.declare}{expr} = {pexpr(tree.numel)}{code.ending}" + ) if tree.prefix != "r" or self.inside_reduction: - call_args.append(expr) + if assignment: + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + call_args.append(SymbolicCallArg(expr)) + else: + call_args.append(expr) if tree.prefix != "r": grid.append(expr) @@ -1662,7 +1686,7 @@ class TritonKernel(Kernel): ) else: # TODO: refactor generate_kernel_call - call_args_str = ", ".join(call_args) + call_args_str = ", ".join(str(item) for item in call_args) stream_name = code.write_get_cuda_stream( V.graph.scheduler.current_device.index ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 9f99bf239b3..995aaf97f27 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1155,7 +1155,9 @@ class CudaWrapperCodeGen(CppWrapperCodeGen): new_args = [] for arg in call_args: var_name = f"var_{next(self.arg_var_id)}" - if is_int(arg): + if isinstance(arg, torch._inductor.codegen.triton.SymbolicCallArg): + self.writeline(f"auto {var_name} = {arg};") + elif is_int(arg): self.writeline(f"int {var_name} = {arg};") elif is_float(arg): self.writeline(f"float {var_name} = {arg};")